Skip to content

Commit

Permalink
Fix SD3 LoRA training to work (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 27, 2024
1 parent db2b4d4 commit a1255d6
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 17 deletions.
20 changes: 10 additions & 10 deletions library/strategy_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ def encode_tokens(
lg_pooled = None
else:
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"

drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate)
if drop_l:
l_pooled = torch.zeros((l_tokens.shape[0], 768), device=l_tokens.device, dtype=l_tokens.dtype)
l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=l_tokens.device, dtype=l_tokens.dtype)
l_pooled = torch.zeros((l_tokens.shape[0], 768), device=clip_l.device, dtype=clip_l.dtype)
l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=clip_l.device, dtype=clip_l.dtype)
if l_attn_mask is not None:
l_attn_mask = torch.zeros_like(l_attn_mask)
l_attn_mask = torch.zeros_like(l_attn_mask, device=clip_l.device)
else:
l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None
prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True)
Expand All @@ -126,10 +126,10 @@ def encode_tokens(

drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate)
if drop_g:
g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=g_tokens.device, dtype=g_tokens.dtype)
g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=g_tokens.device, dtype=g_tokens.dtype)
g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=clip_g.device, dtype=clip_g.dtype)
g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=clip_g.device, dtype=clip_g.dtype)
if g_attn_mask is not None:
g_attn_mask = torch.zeros_like(g_attn_mask)
g_attn_mask = torch.zeros_like(g_attn_mask, device=clip_g.device)
else:
g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None
prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True)
Expand All @@ -144,9 +144,9 @@ def encode_tokens(
else:
drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate)
if drop_t5:
t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5_tokens.device, dtype=t5_tokens.dtype)
t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5xxl.device, dtype=t5xxl.dtype)
if t5_attn_mask is not None:
t5_attn_mask = torch.zeros_like(t5_attn_mask)
t5_attn_mask = torch.zeros_like(t5_attn_mask, device=t5xxl.device)
else:
t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True)
Expand Down Expand Up @@ -187,7 +187,7 @@ def drop_cached_text_encoder_outputs(
if t5_attn_mask is not None:
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])

return lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]

def concat_encodings(
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
Expand Down
15 changes: 8 additions & 7 deletions sd3_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def get_text_encoding_strategy(self, args):
args.apply_t5_attn_mask,
args.clip_l_dropout_rate,
args.clip_g_dropout_rate,
args.t5xxl_dropout_rate,
args.t5_dropout_rate,
)

def post_process_network(self, args, accelerator, network, text_encoders, unet):
Expand Down Expand Up @@ -415,12 +415,13 @@ def forward(hidden_states):
prepare_fp8(text_encoder, weight_dtype)

def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
# drop cached text encoder outputs
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(text_encoder_outputs_list)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
# # drop cached text encoder outputs
# text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
# if text_encoder_outputs_list is not None:
# text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
# text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
# batch["text_encoder_outputs_list"] = text_encoder_outputs_list
pass


def setup_parser() -> argparse.ArgumentParser:
Expand Down
20 changes: 20 additions & 0 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,6 +1151,17 @@ def remove_model(old_ckpt_name):
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs

# if text_encoder_outputs_list is not None:
# lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list
# for i in range(len(lg_out)):
# print(
# f"[{i}] cached L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, cached G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, "
# f"cached T5: {t5_out[i].max()}, "
# f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0},"
# f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}"
# )

if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
Expand Down Expand Up @@ -1182,6 +1193,15 @@ def remove_model(old_ckpt_name):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]

# lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds
# for i in range(len(lg_out)):
# print(
# f"[{i}] train L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, train G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, "
# f"train T5: {t5_out[i].max()}, "
# f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0},"
# f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}"
# )

# sample noise, call unet, get target
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
args,
Expand Down

0 comments on commit a1255d6

Please sign in to comment.