Skip to content

Commit

Permalink
sv4d: fix readme;
Browse files Browse the repository at this point in the history
rename video exampel folder;
add encode_t as input parameter.
  • Loading branch information
ymxie97 committed Aug 2, 2024
1 parent da40eba commit e90e953
Show file tree
Hide file tree
Showing 22 changed files with 43 additions and 29 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,23 @@
- We are releasing **[Stable Video 4D (SV4D)](https://huggingface.co/stabilityai/sv4d)**, a video-to-4D diffusion model for novel-view video synthesis. For research purposes:
- **SV4D** was trained to generate 40 frames (5 video frames x 8 camera views) at 576x576 resolution, given 5 context frames (the input video), and 8 reference views (synthesised from the first frame of the input video, using a multi-view diffusion model like SV3D) of the same size, ideally white-background images with one object.
- To generate longer novel-view videos (21 frames), we propose a novel sampling method using SV4D, by first sampling 5 anchor frames and then densely sampling the remaining frames while maintaining temporal consistency.
- You can run the community-build gradio demo locally by running `python -m scripts.demo.gradio_app_sv4d`.
- To run the community-build gradio demo locally, run `python -m scripts.demo.gradio_app_sv4d`.
- Please check our [project page](https://sv4d.github.io), [tech report](https://sv4d.github.io/static/sv4d_technical_report.pdf) and [video summary](https://www.youtube.com/watch?v=RBP8vdAWTgk) for more details.

**QUICKSTART** : `python scripts/sampling/simple_video_sample_4d.py --input_path assets/sv4d_example_video/test_video1.mp4 --output_folder outputs/sv4d` (after downloading [sv4d.safetensors](https://huggingface.co/stabilityai/sv4d) and [sv3d_u.safetensors](https://huggingface.co/stabilityai/sv3d) from HuggingFace into `checkpoints/`)
**QUICKSTART** : `python scripts/sampling/simple_video_sample_4d.py --input_path assets/sv4d_videos/test_video1.mp4 --output_folder outputs/sv4d` (after downloading [sv4d.safetensors](https://huggingface.co/stabilityai/sv4d) and [sv3d_u.safetensors](https://huggingface.co/stabilityai/sv3d) from HuggingFace into `checkpoints/`)

To run **SV4D** on a single input video of 21 frames:
- Download SV3D models (`sv3d_u.safetensors` and `sv3d_p.safetensors`) from [here](https://huggingface.co/stabilityai/sv3d) and SV4D model (`sv4d.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d) to `checkpoints/`
- Run `python scripts/sampling/simple_video_sample_4d.py --input_path <path/to/video>`
- `input_path` : The input video `<path/to/video>` can be
- a single video file in `gif` or `mp4` format, such as `assets/sv4d_example_video/test_video1.mp4`, or
- a single video file in `gif` or `mp4` format, such as `assets/sv4d_videos/test_video1.mp4`, or
- a folder containing images of video frames in `.jpg`, `.jpeg`, or `.png` format, or
- a file name pattern matching images of video frames.
- `num_steps` : default is 20, can increase to 50 for better quality but longer sampling time.
- `sv3d_version` : To specify the SV3D model to generate reference multi-views, set `--sv3d_version=sv3d_u` for SV3D_u or `--sv3d_version=sv3d_p` for SV3D_p.
- `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0`
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Cliipdrop](https://clipdrop.co/) before running SV4D.
- **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`.
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Clipdrop](https://clipdrop.co/) or [SAM2](https://github.com/facebookresearch/segment-anything-2) before running SV4D.
- **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--encoding_t=1` (of frames encoded at a time) and `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`.

![tile](assets/sv4d.gif)

Expand Down
Binary file removed assets/sv4d_example_video/human_slow_black_bg.mp4
Binary file not shown.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
47 changes: 30 additions & 17 deletions scripts/demo/gradio_app_sv4d.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import List, Optional, Union
import torchvision

from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder
from scripts.demo.sv4d_helpers import (
decode_latents,
load_model,
Expand Down Expand Up @@ -138,6 +139,7 @@
def sample_anchor(
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
seed: Optional[int] = None,
encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
num_steps: int = 20,
sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p
Expand Down Expand Up @@ -205,6 +207,10 @@ def sample_anchor(
sv3d_file = os.path.join(output_folder, "t000.mp4")
save_video(sv3d_file, images_t0.unsqueeze(1))

for emb in model.conditioner.embedders:
if isinstance(emb, VideoPredictionEmbedderWithEncoder):
emb.en_and_decode_n_samples_a_time = encoding_t
model.en_and_decode_n_samples_a_time = decoding_t
# Initialize image matrix
img_matrix = [[None] * n_views for _ in range(n_frames)]
for i, v in enumerate(subsampled_views):
Expand Down Expand Up @@ -413,6 +419,13 @@ def sample_all(
maximum=100,
step=1,
)
encoding_t = gr.Slider(
label="Encode n frames at a time",
info="Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.",
value=8,
minimum=1,
maximum=40,
)
decoding_t = gr.Slider(
label="Decode n frames at a time",
info="Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.",
Expand Down Expand Up @@ -440,7 +453,7 @@ def sample_all(

generate_btn.click(
fn=sample_anchor,
inputs=[input_video, seed, decoding_t, denoising_steps],
inputs=[input_video, seed, encoding_t, decoding_t, denoising_steps],
outputs=[sv3d_video, anchor_video, anchor_frames],
api_name="SV4D output (5 frames)",
)
Expand All @@ -455,22 +468,22 @@ def sample_all(
examples = gr.Examples(
fn=preprocess_video,
examples=[
"./assets/sv4d_example_video/test_video1.mp4",
"./assets/sv4d_example_video/test_video2.mp4",
"./assets/sv4d_example_video/green_robot.mp4",
"./assets/sv4d_example_video/dolphin.mp4",
"./assets/sv4d_example_video/lucia_v000.mp4",
"./assets/sv4d_example_video/snowboard_v000.mp4",
"./assets/sv4d_example_video/stroller_v000.mp4",
"./assets/sv4d_example_video/human5.mp4",
"./assets/sv4d_example_video/bunnyman.mp4",
"./assets/sv4d_example_video/hiphop_parrot.mp4",
"./assets/sv4d_example_video/guppie_v0.mp4",
"./assets/sv4d_example_video/wave_hello.mp4",
"./assets/sv4d_example_video/pistol_v0.mp4",
"./assets/sv4d_example_video/human7.mp4",
"./assets/sv4d_example_video/monkey.mp4",
"./assets/sv4d_example_video/train_v0.mp4",
"./assets/sv4d_videos/test_video1.mp4",
"./assets/sv4d_videos/test_video2.mp4",
"./assets/sv4d_videos/green_robot.mp4",
"./assets/sv4d_videos/dolphin.mp4",
"./assets/sv4d_videos/lucia_v000.mp4",
"./assets/sv4d_videos/snowboard_v000.mp4",
"./assets/sv4d_videos/stroller_v000.mp4",
"./assets/sv4d_videos/human5.mp4",
"./assets/sv4d_videos/bunnyman.mp4",
"./assets/sv4d_videos/hiphop_parrot.mp4",
"./assets/sv4d_videos/guppie_v0.mp4",
"./assets/sv4d_videos/wave_hello.mp4",
"./assets/sv4d_videos/pistol_v0.mp4",
"./assets/sv4d_videos/human7.mp4",
"./assets/sv4d_videos/monkey.mp4",
"./assets/sv4d_videos/train_v0.mp4",
],
inputs=[input_video],
run_on_click=True,
Expand Down
3 changes: 1 addition & 2 deletions scripts/demo/sv4d_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def preprocess_video(input_path, remove_bg=False, n_frames=21, W=576, H=576, out

images_v0.append(image)

base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 10
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 12
processed_file = os.path.join(output_folder, f"{base_count:06d}_process_input.mp4")
imageio.mimwrite(processed_file, images_v0, fps=10)
return processed_file
Expand Down Expand Up @@ -892,7 +892,6 @@ def denoiser(input, sigma, c):
unload_module_gpu(model.model)
unload_module_gpu(model.denoiser)
load_module_gpu(model.first_stage_model)
model.en_and_decode_n_samples_a_time = decoding_t
if isinstance(model.first_stage_model.decoder, VideoDecoder):
samples_x = model.decode_first_stage(
samples_z, timesteps=default(decoding_t, T)
Expand Down
4 changes: 0 additions & 4 deletions scripts/sampling/configs/sv4d.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
N_TIME: 5
N_VIEW: 8
N_FRAMES: 40
ENCODE_N_A_TIME: 8

model:
target: sgm.models.diffusion.DiffusionEngine
Expand Down Expand Up @@ -68,7 +67,6 @@ model:
is_ae: True
n_cond_frames: ${N_FRAMES}
n_copies: 1
en_and_decode_n_samples_a_time: ${ENCODE_N_A_TIME}
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
Expand Down Expand Up @@ -133,7 +131,6 @@ model:
is_ae: True
n_cond_frames: ${N_VIEW}
n_copies: 1
en_and_decode_n_samples_a_time: ${ENCODE_N_A_TIME}
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler

Expand All @@ -144,7 +141,6 @@ model:
is_ae: True
n_cond_frames: ${N_TIME}
n_copies: 1
en_and_decode_n_samples_a_time: ${ENCODE_N_A_TIME}
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
Expand Down
8 changes: 7 additions & 1 deletion scripts/sampling/simple_video_sample_4d.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from fire import Fire

from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder
from scripts.demo.sv4d_helpers import (
decode_latents,
load_model,
Expand All @@ -35,6 +36,7 @@ def sample(
motion_bucket_id: int = 127,
cond_aug: float = 1e-5,
seed: int = 23,
encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
device: str = "cuda",
elevations_deg: Optional[Union[float, List[float]]] = 10.0,
Expand All @@ -45,7 +47,7 @@ def sample(
):
"""
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t` and `encoding_t`.
"""
# Set model config
T = 5 # number of frames per sample
Expand Down Expand Up @@ -162,6 +164,10 @@ def sample(
verbose,
)
model = initial_model_load(model)
for emb in model.conditioner.embedders:
if isinstance(emb, VideoPredictionEmbedderWithEncoder):
emb.en_and_decode_n_samples_a_time = encoding_t
model.en_and_decode_n_samples_a_time = decoding_t

# Interleaved sampling for anchor frames
t0, v0 = 0, 0
Expand Down

0 comments on commit e90e953

Please sign in to comment.