-
Notifications
You must be signed in to change notification settings - Fork 347
/
predict.py
155 lines (136 loc) · 5.44 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
import os
import sys
import argparse
import random
from omegaconf import OmegaConf
from einops import rearrange, repeat
import torch
import torchvision
from pytorch_lightning import seed_everything
from cog import BasePredictor, Input, Path
sys.path.insert(0, "scripts/evaluation")
from funcs import (
batch_ddim_sampling,
load_model_checkpoint,
load_image_batch,
get_filelist,
)
from utils.utils import instantiate_from_config
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
ckpt_path_base = "checkpoints/base_1024_v1/model.ckpt"
config_base = "configs/inference_t2v_1024_v1.0.yaml"
ckpt_path_i2v = "checkpoints/i2v_512_v1/model.ckpt"
config_i2v = "configs/inference_i2v_512_v1.0.yaml"
config_base = OmegaConf.load(config_base)
model_config_base = config_base.pop("model", OmegaConf.create())
self.model_base = instantiate_from_config(model_config_base)
self.model_base = self.model_base.cuda()
self.model_base = load_model_checkpoint(self.model_base, ckpt_path_base)
self.model_base.eval()
config_i2v = OmegaConf.load(config_i2v)
model_config_i2v = config_i2v.pop("model", OmegaConf.create())
self.model_i2v = instantiate_from_config(model_config_i2v)
self.model_i2v = self.model_i2v.cuda()
self.model_i2v = load_model_checkpoint(self.model_i2v, ckpt_path_i2v)
self.model_i2v.eval()
def predict(
self,
task: str = Input(
description="Choose the task.",
choices=["text2video", "image2video"],
default="text2video",
),
prompt: str = Input(
description="Prompt for video generation.",
default="A tiger walks in the forest, photorealistic, 4k, high definition.",
),
image: Path = Input(
description="Input image for image2video task.", default=None
),
ddim_steps: int = Input(description="Number of denoising steps.", default=50),
unconditional_guidance_scale: float = Input(
description="Classifier-free guidance scale.", default=12.0
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
save_fps: int = Input(
description="Frame per second for the generated video.", default=10
),
) -> Path:
width = 1024 if task == "text2video" else 512
height = 576 if task == "text2video" else 320
model = self.model_base if task == "text2video" else self.model_i2v
if task == "image2video":
assert image is not None, "Please provide image for image2video generation."
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
seed_everything(seed)
args = argparse.Namespace(
mode="base" if task == "text2video" else "i2v",
savefps=save_fps,
n_samples=1,
ddim_steps=ddim_steps,
ddim_eta=1.0,
bs=1,
height=height,
width=width,
frames=-1,
fps=28 if task == "text2video" else 8,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_guidance_scale_temporal=None,
)
## latent noise shape
h, w = args.height // 8, args.width // 8
frames = model.temporal_length if args.frames < 0 else args.frames
channels = model.channels
batch_size = 1
noise_shape = [batch_size, channels, frames, h, w]
fps = torch.tensor([args.fps] * batch_size).to(model.device).long()
prompts = [prompt]
text_emb = model.get_learned_conditioning(prompts)
if args.mode == "base":
cond = {"c_crossattn": [text_emb], "fps": fps}
elif args.mode == "i2v":
cond_images = load_image_batch([str(image)], (args.height, args.width))
cond_images = cond_images.to(model.device)
img_emb = model.get_image_embeds(cond_images)
imtext_cond = torch.cat([text_emb, img_emb], dim=1)
cond = {"c_crossattn": [imtext_cond], "fps": fps}
else:
raise NotImplementedError
## inference
batch_samples = batch_ddim_sampling(
model,
cond,
noise_shape,
args.n_samples,
args.ddim_steps,
args.ddim_eta,
args.unconditional_guidance_scale,
)
out_path = "/tmp/output.mp4"
vid_tensor = batch_samples[0]
video = vid_tensor.detach().cpu()
video = torch.clamp(video.float(), -1.0, 1.0)
video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples))
for framesheet in video
] # [3, 1*h, n*w]
grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(
out_path,
grid,
fps=args.savefps,
video_codec="h264",
options={"crf": "10"},
)
return Path(out_path)