Skip to content

Commit

Permalink
fix crashing in DDP training closes #1751
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Nov 2, 2024
1 parent e0db596 commit 5e32ee2
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,11 +838,31 @@ def optimizer_hook(parameter: torch.Tensor):
accelerator.log({}, step=0)

# show model device and dtype
logger.info(f"mmdit device: {mmdit.device}, dtype: {mmdit.dtype}" if mmdit else "mmdit is None")
logger.info(f"clip_l device: {clip_l.device}, dtype: {clip_l.dtype}" if clip_l else "clip_l is None")
logger.info(f"clip_g device: {clip_g.device}, dtype: {clip_g.dtype}" if clip_g else "clip_g is None")
logger.info(f"t5xxl device: {t5xxl.device}, dtype: {t5xxl.dtype}" if t5xxl else "t5xxl is None")
logger.info(f"vae device: {vae.device}, dtype: {vae.dtype}" if vae is not None else "vae is None")
logger.info(
f"mmdit device: {accelerator.unwrap_model(mmdit).device}, dtype: {accelerator.unwrap_model(mmdit).dtype}"
if mmdit
else "mmdit is None"
)
logger.info(
f"clip_l device: {accelerator.unwrap_model(clip_l).device}, dtype: {accelerator.unwrap_model(clip_l).dtype}"
if clip_l
else "clip_l is None"
)
logger.info(
f"clip_g device: {accelerator.unwrap_model(clip_g).device}, dtype: {accelerator.unwrap_model(clip_g).dtype}"
if clip_g
else "clip_g is None"
)
logger.info(
f"t5xxl device: {accelerator.unwrap_model(t5xxl).device}, dtype: {accelerator.unwrap_model(t5xxl).dtype}"
if t5xxl
else "t5xxl is None"
)
logger.info(
f"vae device: {accelerator.unwrap_model(vae).device}, dtype: {accelerator.unwrap_model(vae).dtype}"
if vae is not None
else "vae is None"
)

loss_recorder = train_util.LossRecorder()
epoch = 0 # avoid error when max_train_steps is 0
Expand Down

0 comments on commit 5e32ee2

Please sign in to comment.