Skip to content

Commit

Permalink
Merge pull request #1715 from catboxanon/vpred-ztsnr-fixes
Browse files Browse the repository at this point in the history
Update debiased estimation loss function to accommodate V-pred
  • Loading branch information
kohya-ss authored Oct 25, 2024
2 parents 012e7e6 + 0e7c592 commit c632af8
Show file tree
Hide file tree
Showing 9 changed files with 13 additions and 10 deletions.
2 changes: 1 addition & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)

loss = loss.mean() # mean over batch dimension
else:
Expand Down
7 changes: 5 additions & 2 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,13 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los
return loss


def apply_debiased_estimation(loss, timesteps, noise_scheduler):
def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
weight = 1 / torch.sqrt(snr_t)
if v_prediction:
weight = 1 / (snr_t + 1)
else:
weight = 1 / torch.sqrt(snr_t)
loss = weight * loss
return loss

Expand Down
2 changes: 1 addition & 1 deletion sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ def optimizer_hook(parameter: torch.Tensor):
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)

loss = loss.mean() # mean over batch dimension
else:
Expand Down
2 changes: 1 addition & 1 deletion sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def remove_model(old_ckpt_name):
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

Expand Down
2 changes: 1 addition & 1 deletion sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def remove_model(old_ckpt_name):
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

Expand Down
2 changes: 1 addition & 1 deletion train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def train(args):
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

Expand Down
2 changes: 1 addition & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,7 @@ def remove_model(old_ckpt_name):
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

Expand Down
2 changes: 1 addition & 1 deletion train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def remove_model(old_ckpt_name):
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

Expand Down
2 changes: 1 addition & 1 deletion train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def remove_model(old_ckpt_name):
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

Expand Down

0 comments on commit c632af8

Please sign in to comment.