Skip to content

Commit

Permalink
[FEAT][Test]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Apr 16, 2024
1 parent 053627b commit 99b15b2
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 12 deletions.
4 changes: 2 additions & 2 deletions servers/text_to_video/sample_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
"style": "example_style",
"n": 1,
"output_type": "mp4",
"output_path": "example_output_path"
"output_path": "example_output_path",
}

response = requests.post(url, headers=headers, data=json.dumps(data))

print(response.json())
print(response.json())
73 changes: 73 additions & 0 deletions servers/text_to_video/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
from diffusers import (
AnimateDiffPipeline,
EulerDiscreteScheduler,
MotionAdapter,
)
from diffusers.utils import export_to_gif
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

# Load environment variables from .env file
load_dotenv()


def text_to_video(
task: str,
model_name: str = "ByteDance/AnimateDiff-Lightning",
guidance_scale: float = 1.0,
inference_steps: int = 4,
output_type: str = ".gif",
output_path: str = "animation.gif",
n: int = 1,
length: int = 60,
*args,
**kwargs,
):
"""
Converts a given text task into an animated video.
Args:
task (str): The text task to be converted into a video.
Returns:
str: The path to the exported GIF file.
"""

device = "cuda"
dtype = torch.float16

repo = model_name
ckpt = f"animatediff_lightning_{inference_steps}step_diffusers.safetensors"
base = "emilianJR/epiCRealism" # Choose to your favorite base model.
adapter = MotionAdapter().to(device, dtype)
adapter.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))

pipe = AnimateDiffPipeline.from_pretrained(
base, motion_adapter=adapter, torch_dtype=dtype
).to(device)

pipe.scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config,
timestep_spacing="trailing",
beta_schedule="linear",
)

# outputs = []
# for i in range(n):
# output = pipe(
# prompt=task,
# guidance_scale=guidance_scale,
# num_inference_steps=inference_steps,
# )
# outputs.append(output)
# if output_type == ".gif":
# out = export_to_gif([output], f"{output_path}_{i}.gif")
# else:
# out = export_to_video([output], f"{output_path}_{i}.mp4")
output = pipe(
prompt=task, guidance_scale=guidance_scale, num_inference_steps=inference_steps
)
output = export_to_gif(output.frames[0], output_path)
return output
16 changes: 6 additions & 10 deletions text_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
EulerDiscreteScheduler,
MotionAdapter,
)
from diffusers.utils import export_to_gif, export_to_video
from diffusers.utils import export_to_gif
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from huggingface_hub import hf_hub_download
from loguru import logger
from safetensors.torch import load_file
Expand Down Expand Up @@ -65,12 +64,11 @@ def text_to_video(
base = "emilianJR/epiCRealism" # Choose to your favorite base model.
adapter = MotionAdapter().to(device, dtype)
adapter.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))

pipe = AnimateDiffPipeline.from_pretrained(
base, motion_adapter=adapter, torch_dtype=dtype
).to(device)



pipe.scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config,
timestep_spacing="trailing",
Expand All @@ -90,9 +88,7 @@ def text_to_video(
# else:
# out = export_to_video([output], f"{output_path}_{i}.mp4")
output = pipe(
prompt = task,
guidance_scale = guidance_scale,
num_inference_steps = inference_steps
prompt=task, guidance_scale=guidance_scale, num_inference_steps=inference_steps
)
output = export_to_gif(output.frames[0], output_path)
return output
Expand All @@ -104,7 +100,7 @@ async def create_chat_completion(
):
try:
logger.info(f"Request: {request}")

gen_params = dict(
model_name=request.model_name,
task=request.task,
Expand All @@ -117,7 +113,7 @@ async def create_chat_completion(
)

logger.info(f"Running text_to_video model with params: {gen_params}")

# try:
response = text_to_video(**gen_params)
logger.info(f"Response: {response}")
Expand Down

0 comments on commit 99b15b2

Please sign in to comment.