Skip to content

Commit

Permalink
support SD3.5L, fix final saving
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 24, 2024
1 parent e3c43bd commit 0286114
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def train(args):
accelerator.wait_for_everyone()

# now we can delete Text Encoders to free memory
if args.use_t5xxl_cache_only:
if not args.use_t5xxl_cache_only:
clip_l = None
clip_g = None
t5xxl = None
Expand All @@ -330,6 +330,7 @@ def train(args):

# load VAE for caching latents
if sd3_state_dict is None:
logger.info(f"load state dict for MMDiT and VAE from {args.pretrained_model_name_or_path}")
sd3_state_dict = utils.load_safetensors(
args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype
)
Expand Down Expand Up @@ -360,11 +361,6 @@ def train(args):
# attn_mode == "torch"
# ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。"

# SD3 state dict may contain multiple models, so we need to load it and extract one by one. annoying.
logger.info(f"Loading SD3 models from {args.pretrained_model_name_or_path}")
device_to_load = accelerator.device if args.lowram else "cpu"
sd3_state_dict = utils.load_safetensors(args.pretrained_model_name_or_path, device_to_load, args.disable_mmap_load_safetensors)

if args.gradient_checkpointing:
mmdit.enable_gradient_checkpointing()

Expand Down Expand Up @@ -555,7 +551,7 @@ def train(args):
# clip_l.text_model.encoder.layers[-1].requires_grad_(False)
# clip_l.text_model.final_layer_norm.requires_grad_(False)

# TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する
# move Text Encoders to GPU if not caching outputs
if not args.cache_text_encoder_outputs:
# make sure Text Encoders are on GPU
# TODO support CPU for text encoders
Expand Down Expand Up @@ -817,6 +813,13 @@ def optimizer_hook(parameter: torch.Tensor):
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)

# show model device and dtype
logger.info(f"mmdit device: {mmdit.device}, dtype: {mmdit.dtype}" if mmdit else "mmdit is None")
logger.info(f"clip_l device: {clip_l.device}, dtype: {clip_l.dtype}" if clip_l else "clip_l is None")
logger.info(f"clip_g device: {clip_g.device}, dtype: {clip_g.dtype}" if clip_g else "clip_g is None")
logger.info(f"t5xxl device: {t5xxl.device}, dtype: {t5xxl.dtype}" if t5xxl else "t5xxl is None")
logger.info(f"vae device: {vae.device}, dtype: {vae.dtype}" if vae is not None else "vae is None")

loss_recorder = train_util.LossRecorder()
epoch = 0 # avoid error when max_train_steps is 0
for epoch in range(num_train_epochs):
Expand Down Expand Up @@ -1055,10 +1058,10 @@ def optimizer_hook(parameter: torch.Tensor):
save_dtype,
epoch,
global_step,
accelerator.unwrap_model(clip_l) if train_clip else None,
accelerator.unwrap_model(clip_g) if train_clip else None,
accelerator.unwrap_model(t5xxl) if train_t5xxl else None,
accelerator.unwrap_model(mmdit) if train_mmdit else None,
clip_l if train_clip else None,
clip_g if train_clip else None,
t5xxl if train_t5xxl else None,
mmdit if train_mmdit else None,
vae,
)
logger.info("model saved.")
Expand Down Expand Up @@ -1153,6 +1156,16 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
)
parser.add_argument(
"--blocks_to_swap",
type=int,
default=None,
help="[EXPERIMENTAL] "
"Sets the number of blocks (~640MB) to swap during the forward and backward passes."
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
" / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。"
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。",
)
parser.add_argument(
"--num_last_block_to_freeze",
type=int,
Expand Down

0 comments on commit 0286114

Please sign in to comment.