This is the official implementation of "U-DiTs: Downsample Tokens in U-Shaped Diffusion Transformers".
9/30/2024: U-DiT is cited by Playground V3!
9/26/2024: U-DiT is accepted to NeurIPS 2024!🎉🎉🎉 See you in Vancouver!
Outline
🤔 In this work, we rethink "Could U-Net arch boost DiTs?"
😮 Self-attention with downsampling reduces cost by ~3/4, but improves U-Net performance.
🥳 We develop a series of powerful U-DiTs.
🚀 U-DiT-B could outcompete DiT-XL/2 with only 1/6 of its FLOPs.
Please run command pip install -r requirements.txt
to install the supporting packages.
(Optional) Please download the VAE from this link. The VAE could be automatically downloaded as well.
Here we provide two ways to train a U-DiT model: 1. train on the original ImageNet dataset; 2. train on preprocessed VAE features (Recommended).
Training Data Preparation Use the original ImageNet dataset + VAE encoder. Firstly, download ImageNet as follows:
imagenet/
├──train/
│ ├── n01440764
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
│ │ ├── ......
│ ├── ......
├──val/
│ ├── n01440764
│ │ ├── ILSVRC2012_val_00000293.JPEG
│ │ ├── ILSVRC2012_val_00002138.JPEG
│ │ ├── ......
│ ├── ......
Then run the following command:
torchrun --nnodes=1 --nproc_per_node=8 train.py --data-path={path to imagenet/train} --image-size=256 --model={model name} --epochs={iteration//5000} # fp32 Training
accelerate launch --mixed_precision fp16 train_accelerate.py --data-path {path to imagenet/train} --image-size=256 --model={model name} --epochs={iteration//5000} # fp16 Training
Training Feature Preparation (RECOMMENDED)
Following Fast-DiT, it is recommended to load VAE features directly for faster training. You don't need to download the enormous ImageNet dataset (> 100G); instead, a much smaller "VAE feature" dataset (~21G for ImageNet 256x256) is available here on HuggingFace and MindScope. Please do the following steps:
-
Download imagenet_feature.tar
-
Unzip the tar ball by running
tar -xf imagenet_feature.tar
imagenet_feature/
├── imagenet256_features/ # VAE features
└── imagenet256_labels/ # labels
- Append parser
--feature-path={path to imagenet_feature}
to the training command.
🔥 We released our models via HuggingFace and ModelScope. Please feel free to download them!
Run the following command for parallel sampling:
torch --nnodes=1 --nproc_per_node=8 sample_ddp.py --ckpt={path to checkpoint} --image-size=256 --model={model name} --cfg-scale={cfg scale}
After sampling, an .npz file that contains 50000 images is automatically generated.
We borrow the FID evaluation codes from here. Metrics including FIDs are calculated based on the .npz file. Before evaluation, make sure to download the reference batch for ImageNet 256x256. Then run the following command for metric evaluation:
python evaluator.py {path to reference batch} {path to generated .npz}
- Training code for U-DiTs
- Model weights
- ImageNet features from VAE for faster training
- Colab demos
- Outcomes from longer training
If you find this repo useful, please cite:
@misc{tian2024udits,
title={U-DiTs: Downsample Tokens in U-Shaped Diffusion Transformers},
author={Yuchuan Tian and Zhijun Tu and Hanting Chen and Jie Hu and Chao Xu and Yunhe Wang},
year={2024},
eprint={2405.02730},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
We acknowledge the authors of the following repos:
https://github.com/facebookresearch/DiT (Codebase)
https://github.com/chuanyangjin/fast-DiT (FP16 training; Training on features)
https://github.com/openai/guided-diffusion (Metric evalutation)