From 99b15b24224714b7144cda71576df8ee7c81a6e2 Mon Sep 17 00:00:00 2001 From: Kye Date: Tue, 16 Apr 2024 16:11:44 -0400 Subject: [PATCH] [FEAT][Test] --- servers/text_to_video/sample_request.py | 4 +- servers/text_to_video/test.py | 73 +++++++++++++++++++++++++ text_to_video.py | 16 ++---- 3 files changed, 81 insertions(+), 12 deletions(-) create mode 100644 servers/text_to_video/test.py diff --git a/servers/text_to_video/sample_request.py b/servers/text_to_video/sample_request.py index 74a723a..a1097cc 100644 --- a/servers/text_to_video/sample_request.py +++ b/servers/text_to_video/sample_request.py @@ -11,9 +11,9 @@ "style": "example_style", "n": 1, "output_type": "mp4", - "output_path": "example_output_path" + "output_path": "example_output_path", } response = requests.post(url, headers=headers, data=json.dumps(data)) -print(response.json()) \ No newline at end of file +print(response.json()) diff --git a/servers/text_to_video/test.py b/servers/text_to_video/test.py new file mode 100644 index 0000000..6ed5aab --- /dev/null +++ b/servers/text_to_video/test.py @@ -0,0 +1,73 @@ +import torch +from diffusers import ( + AnimateDiffPipeline, + EulerDiscreteScheduler, + MotionAdapter, +) +from diffusers.utils import export_to_gif +from dotenv import load_dotenv +from huggingface_hub import hf_hub_download +from safetensors.torch import load_file + +# Load environment variables from .env file +load_dotenv() + + +def text_to_video( + task: str, + model_name: str = "ByteDance/AnimateDiff-Lightning", + guidance_scale: float = 1.0, + inference_steps: int = 4, + output_type: str = ".gif", + output_path: str = "animation.gif", + n: int = 1, + length: int = 60, + *args, + **kwargs, +): + """ + Converts a given text task into an animated video. + + Args: + task (str): The text task to be converted into a video. + + Returns: + str: The path to the exported GIF file. + """ + + device = "cuda" + dtype = torch.float16 + + repo = model_name + ckpt = f"animatediff_lightning_{inference_steps}step_diffusers.safetensors" + base = "emilianJR/epiCRealism" # Choose to your favorite base model. + adapter = MotionAdapter().to(device, dtype) + adapter.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device)) + + pipe = AnimateDiffPipeline.from_pretrained( + base, motion_adapter=adapter, torch_dtype=dtype + ).to(device) + + pipe.scheduler = EulerDiscreteScheduler.from_config( + pipe.scheduler.config, + timestep_spacing="trailing", + beta_schedule="linear", + ) + + # outputs = [] + # for i in range(n): + # output = pipe( + # prompt=task, + # guidance_scale=guidance_scale, + # num_inference_steps=inference_steps, + # ) + # outputs.append(output) + # if output_type == ".gif": + # out = export_to_gif([output], f"{output_path}_{i}.gif") + # else: + # out = export_to_video([output], f"{output_path}_{i}.mp4") + output = pipe( + prompt=task, guidance_scale=guidance_scale, num_inference_steps=inference_steps + ) + output = export_to_gif(output.frames[0], output_path) + return output diff --git a/text_to_video.py b/text_to_video.py index d5ebb5e..f19e8d9 100644 --- a/text_to_video.py +++ b/text_to_video.py @@ -7,11 +7,10 @@ EulerDiscreteScheduler, MotionAdapter, ) -from diffusers.utils import export_to_gif, export_to_video +from diffusers.utils import export_to_gif from dotenv import load_dotenv from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import FileResponse from huggingface_hub import hf_hub_download from loguru import logger from safetensors.torch import load_file @@ -65,12 +64,11 @@ def text_to_video( base = "emilianJR/epiCRealism" # Choose to your favorite base model. adapter = MotionAdapter().to(device, dtype) adapter.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device)) - + pipe = AnimateDiffPipeline.from_pretrained( base, motion_adapter=adapter, torch_dtype=dtype ).to(device) - - + pipe.scheduler = EulerDiscreteScheduler.from_config( pipe.scheduler.config, timestep_spacing="trailing", @@ -90,9 +88,7 @@ def text_to_video( # else: # out = export_to_video([output], f"{output_path}_{i}.mp4") output = pipe( - prompt = task, - guidance_scale = guidance_scale, - num_inference_steps = inference_steps + prompt=task, guidance_scale=guidance_scale, num_inference_steps=inference_steps ) output = export_to_gif(output.frames[0], output_path) return output @@ -104,7 +100,7 @@ async def create_chat_completion( ): try: logger.info(f"Request: {request}") - + gen_params = dict( model_name=request.model_name, task=request.task, @@ -117,7 +113,7 @@ async def create_chat_completion( ) logger.info(f"Running text_to_video model with params: {gen_params}") - + # try: response = text_to_video(**gen_params) logger.info(f"Response: {response}")