diff --git a/train_network.py b/train_network.py index 784dd5535..24b4b8819 100644 --- a/train_network.py +++ b/train_network.py @@ -1012,12 +1012,11 @@ def update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 - if schedulefree: - optimizer.optimizer.train() - if args.continue_inversion: for t_enc in text_encoders: t_enc.train() + if schedulefree: + optimizer.optimizer.train() metadata["ss_epoch"] = str(epoch + 1) @@ -1153,7 +1152,8 @@ def update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_ input_embeddings_weight.grad[index_no_updates] = 0 optimizer.step() - lr_scheduler.step() + if not schedulefree: + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # normalize embeddings @@ -1233,15 +1233,15 @@ def update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_ accelerator.wait_for_everyone() + if schedulefree: + optimizer.optimizer.eval() + if args.continue_inversion: + update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_to_token_ids) + # 指定エポックごとにモデルを保存 if args.save_every_n_epochs is not None: saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs if is_main_process and saving: - if schedulefree: - optimizer.optimizer.eval() - if args.continue_inversion: - update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_to_token_ids) - ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1, embeddings_map=embeddings_map) @@ -1269,11 +1269,6 @@ def update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_ train_util.save_state_on_train_end(args, accelerator) if is_main_process: - if schedulefree: - optimizer.optimizer.eval() - if args.continue_inversion: - update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_to_token_ids) - ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, network, global_step, num_train_epochs, embeddings_map=embeddings_map, force_sync_upload=True) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index a66ea5639..7cab4a205 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -410,6 +410,7 @@ def train(self, args): for text_encoder in text_encoders: trainable_params += text_encoder.get_input_embeddings().parameters() _, _, optimizer = train_util.get_optimizer(args, trainable_params) + schedulefree = "schedulefree" in args.optimizer_type.lower() # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 @@ -582,6 +583,8 @@ def remove_model(old_ckpt_name): for text_encoder in text_encoders: text_encoder.train() + if schedulefree: + optimizer.optimizer.train() loss_total = 0 @@ -657,11 +660,12 @@ def remove_model(old_ckpt_name): input_embeddings_weight.grad[index_no_updates] = 0 optimizer.step() - lr_scheduler.step() + if not schedulefree: + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + # normalize embeddings with torch.no_grad(): - # normalize embeddings if args.clip_ti_decay: for text_encoder, index_updates in zip( text_encoders, index_updates_list @@ -684,16 +688,6 @@ def remove_model(old_ckpt_name): pre_norm + lambda_ * (0.4 - pre_norm) ) - # # Let's make sure we don't update any embedding weights besides the newly added token - # for text_encoder, orig_embeds_params, index_no_updates in zip( - # text_encoders, orig_embeds_params_list, index_no_updates_list - # ): - # # if full_fp16/bf16, input_embeddings_weight is fp16/bf16, orig_embeds_params is fp32 - # input_embeddings_weight = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight - # input_embeddings_weight[index_no_updates] = orig_embeds_params.to(input_embeddings_weight.dtype)[ - # index_no_updates - # ] - # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) @@ -716,6 +710,8 @@ def remove_model(old_ckpt_name): if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: accelerator.wait_for_everyone() if accelerator.is_main_process: + if schedulefree: + optimizer.optimizer.eval() updated_embs_list = [] for text_encoder, token_ids in zip(text_encoders, token_ids_list): updated_embs = ( @@ -763,6 +759,8 @@ def remove_model(old_ckpt_name): accelerator.wait_for_everyone() + if schedulefree: + optimizer.optimizer.eval() updated_embs_list = [] for text_encoder, token_ids in zip(text_encoders, token_ids_list): updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()