From 5e32ee26a13394fdee77149c4e96b78c58eabc5e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 2 Nov 2024 15:32:16 +0900 Subject: [PATCH] fix crashing in DDP training closes #1751 --- sd3_train.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index f64e2da2c..e03d1708b 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -838,11 +838,31 @@ def optimizer_hook(parameter: torch.Tensor): accelerator.log({}, step=0) # show model device and dtype - logger.info(f"mmdit device: {mmdit.device}, dtype: {mmdit.dtype}" if mmdit else "mmdit is None") - logger.info(f"clip_l device: {clip_l.device}, dtype: {clip_l.dtype}" if clip_l else "clip_l is None") - logger.info(f"clip_g device: {clip_g.device}, dtype: {clip_g.dtype}" if clip_g else "clip_g is None") - logger.info(f"t5xxl device: {t5xxl.device}, dtype: {t5xxl.dtype}" if t5xxl else "t5xxl is None") - logger.info(f"vae device: {vae.device}, dtype: {vae.dtype}" if vae is not None else "vae is None") + logger.info( + f"mmdit device: {accelerator.unwrap_model(mmdit).device}, dtype: {accelerator.unwrap_model(mmdit).dtype}" + if mmdit + else "mmdit is None" + ) + logger.info( + f"clip_l device: {accelerator.unwrap_model(clip_l).device}, dtype: {accelerator.unwrap_model(clip_l).dtype}" + if clip_l + else "clip_l is None" + ) + logger.info( + f"clip_g device: {accelerator.unwrap_model(clip_g).device}, dtype: {accelerator.unwrap_model(clip_g).dtype}" + if clip_g + else "clip_g is None" + ) + logger.info( + f"t5xxl device: {accelerator.unwrap_model(t5xxl).device}, dtype: {accelerator.unwrap_model(t5xxl).dtype}" + if t5xxl + else "t5xxl is None" + ) + logger.info( + f"vae device: {accelerator.unwrap_model(vae).device}, dtype: {accelerator.unwrap_model(vae).dtype}" + if vae is not None + else "vae is None" + ) loss_recorder = train_util.LossRecorder() epoch = 0 # avoid error when max_train_steps is 0