Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement CogVideo model #969

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions camel/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .anthropic_model import AnthropicModel
from .azure_openai_model import AzureOpenAIModel
from .base_model import BaseModelBackend
from .cogvideo_model import CogVideoModel
from .gemini_model import GeminiModel
from .groq_model import GroqModel
from .litellm_model import LiteLLMModel
Expand Down Expand Up @@ -53,4 +54,5 @@
'RekaModel',
'SambaModel',
'TogetherAIModel',
'CogVideoModel',
]
90 changes: 90 additions & 0 deletions camel/models/cogvideo_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========

import os
from typing import Any, Dict, Optional

import httpx
import requests

from camel.types import ModelType


class CogVideoModel:
r"""CogVideo model API backend."""

def __init__(
self,
model_type: ModelType,
model_config_dict: Dict[str, Any],
url: Optional[str] = "http://localhost:8000/generate",
use_gpu: bool = True,
) -> None:
r"""Constructor for CogVideo backend
Reference: https://github.com/THUDM/CogVideo
Args:
model_type (ModelType): Model for which backend is created
such as CogVideoX-2B, CogVideoX-5B, etc.
model_config_dict (Dict[str, Any]): A dictionary of parameters
for the model configuration.
url (Optional[str]): The URL to the model service.
(default: 'http://localhost:8000/generate')
use_gpu (bool): Whether to use GPU for inference. (default: True)
"""
self.model_type = model_type
self.model_config_dict = model_config_dict
self._url = url or os.environ.get("COGVIDEO_API_BASE_URL")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that COGVIDEO_API_BASE_URL in env can never be read

if not self._url:
raise ValueError("COGVIDEO_API_BASE_URL should be set.")
self._use_gpu = use_gpu

async def run(self, prompt: str, **kwargs: Any) -> str:
r"""Run the CogVideo model to generate a video from a text prompt.
Args:
prompt (str): The text prompt to generate the video.
**kwargs (Any): Additional arguments for the model request.
Returns:
str: The path or URL to the generated video.
Raises:
Exception: If there is an error in the request or response.
"""
data = {
"prompt": prompt,
"model_type": self.model_type,
"use_gpu": self._use_gpu,
**self.model_config_dict,
**kwargs,
}

if not isinstance(self._url, str):
raise ValueError("URL should be a string.")

async with httpx.AsyncClient() as client:
try:
response = await client.post(self._url, json=data)
response.raise_for_status()
video_url = response.json().get("video_url")
if not video_url:
raise ValueError(
"No video URL returned by the model service."
)
return video_url

except requests.exceptions.RequestException as e:
raise Exception("Error during CogVideo API call") from e
24 changes: 24 additions & 0 deletions camel/types/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ class ModelType(Enum):
MISTRAL_MIXTRAL_8x22B = "open-mixtral-8x22b"
MISTRAL_CODESTRAL_MAMBA = "open-codestral-mamba"

# CogVideo Model
COGVIDEO = "cogvideo"
COGVIDEOX_2B = "cogvideox-2b"
COGVIDEOX_5B = "cogvideox-5b"

# Reka models
REKA_CORE = "reka-core"
REKA_FLASH = "reka-flash"
Expand Down Expand Up @@ -209,6 +214,17 @@ def is_gemini(self) -> bool:
return self in {ModelType.GEMINI_1_5_FLASH, ModelType.GEMINI_1_5_PRO}

@property
def is_cogvideo(self) -> bool:
r"""Returns whether this type of models is CogVideo-released model.

Returns:
bool: Whether this type of models is CogVideo.
"""
return self in {
ModelType.COGVIDEO,
ModelType.COGVIDEOX_2B,
ModelType.COGVIDEOX_5B,

def is_reka(self) -> bool:
r"""Returns whether this type of models is Reka model.

Expand Down Expand Up @@ -395,6 +411,7 @@ class TaskType(Enum):
MULTI_CONDITION_IMAGE_CRAFT = "multi_condition_image_craft"
DEFAULT = "default"
VIDEO_DESCRIPTION = "video_description"
VIDEO_GENERATION = "video_generation"


class VectorDistance(Enum):
Expand Down Expand Up @@ -481,6 +498,8 @@ class ModelPlatformType(Enum):
REKA = "reka"
TOGETHER = "together"
OPENAI_COMPATIBILITY_MODEL = "openai-compatibility-model"
INTERNLM = "internlm"
COGVIDEO = "cogvideo"
SAMBA = "samba-nova"

@property
Expand Down Expand Up @@ -559,6 +578,11 @@ def is_samba(self) -> bool:
r"""Returns whether this platform is Samba Nova."""
return self is ModelPlatformType.SAMBA

@property
def is_cogvideo(self) -> bool:
r"""Returns whether this platform is CogVideo"""
return self in [ModelPlatformType.COGVIDEO]


class AudioModelType(Enum):
TTS_1 = "tts-1"
Expand Down
44 changes: 44 additions & 0 deletions examples/models/cogvideo_model_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========

import asyncio

from camel.models import CogVideoModel
from camel.types import ModelType


async def main():
# Initialize the CogVideo model
model = CogVideoModel(
model_type=ModelType.COGVIDEOX_5B,
model_config_dict={"video_length": 4, "frame_rate": 8},
)

# Define the text prompt fro video generation
prompt = "A video of a cat playing with a ball."

# Run the model to generate video from the text prompt
try:
video_url = await model.run(prompt=prompt)
print(f"Generated video URL: {video_url}")
except Exception as e:
print(f"An error occurred: {e}")


if __name__ == "__main__":
# Note: To run this code, you need to have a CogVideo server running
# locally or remotely that provides endpoints for this model.
# Ensure that the server is accessible at the specified URL
# in the CogVideoModel class.
asyncio.run(main())
Loading