Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support SD3 #1374

Draft
wants to merge 231 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
231 commits
Select commit Hold shift + click to select a range
e526828
add sd3 models and inference script
kohya-ss Jun 15, 2024
a518e3c
Merge branch 'dev' into sd3
kohya-ss Jun 23, 2024
d53ea22
sd3 training
kohya-ss Jun 23, 2024
0fe4eaf
fix to use zero for initial latent
kohya-ss Jun 24, 2024
4802e4a
workaround for long caption ref #1382
kohya-ss Jun 24, 2024
8f2ba27
support text_encoder_batch_size for caching
kohya-ss Jun 26, 2024
828a581
fix assertion for experimental impl ref #1389
kohya-ss Jun 26, 2024
381598c
fix resolution in metadata for sd3
kohya-ss Jun 26, 2024
66cf435
re-fix assertion ref #1389
kohya-ss Jun 27, 2024
1908646
Fix fp16 mixed precision, model is in bf16 without full_bf16
kohya-ss Jun 29, 2024
ea18d5b
Fix to work full_bf16 and full_fp16.
kohya-ss Jun 29, 2024
50e3d62
fix to work T5XXL with fp16
kohya-ss Jul 8, 2024
c9de7c4
WIP: new latents caching
kohya-ss Jul 8, 2024
3ea4fce
load models one by one
kohya-ss Jul 8, 2024
9dc7997
fix typo
kohya-ss Jul 9, 2024
3d40292
WIP: update new latents caching
kohya-ss Jul 9, 2024
6f0e235
Fix shift value in SD3 inference.
kohya-ss Jul 10, 2024
b8896aa
update README
kohya-ss Jul 10, 2024
082f136
reduce peak GPU memory usage before training
kohya-ss Jul 12, 2024
41dee60
Refactor caching mechanism for latents and text encoder outputs, etc.
kohya-ss Jul 27, 2024
1a977e8
fix typos
kohya-ss Jul 27, 2024
002d751
sample images for training
kohya-ss Jul 29, 2024
231df19
Fix npz path for verification
kohya-ss Aug 5, 2024
da4d0fe
support attn mask for l+g/t5
kohya-ss Aug 5, 2024
36b2e6f
add FLUX.1 LoRA training
kohya-ss Aug 9, 2024
808d2d1
fix typos
kohya-ss Aug 9, 2024
358f13f
fix alpha is ignored
kohya-ss Aug 10, 2024
8a0f12d
update FLUX LoRA training
kohya-ss Aug 10, 2024
82314ac
update readme for ai toolkit settings
kohya-ss Aug 11, 2024
d25ae36
fix apply_t5_attn_mask to work
kohya-ss Aug 11, 2024
9e09a69
update README
kohya-ss Aug 11, 2024
4af36f9
update to work interactive mode
kohya-ss Aug 12, 2024
a7d5dab
Update readme
kohya-ss Aug 12, 2024
0415d20
update dependencies closes #1450
kohya-ss Aug 13, 2024
4cf42cc
Merge branch 'sd3' of https://github.com/kohya-ss/sd-scripts into sd3
kohya-ss Aug 13, 2024
f5ce754
Merge branch 'dev' into sd3
kohya-ss Aug 13, 2024
9711c96
update README
kohya-ss Aug 13, 2024
56d7651
add experimental split mode for FLUX
kohya-ss Aug 13, 2024
9760d09
Fix AttributeError: 'T5EncoderModel' object has no attribute 'text_mo…
fireicewolf Aug 14, 2024
7db4222
add sample image generation during training
kohya-ss Aug 14, 2024
e2d822c
Merge pull request #1452 from fireicewolf/sd3-devel
kohya-ss Aug 15, 2024
8aaa196
fix encoding latents closes #1456
kohya-ss Aug 15, 2024
35b6cb0
update for torchvision
kohya-ss Aug 15, 2024
08ef886
Fix AttributeError: 'FluxNetworkTrainer' object has no attribute 'sam…
fireicewolf Aug 16, 2024
739a896
Merge pull request #1461 from fireicewolf/sd3-devel
kohya-ss Aug 16, 2024
3921a4e
add t5xxl max token length, support schnell
kohya-ss Aug 16, 2024
e45d3f8
add merge LoRA script
kohya-ss Aug 16, 2024
7367584
fix sd3 training to work without cachine TE outputs #1465
kohya-ss Aug 17, 2024
400955d
add fine tuning FLUX.1 (WIP)
kohya-ss Aug 17, 2024
25f77f6
fix flux fine tuning to work
kohya-ss Aug 17, 2024
7e68891
fix: Flux の LoRA マージ機能を修正
exveria1015 Aug 18, 2024
ef535ec
add memory efficient training for FLUX.1
kohya-ss Aug 18, 2024
a450488
update readme
kohya-ss Aug 18, 2024
d034032
update README fix option name
kohya-ss Aug 19, 2024
6e72a79
reduce peak VRAM usage by excluding some blocks to cuda
kohya-ss Aug 19, 2024
486fe8f
feat: reduce memory usage and add memory efficient option for model s…
kohya-ss Aug 19, 2024
9e72be0
Fix debug_dataset to work
kohya-ss Aug 19, 2024
c62c95e
update about multi-resolution training in FLUX.1
kohya-ss Aug 19, 2024
92b1f6d
Merge pull request #1469 from exveria1015/sd3
kohya-ss Aug 20, 2024
6f6faf9
fix to work with ai-toolkit LoRA
kohya-ss Aug 20, 2024
9381332
revert merge function add add option to use new func
kohya-ss Aug 20, 2024
dbed512
chore: formatting
kohya-ss Aug 20, 2024
388b3b4
Merge pull request #1482 from kohya-ss/flux-merge-lora
kohya-ss Aug 20, 2024
6ab48b0
feat: Support multi-resolution training with caching latents to disk
kohya-ss Aug 20, 2024
7e459c0
Update T5 attention mask handling in FLUX
kohya-ss Aug 20, 2024
e17c42c
Add BFL/Diffusers LoRA converter #1467 #1458 #1483
kohya-ss Aug 21, 2024
2b07a92
Fix error in applying mask in Attention and add LoRA converter script
kohya-ss Aug 21, 2024
e1cd19c
add stochastic rounding, fix single block
kohya-ss Aug 21, 2024
98c91a7
Fix bug in FLUX multi GPU training
kohya-ss Aug 22, 2024
a4d27a2
Fix --debug_dataset to work.
kohya-ss Aug 22, 2024
2d8fa33
Fix to remove zero pad for t5xxl output
kohya-ss Aug 22, 2024
b0a9808
added a script to extract LoRA
kohya-ss Aug 22, 2024
bf9f798
chore: fix typos, remove debug print
kohya-ss Aug 22, 2024
99744af
Merge branch 'dev' into sd3
kohya-ss Aug 22, 2024
81411a3
speed up getting image sizes
kohya-ss Aug 22, 2024
2e89cd2
Fix issue with attention mask not being applied in single blocks
kohya-ss Aug 24, 2024
cf689e7
feat: Add option to split projection layers and apply LoRA
kohya-ss Aug 24, 2024
5639c2a
fix typo
kohya-ss Aug 24, 2024
ea92426
Merge branch 'dev' into sd3
kohya-ss Aug 24, 2024
72287d3
feat: Add `shift` option to `--timestep_sampling` in FLUX.1 fine-tuni…
kohya-ss Aug 25, 2024
0087a46
FLUX.1 LoRA supports CLIP-L
kohya-ss Aug 27, 2024
3be712e
feat: Update direct loading fp8 ckpt for LoRA training
kohya-ss Aug 27, 2024
a61cf73
update readme
kohya-ss Aug 27, 2024
6c0e8a5
make guidance_scale keep float in args
Akegarasu Aug 29, 2024
a0cfb08
Cleaned up README
kohya-ss Aug 29, 2024
daa6ad5
Update README.md
kohya-ss Aug 29, 2024
930d709
Merge pull request #1525 from Akegarasu/sd3
kohya-ss Aug 29, 2024
8ecf0fc
Refactor code to ensure args.guidance_scale is always a float #1525
kohya-ss Aug 29, 2024
8fdfd8c
Update safetensors to version 0.4.4 in requirements.txt #1524
kohya-ss Aug 29, 2024
34f2315
fix: text_encoder_conds referenced before assignment
Akegarasu Aug 29, 2024
35882f8
fix
Akegarasu Aug 29, 2024
25c9040
Update flux_train_utils.py
sdbds Aug 30, 2024
928e0fc
Merge pull request #1529 from Akegarasu/sd3
kohya-ss Sep 1, 2024
ef510b3
Sd3 freeze x_block (#1417)
sdbds Sep 1, 2024
92e7600
Move freeze_blocks to sd3_train because it's only for sd3
kohya-ss Sep 1, 2024
1e30aa8
Merge pull request #1541 from sdbds/flux_shift
kohya-ss Sep 1, 2024
4f6d915
update help and README
kohya-ss Sep 1, 2024
6abacf0
update README
kohya-ss Sep 2, 2024
b65ae9b
T5XXL LoRA training, fp8 T5XXL support
kohya-ss Sep 4, 2024
b7cff0a
update README
kohya-ss Sep 4, 2024
56cb2fc
support T5XXL LoRA, reduce peak memory usage #1560
kohya-ss Sep 4, 2024
90ed2df
feat: Add support for merging CLIP-L and T5XXL LoRA models
kohya-ss Sep 4, 2024
d912952
set dtype before calling ae closes #1562
kohya-ss Sep 5, 2024
2889108
feat: Add --cpu_offload_checkpointing option to LoRA training
kohya-ss Sep 5, 2024
ce14447
Merge branch 'dev' into sd3
kohya-ss Sep 7, 2024
d29af14
add negative prompt for flux inference script
kohya-ss Sep 9, 2024
d10ff62
support individual LR for CLIP-L/T5XXL
kohya-ss Sep 10, 2024
65b8a06
update README
kohya-ss Sep 10, 2024
eaafa5c
Merge branch 'dev' into sd3
kohya-ss Sep 11, 2024
8311e88
typo fix
cocktailpeanut Sep 11, 2024
d83f2e9
Merge pull request #1592 from cocktailpeanut/sd3
kohya-ss Sep 11, 2024
a823fd9
Improve wandb logging (#1576)
p1atdev Sep 11, 2024
237317f
update README
kohya-ss Sep 11, 2024
cefe526
fix to work old notation for TE LR in .toml
kohya-ss Sep 12, 2024
f3ce80e
Merge branch 'dev' into sd3
kohya-ss Sep 13, 2024
c15a3a1
Merge branch 'dev' into sd3
kohya-ss Sep 13, 2024
0485f23
Merge branch 'dev' into sd3
kohya-ss Sep 13, 2024
2d8ee3c
OFT for FLUX.1
kohya-ss Sep 14, 2024
c9ff4de
Add support for specifying rank for each layer in FLUX.1
kohya-ss Sep 14, 2024
6445bb2
update README
kohya-ss Sep 14, 2024
9f44ef1
add diffusers to FLUX.1 conversion script
kohya-ss Sep 15, 2024
be078bd
fix typo
kohya-ss Sep 15, 2024
96c677b
fix to work lienar/cosine lr scheduler closes #1602 ref #1393
kohya-ss Sep 16, 2024
d8d15f1
add support for specifying blocks in FLUX.1 LoRA training
kohya-ss Sep 16, 2024
0cbe95b
fix text_encoder_lr to work with int closes #1608
kohya-ss Sep 17, 2024
a2ad7e5
blocks_to_swap=0 means no swap
kohya-ss Sep 17, 2024
bbd160b
sd3 schedule free opt (#1605)
kohya-ss Sep 17, 2024
e745021
update README
kohya-ss Sep 17, 2024
1286e00
fix to call train/eval in schedulefree #1605
kohya-ss Sep 18, 2024
706a48d
Merge branch 'dev' into sd3
kohya-ss Sep 19, 2024
b844c70
Merge branch 'dev' into sd3
kohya-ss Sep 19, 2024
3957372
Retain alpha in `pil_resize`
emcmanus Sep 19, 2024
de4bb65
Update utils.py
emcmanus Sep 19, 2024
0535cd2
fix: backward compatibility for text_encoder_lr
Akegarasu Sep 20, 2024
24f8975
Merge pull request #1620 from Akegarasu/sd3
kohya-ss Sep 20, 2024
583d4a4
add compatibility for int LR (D-Adaptation etc.) #1620
kohya-ss Sep 20, 2024
95ff9db
Merge pull request #1619 from emcmanus/patch-1
kohya-ss Sep 20, 2024
fba7692
Merge branch 'dev' into sd3
kohya-ss Sep 23, 2024
65fb69f
Merge branch 'dev' into sd3
kohya-ss Sep 25, 2024
56a7bc1
new block swap for FLUX.1 fine tuning
kohya-ss Sep 25, 2024
da94fd9
fix typos
kohya-ss Sep 25, 2024
2cd6aa2
Merge branch 'dev' into sd3
kohya-ss Sep 26, 2024
392e8de
fix flip_aug, alpha_mask, random_crop issue in caching in caching str…
kohya-ss Sep 26, 2024
3ebb65f
Merge branch 'dev' into sd3
kohya-ss Sep 26, 2024
9249d00
experimental support for multi-gpus latents caching
kohya-ss Sep 26, 2024
24b1fdb
remove debug print
kohya-ss Sep 26, 2024
a9aa526
fix sample generation is not working in FLUX1 fine tuning #1647
kohya-ss Sep 28, 2024
822fe57
add workaround for 'Some tensors share memory' error #1614
kohya-ss Sep 28, 2024
1a0f5b0
re-fix sample generation is not working in FLUX1 split mode #1647
kohya-ss Sep 28, 2024
d050638
Merge branch 'dev' into sd3
kohya-ss Sep 29, 2024
e0c3630
Support Sdxl Controlnet (#1648)
sdbds Sep 29, 2024
56a63f0
Merge branch 'sd3' into multi-gpu-caching
kohya-ss Sep 29, 2024
8919b31
use original ControlNet instead of Diffusers
kohya-ss Sep 29, 2024
0243c65
fix typo
kohya-ss Sep 29, 2024
8bea039
Merge branch 'dev' into sd3
kohya-ss Sep 29, 2024
d78f6a7
Merge branch 'sd3' into sdxl-ctrl-net
kohya-ss Sep 29, 2024
793999d
sample generation in SDXL ControlNet training
kohya-ss Sep 30, 2024
33e942e
Merge branch 'sd3' into fast_image_sizes
kohya-ss Sep 30, 2024
c2440f9
fix cond image normlization, add independent LR for control
kohya-ss Oct 3, 2024
ba08a89
call optimizer eval/train for sample_at_first, also set train after r…
kohya-ss Oct 4, 2024
83e3048
load Diffusers format, check schnell/dev
kohya-ss Oct 6, 2024
126159f
Merge branch 'sd3' into sdxl-ctrl-net
kohya-ss Oct 7, 2024
886f753
support weighted captions for sdxl LoRA and fine tuning
kohya-ss Oct 9, 2024
3de42b6
fix: distributed training in windows
Akegarasu Oct 10, 2024
9f4dac5
torch 2.4
Akegarasu Oct 10, 2024
f2bc820
support weighted captions for SD/SDXL
kohya-ss Oct 10, 2024
035c4a8
update docs and help text
kohya-ss Oct 11, 2024
43bfeea
Merge pull request #1655 from kohya-ss/sdxl-ctrl-net
kohya-ss Oct 11, 2024
d005652
Merge pull request #1686 from Akegarasu/sd3
kohya-ss Oct 12, 2024
0d3058b
update README
kohya-ss Oct 12, 2024
ff4083b
Merge branch 'sd3' into multi-gpu-caching
kohya-ss Oct 12, 2024
c80c304
Refactor caching in train scripts
kohya-ss Oct 12, 2024
ecaea90
update README
kohya-ss Oct 12, 2024
e277b57
Update FLUX.1 support for compact models
kohya-ss Oct 12, 2024
5bb9f7f
Merge branch 'sd3' into multi-gpu-caching
kohya-ss Oct 13, 2024
74228c9
update cache_latents/text_encoder_outputs
kohya-ss Oct 13, 2024
c65cf38
Merge branch 'sd3' into fast_image_sizes
kohya-ss Oct 13, 2024
2244cf5
load images in parallel when caching latents
kohya-ss Oct 13, 2024
bfc3a65
fix to work cache latents/text encoder outputs
kohya-ss Oct 13, 2024
d02a6ef
Merge pull request #1660 from kohya-ss/fast_image_sizes
kohya-ss Oct 13, 2024
886ffb4
Merge branch 'sd3' into multi-gpu-caching
kohya-ss Oct 13, 2024
2d5f7fa
update README
kohya-ss Oct 13, 2024
1275e14
Merge pull request #1690 from kohya-ss/multi-gpu-caching
kohya-ss Oct 13, 2024
2500f5a
fix latents caching not working closes #1696
kohya-ss Oct 14, 2024
3cc5b8d
Diff Output Preserv loss for SDXL
kohya-ss Oct 18, 2024
d8d7142
fix to work caching latents #1696
kohya-ss Oct 18, 2024
ef70aa7
add FLUX.1 support
kohya-ss Oct 18, 2024
2c45d97
update README, remove unnecessary autocast
kohya-ss Oct 19, 2024
09b4d1e
Merge branch 'sd3' into diff_output_prsv
kohya-ss Oct 19, 2024
aa93242
Merge pull request #1710 from kohya-ss/diff_output_prsv
kohya-ss Oct 19, 2024
7fe8e16
fix to work ControlNetSubset with custom_attributes
kohya-ss Oct 19, 2024
138dac4
update README
kohya-ss Oct 20, 2024
623017f
refactor SD3 CLIP to transformers etc.
kohya-ss Oct 24, 2024
e3c43bd
reduce memory usage in sample image generation
kohya-ss Oct 24, 2024
0286114
support SD3.5L, fix final saving
kohya-ss Oct 24, 2024
f8c5146
support block swap with fused_optimizer_pass
kohya-ss Oct 24, 2024
5fba6f5
Merge branch 'dev' into sd3
kohya-ss Oct 25, 2024
f52fb66
Merge branch 'sd3' into sd3_5_support
kohya-ss Oct 25, 2024
d2c549d
support SD3 LoRA
kohya-ss Oct 25, 2024
0031d91
add latent scaling/shifting
kohya-ss Oct 25, 2024
56bf761
fix errors in SD3 LoRA training with Text Encoders close #1724
kohya-ss Oct 26, 2024
014064f
fix sample image generation without seed failed close #1726
kohya-ss Oct 26, 2024
8549669
Merge branch 'dev' into sd3
kohya-ss Oct 26, 2024
150579d
Merge branch 'sd3' into sd3_5_support
kohya-ss Oct 26, 2024
731664b
Merge branch 'dev' into sd3
kohya-ss Oct 27, 2024
b649bbf
Merge branch 'sd3' into sd3_5_support
kohya-ss Oct 27, 2024
db2b4d4
Add dropout rate arguments for CLIP-L, CLIP-G, and T5, fix Text Encod…
kohya-ss Oct 27, 2024
a1255d6
Fix SD3 LoRA training to work (WIP)
kohya-ss Oct 27, 2024
d4f7849
prevent unintended cast for disk cached TE outputs
kohya-ss Oct 27, 2024
1065dd1
Fix to work dropout_rate for TEs
kohya-ss Oct 27, 2024
af8e216
Fix sample image gen to work with block swap
kohya-ss Oct 28, 2024
7555486
Fix error on saving T5XXL
kohya-ss Oct 28, 2024
0af4edd
Fix split_qkv
kohya-ss Oct 29, 2024
d4e19fb
Support Lora
kohya-ss Oct 29, 2024
80bb3f4
Merge branch 'sd3_5_support' of https://github.com/kohya-ss/sd-script…
kohya-ss Oct 29, 2024
1e2f7b0
Support for checkpoint files with a mysterious prefix "model.diffusio…
kohya-ss Oct 29, 2024
ce5b532
Fix additional LoRA to work
kohya-ss Oct 29, 2024
c9a1417
Merge branch 'sd3' into sd3_5_support
kohya-ss Oct 29, 2024
b502f58
Fix emb_dim to work.
kohya-ss Oct 29, 2024
bdddc20
support SD3.5M
kohya-ss Oct 30, 2024
8c3c825
Merge branch 'sd3_5_support' of https://github.com/kohya-ss/sd-script…
kohya-ss Oct 30, 2024
70a179e
Fix to use SDPA instead of xformers
kohya-ss Oct 30, 2024
1434d85
Support SD3.5M multi resolutional training
kohya-ss Oct 31, 2024
9e23368
Update SD3 training
kohya-ss Oct 31, 2024
830df4a
Fix crashing if image is too tall or wide.
kohya-ss Oct 31, 2024
9aa6f52
Fix memory leak in latent caching. bmp failed to cache
kohya-ss Nov 1, 2024
82daa98
remove duplicate resolution for scaled pos embed
kohya-ss Nov 1, 2024
264328d
Merge pull request #1719 from kohya-ss/sd3_5_support
kohya-ss Nov 1, 2024
e0db596
update multi-res training in SD3.5M
kohya-ss Nov 2, 2024
5e32ee2
fix crashing in DDP training closes #1751
kohya-ss Nov 2, 2024
4384903
Fix to work without latent cache #1758
kohya-ss Nov 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
777 changes: 774 additions & 3 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/train_lllite_README.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ for img_file in img_files:

### Creating a dataset configuration file

You can use the command line arguments of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`.
You can use the command line argument `--conditioning_data_dir` of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`.

```toml
[general]
Expand Down
70 changes: 46 additions & 24 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tqdm import tqdm

import torch
from library import deepspeed_utils
from library import deepspeed_utils, strategy_base
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()
Expand Down Expand Up @@ -39,6 +39,7 @@
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
import library.strategy_sd as strategy_sd


def train(args):
Expand All @@ -52,7 +53,15 @@ def train(args):
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する

tokenizer = train_util.load_tokenizer(args)
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)

# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if cache_latents:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)

# データセットを準備する
if args.dataset_class is None:
Expand Down Expand Up @@ -81,10 +90,10 @@ def train(args):
]
}

blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
train_dataset_group = train_util.load_arbitrary_dataset(args)

current_epoch = Value("i", 0)
current_step = Value("i", 0)
Expand Down Expand Up @@ -167,8 +176,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)

train_dataset_group.new_cache_latents(vae, accelerator)

vae.to("cpu")
clean_memory_on_device(accelerator.device)

Expand All @@ -194,6 +204,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
else:
text_encoder.eval()

text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)

if not cache_latents:
vae.requires_grad_(False)
vae.eval()
Expand All @@ -216,7 +229,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)

# dataloaderを準備する
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()

# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
Expand Down Expand Up @@ -319,7 +336,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
)

# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
Expand All @@ -344,19 +366,17 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
if args.weighted_captions:
encoder_hidden_states = get_weighted_text_embeddings(
tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, [text_encoder], input_ids_list, weights_list
)[0]
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
input_ids = batch["input_ids_list"][0].to(accelerator.device)
encoder_hidden_states = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder], [input_ids]
)[0]
if args.full_fp16:
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
Expand Down Expand Up @@ -411,7 +431,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
global_step += 1

train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)

# 指定ステップごとにモデルを保存
Expand All @@ -436,7 +456,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
)

current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
Expand All @@ -449,7 +469,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

Expand All @@ -474,7 +494,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
vae,
)

train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
train_util.sample_images(
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)

is_main_process = accelerator.is_main_process
if is_main_process:
Expand Down
Loading