Skip to content

Commit

Permalink
Documentation updates, memory use optimisations
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira committed Oct 15, 2023
1 parent 4d76426 commit 09b0426
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
2 changes: 1 addition & 1 deletion documentation/DEEPSPEED.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ SimpleTuner v0.7 includes preliminary support for training SDXL using DeepSpeed
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... Off | 00000000:08:00.0 Off | Off |
| 0% 43C P2 100W / 450W | 9237MiB / 24564MiB | 0% Default |
| 0% 43C P2 100W / 450W | 9237MiB / 24564MiB | 100% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
Expand Down
1 change: 1 addition & 0 deletions helpers/caching/sdxl_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def encode_sdxl_prompt(
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)

# Clear out anything we moved to the text encoder device
text_input_ids.to('cpu')
del text_input_ids

prompt_embeds_list.append(prompt_embeds)
Expand Down
16 changes: 9 additions & 7 deletions train_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,13 +817,15 @@ def main():
accelerator.register_load_state_pre_hook(model_hooks.load_model_hook)

# Prepare everything with our `accelerator`.
logger.info(f"Loading our accelerator...")
unet, train_dataloader, lr_scheduler, optimizer = accelerator.prepare(
unet, train_dataloader, lr_scheduler, optimizer
)
if args.use_ema:
logger.info("Moving EMA model weights to accelerator...")
ema_unet.to(accelerator.device, dtype=weight_dtype)
disable_accelerator = os.environ.get('SIMPLETUNER_DISABLE_ACCELERATOR', False)
if not disable_accelerator:
logger.info(f"Loading our accelerator...")
unet, train_dataloader, lr_scheduler, optimizer = accelerator.prepare(
unet, train_dataloader, lr_scheduler, optimizer
)
if args.use_ema:
logger.info("Moving EMA model weights to accelerator...")
ema_unet.to(accelerator.device, dtype=weight_dtype)

# Move vae, unet and text_encoder to device and cast to weight_dtype
# The VAE is in float32 to avoid NaN losses.
Expand Down

0 comments on commit 09b0426

Please sign in to comment.