diff --git a/sd3_train.py b/sd3_train.py index 6336b4cf9..d4ab13a34 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -321,7 +321,7 @@ def train(args): accelerator.wait_for_everyone() # now we can delete Text Encoders to free memory - if args.use_t5xxl_cache_only: + if not args.use_t5xxl_cache_only: clip_l = None clip_g = None t5xxl = None @@ -330,6 +330,7 @@ def train(args): # load VAE for caching latents if sd3_state_dict is None: + logger.info(f"load state dict for MMDiT and VAE from {args.pretrained_model_name_or_path}") sd3_state_dict = utils.load_safetensors( args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype ) @@ -360,11 +361,6 @@ def train(args): # attn_mode == "torch" # ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" - # SD3 state dict may contain multiple models, so we need to load it and extract one by one. annoying. - logger.info(f"Loading SD3 models from {args.pretrained_model_name_or_path}") - device_to_load = accelerator.device if args.lowram else "cpu" - sd3_state_dict = utils.load_safetensors(args.pretrained_model_name_or_path, device_to_load, args.disable_mmap_load_safetensors) - if args.gradient_checkpointing: mmdit.enable_gradient_checkpointing() @@ -555,7 +551,7 @@ def train(args): # clip_l.text_model.encoder.layers[-1].requires_grad_(False) # clip_l.text_model.final_layer_norm.requires_grad_(False) - # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する + # move Text Encoders to GPU if not caching outputs if not args.cache_text_encoder_outputs: # make sure Text Encoders are on GPU # TODO support CPU for text encoders @@ -817,6 +813,13 @@ def optimizer_hook(parameter: torch.Tensor): # log empty object to commit the sample images to wandb accelerator.log({}, step=0) + # show model device and dtype + logger.info(f"mmdit device: {mmdit.device}, dtype: {mmdit.dtype}" if mmdit else "mmdit is None") + logger.info(f"clip_l device: {clip_l.device}, dtype: {clip_l.dtype}" if clip_l else "clip_l is None") + logger.info(f"clip_g device: {clip_g.device}, dtype: {clip_g.dtype}" if clip_g else "clip_g is None") + logger.info(f"t5xxl device: {t5xxl.device}, dtype: {t5xxl.dtype}" if t5xxl else "t5xxl is None") + logger.info(f"vae device: {vae.device}, dtype: {vae.dtype}" if vae is not None else "vae is None") + loss_recorder = train_util.LossRecorder() epoch = 0 # avoid error when max_train_steps is 0 for epoch in range(num_train_epochs): @@ -1055,10 +1058,10 @@ def optimizer_hook(parameter: torch.Tensor): save_dtype, epoch, global_step, - accelerator.unwrap_model(clip_l) if train_clip else None, - accelerator.unwrap_model(clip_g) if train_clip else None, - accelerator.unwrap_model(t5xxl) if train_t5xxl else None, - accelerator.unwrap_model(mmdit) if train_mmdit else None, + clip_l if train_clip else None, + clip_g if train_clip else None, + t5xxl if train_t5xxl else None, + mmdit if train_mmdit else None, vae, ) logger.info("model saved.") @@ -1153,6 +1156,16 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", ) + parser.add_argument( + "--blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of blocks (~640MB) to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + ) parser.add_argument( "--num_last_block_to_freeze", type=int,