Skip to content

Commit

Permalink
Stable Video Diffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
Tim Dockhorn committed Nov 21, 2023
1 parent 477d8b9 commit 059d8e9
Show file tree
Hide file tree
Showing 59 changed files with 5,418 additions and 1,646 deletions.
119 changes: 79 additions & 40 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,48 @@

## News

**November 21, 2023**

- We are releasing Stable Video Diffusion, an image-to-video model, for research purposes:
- [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid): This model was trained to generate 14
frames at resolution 576x1024 given a context frame of the same size.
We use the standard image encoder from SD 2.1, but replace the decoder with a temporally-aware `deflickering decoder`.
- [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt): Same architecture as `SVD` but finetuned
for 25 frame generation.
- We provide a streamlit demo `scripts/demo/video_sampling.py` and a standalone python script `scripts/sampling/simple_video_sample.py` for inference of both models.
- Alongside the model, we release a [technical report](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets).

![tile](assets/tile.gif)

**July 26, 2023**
- We are releasing two new open models with a permissive [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0) (see [Inference](#inference) for file hashes):
- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0): An improved version over `SDXL-base-0.9`.
- [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0): An improved version over `SDXL-refiner-0.9`.

![sample2](assets/001_with_eval.png)
- We are releasing two new open models with a
permissive [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0) (see [Inference](#inference) for file
hashes):
- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0): An improved version
over `SDXL-base-0.9`.
- [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0): An improved version
over `SDXL-refiner-0.9`.

![sample2](assets/001_with_eval.png)

**July 4, 2023**

- A technical report on SDXL is now available [here](https://arxiv.org/abs/2307.01952).

**June 22, 2023**


- We are releasing two new diffusion models for research purposes:
- `SDXL-base-0.9`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip) and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses the OpenCLIP model.
- `SDXL-refiner-0.9`: The refiner has been trained to denoise small noise levels of high quality data and as such is not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
- `SDXL-base-0.9`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The
base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip)
and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses
the OpenCLIP model.
- `SDXL-refiner-0.9`: The refiner has been trained to denoise small noise levels of high quality data and as such is
not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.

If you would like to access these models for your research, please apply using one of the following links:
[SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
[SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
This means that you can apply for any of the two links - and if you are granted - you can access both.
Please log in to your Hugging Face Account with your organization email to request access.
**We plan to do a full release soon (July).**
Expand All @@ -32,21 +54,32 @@ Please log in to your Hugging Face Account with your organization email to reque

### General Philosophy

Modularity is king. This repo implements a config-driven approach where we build and combine submodules by calling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples.
Modularity is king. This repo implements a config-driven approach where we build and combine submodules by
calling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples.

### Changelog from the old `ldm` codebase

For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), but it should be easy to use other training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`, now `DiffusionEngine`) has been cleaned up:
For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), but it should be easy to use other
training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`,
now `DiffusionEngine`) has been cleaned up:

- No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial conditionings, and all combinations thereof) in a single class: `GeneralConditioner`, see `sgm/modules/encoders/modules.py`.
- No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial
conditionings, and all combinations thereof) in a single class: `GeneralConditioner`,
see `sgm/modules/encoders/modules.py`.
- We separate guiders (such as classifier-free guidance, see `sgm/modules/diffusionmodules/guiders.py`) from the
samplers (`sgm/modules/diffusionmodules/sampling.py`), and the samplers are independent of the model.
- We adopt the ["denoiser framework"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable change is probably now the option to train continuous time models):
* Discrete times models (denoisers) are simply a special case of continuous time models (denoisers); see `sgm/modules/diffusionmodules/denoiser.py`.
* The following features are now independent: weighting of the diffusion loss function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during training (`sgm/modules/diffusionmodules/sigma_sampling.py`).
- We adopt the ["denoiser framework"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable
change is probably now the option to train continuous time models):
* Discrete times models (denoisers) are simply a special case of continuous time models (denoisers);
see `sgm/modules/diffusionmodules/denoiser.py`.
* The following features are now independent: weighting of the diffusion loss
function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the
network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during
training (`sgm/modules/diffusionmodules/sigma_sampling.py`).
- Autoencoding models have also been cleaned up.

## Installation:

<a name="installation"></a>

#### 1. Clone the repo
Expand All @@ -60,29 +93,17 @@ cd generative-models

This is assuming you have navigated to the `generative-models` root after cloning it.

**NOTE:** This is tested under `python3.8` and `python3.10`. For other python versions, you might encounter version conflicts.


**PyTorch 1.13**

```shell
# install required packages from pypi
python3 -m venv .pt13
source .pt13/bin/activate
pip3 install -r requirements/pt13.txt
```
**NOTE:** This is tested under `python3.10`. For other python versions, you might encounter version conflicts.

**PyTorch 2.0**


```shell
# install required packages from pypi
python3 -m venv .pt2
source .pt2/bin/activate
pip3 install -r requirements/pt2.txt
```


#### 3. Install `sgm`

```shell
Expand Down Expand Up @@ -114,8 +135,10 @@ depending on your use case and PyTorch version, manually.

## Inference

We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`.
We provide file hashes for the complete file as well as for only the saved tensors in the file (see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that).
We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling
in `scripts/demo/sampling.py`.
We provide file hashes for the complete file as well as for only the saved tensors in the file (
see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that).
The following models are currently supported:

- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
Expand All @@ -136,19 +159,20 @@ The following models are currently supported:
**Weights for SDXL**:

**SDXL-1.0:**
The weights of SDXL-1.0 are available (subject to a [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0)) here:
The weights of SDXL-1.0 are available (subject to
a [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0)) here:

- base model: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/
- refiner model: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/


**SDXL-0.9:**
The weights of SDXL-0.9 are available and subject to a [research license](model_licenses/LICENSE-SDXL0.9).
If you would like to access these models for your research, please apply using one of the following links:
[SDXL-base-0.9 model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
[SDXL-base-0.9 model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
and [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
This means that you can apply for any of the two links - and if you are granted - you can access both.
Please log in to your Hugging Face Account with your organization email to request access.


After obtaining the weights, place them into `checkpoints/`.
Next, start the demo using

Expand All @@ -166,6 +190,7 @@ not the same as in previous Stable Diffusion 1.x/2.x versions.

To run the script you need to either have a working installation as above or
try an _experimental_ import using only a minimal amount of packages:

```bash
python -m venv .detect
source .detect/bin/activate
Expand All @@ -177,6 +202,7 @@ pip install --no-deps invisible-watermark
To run the script you need to have a working installation as above. The script
is then useable in the following ways (don't forget to activate your
virtual environment beforehand, e.g. `source .pt1/bin/activate`):

```bash
# test a single file
python scripts/demo/detect.py <your filename here>
Expand All @@ -203,11 +229,21 @@ run
python main.py --base configs/example_training/toy/mnist_cond.yaml
```

**NOTE 1:** Using the non-toy-dataset configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml` and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depending on the used dataset (which is expected to stored in tar-file in the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search for comments containing `USER:` in the respective config.
**NOTE 1:** Using the non-toy-dataset
configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml`
and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depending on the
used dataset (which is expected to stored in tar-file in
the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search
for comments containing `USER:` in the respective config.

**NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`, only `pytorch1.13` is supported.
**NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for
autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`,
only `pytorch1.13` is supported.

**NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires retrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing the `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done for the provided text-to-image configs.
**NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires
retrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing
the `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done
for the provided text-to-image configs.

### Building New Diffusion Models

Expand All @@ -216,7 +252,8 @@ python main.py --base configs/example_training/toy/mnist_cond.yaml
The `GeneralConditioner` is configured through the `conditioner_config`. Its only attribute is `emb_models`, a list of
different embedders (all inherited from `AbstractEmbModel`) that are used to condition the generative model.
All embedders should define whether or not they are trainable (`is_trainable`, default `False`), a classifier-free
guidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for text-conditioning or `cls` for class-conditioning.
guidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for
text-conditioning or `cls` for class-conditioning.
When computing conditionings, the embedder will get `batch[input_key]` as input.
We currently support two to four dimensional conditionings and conditionings of different embedders are concatenated
appropriately.
Expand All @@ -229,7 +266,8 @@ enough as we plan to experiment with transformer-based diffusion backbones.

#### Loss

The loss is configured through `loss_config`. For standard diffusion model training, you will have to set `sigma_sampler_config`.
The loss is configured through `loss_config`. For standard diffusion model training, you will have to
set `sigma_sampler_config`.

#### Sampler config

Expand All @@ -239,8 +277,9 @@ guidance.

### Dataset Handling


For large scale training we recommend using the data pipelines from our [data pipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement and automatically included when following the steps from the [Installation section](#installation).
For large scale training we recommend using the data pipelines from
our [data pipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement
and automatically included when following the steps from the [Installation section](#installation).
Small map-style datasets should be defined here in the repository (e.g., MNIST, CIFAR-10, ...), and return a dict of
data keys/values,
e.g.,
Expand Down
Binary file added assets/test_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/tile.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -29,44 +29,33 @@ model:
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 4 ]
ch_mult: [1, 2, 4]
num_res_blocks: 4
attn_resolutions: [ ]
attn_resolutions: []
dropout: 0.0

decoder_config:
target: sgm.modules.diffusionmodules.model.Decoder
params:
attn_type: none
double_z: False
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 4 ]
num_res_blocks: 4
attn_resolutions: [ ]
dropout: 0.0
params: ${model.params.encoder_config.params}

data:
target: sgm.data.dataset.StableDataModuleFromConfig
params:
train:
datapipeline:
urls:
- "DATA-PATH"
- DATA-PATH
pipeline_config:
shardshuffle: 10000
sample_shuffle: 10000

decoders:
- "pil"
- pil

postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms
params:
key: 'jpg'
key: jpg
transforms:
- target: torchvision.transforms.Resize
params:
Expand Down
105 changes: 105 additions & 0 deletions configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
model:
base_learning_rate: 4.5e-6
target: sgm.models.autoencoder.AutoencodingEngine
params:
input_key: jpg
monitor: val/loss/rec
disc_start_iter: 0

encoder_config:
target: sgm.modules.diffusionmodules.model.Encoder
params:
attn_type: vanilla-xformers
double_z: true
z_channels: 8
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0

decoder_config:
target: sgm.modules.diffusionmodules.model.Decoder
params: ${model.params.encoder_config.params}

regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer

loss_config:
target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
params:
perceptual_weight: 0.25
disc_start: 20001
disc_weight: 0.5
learn_logvar: True

regularization_weights:
kl_loss: 1.0

data:
target: sgm.data.dataset.StableDataModuleFromConfig
params:
train:
datapipeline:
urls:
- DATA-PATH
pipeline_config:
shardshuffle: 10000
sample_shuffle: 10000

decoders:
- pil

postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms
params:
key: jpg
transforms:
- target: torchvision.transforms.Resize
params:
size: 256
interpolation: 3
- target: torchvision.transforms.ToTensor
- target: sdata.mappers.Rescaler
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
params:
h_key: height
w_key: width

loader:
batch_size: 8
num_workers: 4


lightning:
strategy:
target: pytorch_lightning.strategies.DDPStrategy
params:
find_unused_parameters: True

modelcheckpoint:
params:
every_n_train_steps: 5000

callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 50000

image_logger:
target: main.ImageLogger
params:
enable_autocast: False
batch_frequency: 1000
max_images: 8
increase_log_steps: True

trainer:
devices: 0,
limit_val_batches: 50
benchmark: True
accumulate_grad_batches: 1
val_check_interval: 10000
Loading

1 comment on commit 059d8e9

@moeadham
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙌🙌🙌

Please sign in to comment.