Skip to content

Commit

Permalink
add schedulefree to TI
Browse files Browse the repository at this point in the history
  • Loading branch information
feffy380 committed Apr 10, 2024
1 parent 317d335 commit 8f6d4d7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 26 deletions.
23 changes: 9 additions & 14 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,12 +1012,11 @@ def update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

if schedulefree:
optimizer.optimizer.train()

if args.continue_inversion:
for t_enc in text_encoders:
t_enc.train()
if schedulefree:
optimizer.optimizer.train()

metadata["ss_epoch"] = str(epoch + 1)

Expand Down Expand Up @@ -1153,7 +1152,8 @@ def update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_
input_embeddings_weight.grad[index_no_updates] = 0

optimizer.step()
lr_scheduler.step()
if not schedulefree:
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

# normalize embeddings
Expand Down Expand Up @@ -1233,15 +1233,15 @@ def update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_

accelerator.wait_for_everyone()

if schedulefree:
optimizer.optimizer.eval()
if args.continue_inversion:
update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_to_token_ids)

# 指定エポックごとにモデルを保存
if args.save_every_n_epochs is not None:
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if is_main_process and saving:
if schedulefree:
optimizer.optimizer.eval()
if args.continue_inversion:
update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_to_token_ids)

ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1, embeddings_map=embeddings_map)

Expand Down Expand Up @@ -1269,11 +1269,6 @@ def update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_
train_util.save_state_on_train_end(args, accelerator)

if is_main_process:
if schedulefree:
optimizer.optimizer.eval()
if args.continue_inversion:
update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_to_token_ids)

ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, network, global_step, num_train_epochs, embeddings_map=embeddings_map, force_sync_upload=True)

Expand Down
22 changes: 10 additions & 12 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ def train(self, args):
for text_encoder in text_encoders:
trainable_params += text_encoder.get_input_embeddings().parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
schedulefree = "schedulefree" in args.optimizer_type.lower()

# dataloaderを準備する
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
Expand Down Expand Up @@ -582,6 +583,8 @@ def remove_model(old_ckpt_name):

for text_encoder in text_encoders:
text_encoder.train()
if schedulefree:
optimizer.optimizer.train()

loss_total = 0

Expand Down Expand Up @@ -657,11 +660,12 @@ def remove_model(old_ckpt_name):
input_embeddings_weight.grad[index_no_updates] = 0

optimizer.step()
lr_scheduler.step()
if not schedulefree:
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

# normalize embeddings
with torch.no_grad():
# normalize embeddings
if args.clip_ti_decay:
for text_encoder, index_updates in zip(
text_encoders, index_updates_list
Expand All @@ -684,16 +688,6 @@ def remove_model(old_ckpt_name):
pre_norm + lambda_ * (0.4 - pre_norm)
)

# # Let's make sure we don't update any embedding weights besides the newly added token
# for text_encoder, orig_embeds_params, index_no_updates in zip(
# text_encoders, orig_embeds_params_list, index_no_updates_list
# ):
# # if full_fp16/bf16, input_embeddings_weight is fp16/bf16, orig_embeds_params is fp32
# input_embeddings_weight = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
# input_embeddings_weight[index_no_updates] = orig_embeds_params.to(input_embeddings_weight.dtype)[
# index_no_updates
# ]

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand All @@ -716,6 +710,8 @@ def remove_model(old_ckpt_name):
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
if schedulefree:
optimizer.optimizer.eval()
updated_embs_list = []
for text_encoder, token_ids in zip(text_encoders, token_ids_list):
updated_embs = (
Expand Down Expand Up @@ -763,6 +759,8 @@ def remove_model(old_ckpt_name):

accelerator.wait_for_everyone()

if schedulefree:
optimizer.optimizer.eval()
updated_embs_list = []
for text_encoder, token_ids in zip(text_encoders, token_ids_list):
updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
Expand Down

0 comments on commit 8f6d4d7

Please sign in to comment.