Skip to content

Commit

Permalink
Merge pull request #1719 from kohya-ss/sd3_5_support
Browse files Browse the repository at this point in the history
SD3.5 Large support
  • Loading branch information
kohya-ss authored Nov 1, 2024
2 parents 1e2f7b0 + 82daa98 commit 264328d
Show file tree
Hide file tree
Showing 19 changed files with 3,292 additions and 2,240 deletions.
195 changes: 163 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
This repository contains training, generation and utility scripts for Stable Diffusion.

## FLUX.1 training (WIP)
## FLUX.1 and SD3 training (WIP)

This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training.

Expand All @@ -9,8 +9,15 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`

- [FLUX.1 training](#flux1-training)
- [SD3 training](#sd3-training)

### Recent Updates

Oct 31, 2024:

- Added support for SD3.5L/M training. See [SD3 training](#sd3-training) for details.

Oct 19, 2024:

- Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training. SD1/2 is not tested yet. This is an experimental feature.
Expand Down Expand Up @@ -139,7 +146,7 @@ Sep 1, 2024:
Aug 29, 2024:
Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `requirements.txt` is updated.

### Contents
## FLUX.1 training

- [FLUX.1 LoRA training](#flux1-lora-training)
- [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training)
Expand Down Expand Up @@ -586,53 +593,177 @@ python tools/convert_diffusers_to_flux.py --diffusers_path path/to/diffusers_fol

## SD3 training

SD3 training is done with `sd3_train.py`.
SD3.5L/M training is now available.

### SD3 LoRA training

The script is `sd3_train_network.py`. See `--help` for options.

SD3 model, CLIP-L, CLIP-G, and T5XXL models are recommended to be in float/fp16 format. If you specify `--fp8_base`, you can use fp8 models for SD3. The fp8 model is only compatible with `float8_e4m3fn` format.

Sample command is below. It will work with 16GB VRAM GPUs (SD3.5L).

```
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 sd3_train_network.py
--pretrained_model_name_or_path path/to/sd3.5_large.safetensors --clip_l sd3/clip_l.safetensors --clip_g sd3/clip_g.safetensors --t5xxl sd3/t5xxl_fp16.safetensors
--cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers
--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
--network_module networks.lora_sd3 --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4
--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base
--highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml
--output_dir path/to/output/dir --output_name sd3-lora-name
```
(The command is multi-line for readability. Please combine it into one line.)

The training can be done with 12GB VRAM GPUs with Adafactor optimizer. Please use settings like below:

```
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0
```

`--cpu_offload_checkpointing` and `--split_mode` are not available for SD3 LoRA training.

__Sep 1, 2024__:
- `--num_last_block_to_freeze` is added to `sd3_train.py`. This option is to freeze the last n blocks of the MMDiT. See [#1417](https://github.com/kohya-ss/sd-scripts/pull/1417) for details. Thanks to sdbds!
We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted.

__Jul 27, 2024__:
- Latents and text encoder outputs caching mechanism is refactored significantly.
- Existing cache files for SD3 need to be recreated. Please delete the previous cache files.
- With this change, dataset initialization is significantly faster, especially for large datasets.
The trained LoRA model can be used with ComfyUI.

- Architecture-dependent parts are extracted from the dataset (`train_util.py`). This is expected to make it easier to add future architectures.
#### Key Options for SD3 LoRA training

- Architecture-dependent parts including the cache mechanism for SD1/2/SDXL are also extracted. The basic operation of SD1/2/SDXL training on the sd3 branch has been confirmed, but there may be bugs. Please use the main or dev branch for SD1/2/SDXL training.
Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome.

---
- `--network_module` is the module for LoRA training. Specify `networks.lora_sd3` for SD3 LoRA training.
- `--pretrained_model_name_or_path` is the path to the pretrained model (SD3/3.5). If you specify `--fp8_base`, you can use fp8 models for SD3/3.5. The fp8 model is only compatible with `float8_e4m3fn` format.
- `--clip_l` is the path to the CLIP-L model.
- `--clip_g` is the path to the CLIP-G model.
- `--t5xxl` is the path to the T5XXL model. If you specify `--fp8_base`, you can use fp8 (float8_e4m3fn) models for T5XXL. However, it is recommended to use fp16 models for caching.
- `--vae` is the path to the autoencoder model. __This option is not necessary for SD3.__ VAE is included in the standard SD3 model.
- `--disable_mmap_load_safetensors` is to disable memory mapping when loading safetensors. __This option significantly reduces the memory usage when loading models for Windows users.__
- `--clip_l_dropout_rate`, `--clip_g_dropout_rate` and `--t5_dropout_rate` are the dropout rates for the embeddings of CLIP-L, CLIP-G, and T5XXL, described in [SAI research papre](http://arxiv.org/pdf/2403.03206). The default is 0.0. For LoRA training, it is seems to be better to set 0.0.
- `--pos_emb_random_crop_rate` is the rate of random cropping of positional embeddings, described in [SD3.5M model card](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium). The default is 0. It is seems to be better to set 0.0 for LoRA training.
- `--enable_scaled_pos_embed` is to enable the scaled positional embeddings. The default is False. This option is an experimental feature for SD3.5M. Details are described below.

`fp16` and `bf16` are available for mixed precision training. We are not sure which is better.
Other options are described below.

`optimizer_type = "adafactor"` is recommended for 24GB VRAM GPUs. `cache_text_encoder_outputs_to_disk` and `cache_latents_to_disk` are necessary currently.
#### Key Features for SD3 LoRA training

`clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them.
1. CLIP-L, G and T5XXL LoRA Support:
- SD3 LoRA training now supports CLIP-L, CLIP-G and T5XXL LoRA training.
- Remove `--network_train_unet_only` from your command.
- Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L and G is also trained at the same time.
- T5XXL output can be cached for CLIP-L and G LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
- The learning rates for CLIP-L, CLIP-G and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5 5e-6`. The first value is the learning rate for CLIP-L, the second value is for CLIP-G, and the third value is for T5XXL. If you specify only one, the learning rates for CLIP-L, CLIP-G and T5XXL will be the same. If the third value is not specified, the second value is used for T5XXL. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL.
- The trained LoRA can be used with ComfyUI.

t5xxl works with `fp16` now.
| trained LoRA|option|network_args|cache_text_encoder_outputs (*1)|
|---|---|---|---|
|MMDiT|`--network_train_unet_only`|-|o|
|MMDiT + CLIP-L + CLIP-G|-|-|o (*2)|
|MMDiT + CLIP-L + CLIP-G + T5XXL|-|`train_t5xxl=True`|-|
|CLIP-L + CLIP-G (*3)|`--network_train_text_encoder_only`|-|o (*2)|
|CLIP-L + CLIP-G + T5XXL (*3)|`--network_train_text_encoder_only`|`train_t5xxl=True`|-|

There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype.
- *1: `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
- *2: T5XXL output can be cached for CLIP-L and G LoRA training.
- *3: Not tested yet.

2. Experimental FP8/FP16 mixed training:
- `--fp8_base_unet` enables training with fp8 for MMDiT and bf16/fp16 for CLIP-L/G/T5XXL.
- When specifying this option, the `--fp8_base` option is automatically enabled.

`text_encoder_batch_size` is added experimentally for caching faster.
3. Split Q/K/V Projection Layers (Experimental):
- Same as FLUX.1.

4. CLIP-L/G and T5 Attention Mask Application:
- This function is planned to be implemented in the future.

5. Multi-resolution Training Support:
- Only for SD3.5M.
- Same as FLUX.1 for data preparation.
- If you train with multiple resolutions, specify `--enable_scaled_pos_embed` to enable the scaled positional embeddings. The default is False. This option is an experimental feature for SD3.5M.

```toml
learning_rate = 1e-6 # seems to depend on the batch size
optimizer_type = "adafactor"
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
cache_text_encoder_outputs = true
cache_text_encoder_outputs_to_disk = true
vae_batch_size = 1
text_encoder_batch_size = 4
cache_latents = true
cache_latents_to_disk = true

Technical details of multi-resolution training for SD3.5M:

The values of the positional embeddings must be the same for each resolution. That is, the same value must be in the same position for 512x512, 768x768, and 1024x1024. To achieve this, the positional embeddings for each resolution are calculated in advance and switched according to the resolution of the training data. This feature is enabled by `--enable_scaled_pos_embed`.

This idea and the code for calculating scaled positional embeddings are contributed by KohakuBlueleaf. Thanks to KohakuBlueleaf!


#### Specify rank for each layer in SD3 LoRA

You can specify the rank for each layer in SD3 by specifying the following network_args. If you specify `0`, LoRA will not be applied to that layer.

When network_args is not specified, the default value (`network_dim`) is applied, same as before.

|network_args|target layer|
|---|---|
|context_attn_dim|attn in context_block|
|context_mlp_dim|mlp in context_block|
|context_mod_dim|adaLN_modulation in context_block|
|x_attn_dim|attn in x_block|
|x_mlp_dim|mlp in x_block|
|x_mod_dim|adaLN_modulation in x_block|

`"verbose=True"` is also available for debugging. It shows the rank of each layer.

example:
```
--network_args "context_attn_dim=2" "context_mlp_dim=3" "context_mod_dim=4" "x_attn_dim=5" "x_mlp_dim=6" "x_mod_dim=7" "verbose=True"
```

You can apply LoRA to the conditioning layers of SD3 by specifying `emb_dims` in network_args. When specifying, be sure to specify 6 numbers in `[]` as a comma-separated list.

example:
```
--network_args "emb_dims=[2,3,4,5,6,7]"
```

Each number corresponds to `context_embedder`, `t_embedder`, `x_embedder`, `y_embedder`, `final_layer_adaLN_modulation`, `final_layer_linear`. The above example applies LoRA to all conditioning layers, with rank 2 for `context_embedder`, 3 for `t_embedder`, 4 for `context_embedder`, 5 for `y_embedder`, 6 for `final_layer_adaLN_modulation`, and 7 for `final_layer_linear`.

If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,4,0,0]` applies LoRA only to `context_embedder` and `y_embedder`.

#### Specify blocks to train in SD3 LoRA training

You can specify the blocks to train in SD3 LoRA training by specifying `train_block_indices` in network_args. The indices are 0-based. The default (when omitted) is to train all blocks. The indices are specified as a list of integers or a range of integers, like `0,1,5,8` or `0,1,4-5,7`.

The number of blocks depends on the model. The valid range is 0-(the number of blocks - 1). `all` is also available to train all blocks, `none` is also available to train no blocks.

example:
```
--network_args "train_block_indices=1,2,6-8"
```

### Inference for SD3 with LoRA model

The inference script is also available. The script is `sd3_minimal_inference.py`. See `--help` for options.

### SD3 fine-tuning

Documentation is not available yet. Please refer to the FLUX.1 fine-tuning guide for now. The major difference are following:

- `--clip_g` is also available for SD3 fine-tuning.
- `--timestep_sampling` `--discrete_flow_shift``--model_prediction_type` --guidance_scale` are not necessary for SD3 fine-tuning.
- Use `--vae` instead of `--ae` if necessary. __This option is not necessary for SD3.__ VAE is included in the standard SD3 model.
- `--disable_mmap_load_safetensors` is available. __This option significantly reduces the memory usage when loading models for Windows users.__
- `--cpu_offload_checkpointing` is not available for SD3 fine-tuning.
- `--clip_l_dropout_rate`, `--clip_g_dropout_rate` and `--t5_dropout_rate` are available same as LoRA training.
- `--pos_emb_random_crop_rate` and `--enable_scaled_pos_embed` are available for SD3.5M fine-tuning.
- Training text encoders is available with `--train_text_encoder` option, similar to SDXL training.
- CLIP-L and G can be trained with `--train_text_encoder` option. Training T5XXL needs `--train_t5xxl` option.
- If you use the cached text encoder outputs for T5XXL with training CLIP-L and G, specify `--use_t5xxl_cache_only`. This option enables to use the cached text encoder outputs for T5XXL only.
- The learning rates for CLIP-L, CLIP-G and T5XXL can be specified separately. `--text_encoder_lr1`, `--text_encoder_lr2` and `--text_encoder_lr3` are available.

### Extract LoRA from SD3 Models

Not available yet.

__2024/7/27:__
### Convert SD3 LoRA

Latents およびテキストエンコーダ出力のキャッシュの仕組みを大きくリファクタリングしました。SD3 用の既存のキャッシュファイルの再作成が必要になりますが、ご了承ください(以前のキャッシュファイルは削除してください)。これにより、特にデータセットの規模が大きい場合のデータセット初期化が大幅に高速化されます。
Not available yet.

データセット (`train_util.py`) からアーキテクチャ依存の部分を切り出しました。これにより将来的なアーキテクチャ追加が容易になると期待しています。
### Merge LoRA to SD3 checkpoint

SD1/2/SDXL のキャッシュ機構を含むアーキテクチャ依存の部分も切り出しました。sd3 ブランチの SD1/2/SDXL 学習について、基本的な動作は確認していますが、不具合があるかもしれません。SD1/2/SDXL の学習には main または dev ブランチをお使いください。
Not available yet.

---

Expand Down
4 changes: 2 additions & 2 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from accelerate.utils import set_seed
from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux
from library.sd3_train_utils import load_prompts, FlowMatchEulerDiscreteScheduler
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler

import library.train_util as train_util

Expand Down Expand Up @@ -241,7 +241,7 @@ def train(args):

text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()

prompts = load_prompts(args.sample_prompts)
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
Expand Down
4 changes: 2 additions & 2 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def cache_text_encoder_outputs_if_needed(
tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()

prompts = sd3_train_utils.load_prompts(args.sample_prompts)
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
Expand Down Expand Up @@ -363,7 +363,7 @@ def get_noise_pred_and_target(
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
if t.dtype.is_floating_point:
if t is not None and t.dtype.is_floating_point:
t.requires_grad_(True)
img_ids.requires_grad_(True)
guidance_vec.requires_grad_(True)
Expand Down
3 changes: 1 addition & 2 deletions library/flux_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from safetensors.torch import save_file

from library import flux_models, flux_utils, strategy_base, train_util
from library.sd3_train_utils import load_prompts
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()
Expand Down Expand Up @@ -70,7 +69,7 @@ def sample_images(
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])

prompts = load_prompts(args.sample_prompts)
prompts = train_util.load_prompts(args.sample_prompts)

save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
Expand Down
Loading

0 comments on commit 264328d

Please sign in to comment.