Skip to content

Commit

Permalink
fix sdxl pti, support loading lora before textual inversion
Browse files Browse the repository at this point in the history
  • Loading branch information
feffy380 committed Apr 1, 2024
1 parent 50e8c69 commit ef4f5ea
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 25 deletions.
4 changes: 2 additions & 2 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,13 +1049,13 @@ def enumerate_params(loras):
params.extend(lora.parameters())
return params

if self.text_encoder_loras:
if self.text_encoder_loras and text_encoder_lr != 0.0:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)

if self.unet_loras:
if self.unet_loras and unet_lr != 0.0:
if self.block_lr:
# 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
block_idx_to_lora = {}
Expand Down
58 changes: 35 additions & 23 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def train(self, args):

token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
logger.info(f"Loaded token strings {token_strings}")
for i, (tokenizer, text_encoder, embeds) in enumerate(zip(tokenizers, text_encoders, embeds_list)):
for i, (tokenizer, t_enc, embeds) in enumerate(zip(tokenizers, text_encoders, embeds_list)):
num_added_tokens = tokenizer.add_tokens(token_strings)
assert (
num_added_tokens == num_vectors_per_token
Expand All @@ -260,10 +260,10 @@ def train(self, args):
), f"token ids is not end of tokenize: tokenizer {i+1}, {token_ids}, {len(tokenizer)}"

# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
t_enc.resize_token_embeddings(len(tokenizer))

# Initialise the newly added token with the provided embeddings
token_embeds = text_encoder.get_input_embeddings().weight.data
token_embeds = t_enc.get_input_embeddings().weight.data
for token_id, embed in zip(token_ids, embeds):
token_embeds[token_id] = embed
embedding_to_token_ids[token_string].append(token_ids)
Expand Down Expand Up @@ -428,14 +428,18 @@ def train(self, args):
)
args.scale_weight_norms = False

train_unet = not args.network_train_text_encoder_only
train_unet = not args.network_train_text_encoder_only
train_text_encoder = self.is_train_text_encoder(args)
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)

if args.network_weights is not None:
info = network.load_weights(args.network_weights)
accelerator.print(f"load network weights from {args.network_weights}: {info}")

# disable training if LR is 0 to save memory. useful for resuming LoRA with frozen TE or Unet
train_unet = train_unet and args.unet_lr != 0.0
train_text_encoder = train_text_encoder and args.text_encoder_lr != 0.0

if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
for t_enc in text_encoders:
Expand All @@ -452,9 +456,9 @@ def train(self, args):
# Add embeddings params when continuing the inversion
if args.continue_inversion:
# TODO: might be good to add the embedding to the LoRA module directly to continue training ("bundle_emb.{emb_name}.string_to_param.*")
for text_encoder in text_encoders:
for t_enc in text_encoders:
trainable_params.append({
"params": text_encoder.get_input_embeddings().parameters(),
"params": t_enc.get_input_embeddings().parameters(),
"lr": args.embedding_lr or args.text_encoder_lr or args.learning_rate,
})
except TypeError:
Expand Down Expand Up @@ -535,7 +539,7 @@ def train(self, args):
unet = accelerator.prepare(unet)
else:
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder:
if train_text_encoder or args.continue_inversion:
if len(text_encoders) > 1:
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
else:
Expand Down Expand Up @@ -580,7 +584,7 @@ def train(self, args):
t_enc.train()

# set top parameter requires_grad = True for gradient checkpointing works
if train_text_encoder:
if train_text_encoder or args.continue_inversion:
t_enc.text_model.embeddings.requires_grad_(True)
else:
unet.eval()
Expand Down Expand Up @@ -892,10 +896,18 @@ def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, embeddings_map, force_s
if len(embeddings_map.keys()) > 0:
# Bundle embeddings in LoRA state dict
state_dict = unwrapped_nw.state_dict()
is_sdxl = len(next(iter(embeddings_map.values()))) == 2
for emb_name in embeddings_map.keys():
accelerator.print(f"Bundling embedding: {emb_name}")
key = f"bundle_emb.{emb_name}.string_to_param.*"
state_dict[key] = embeddings_map[emb_name]
if is_sdxl:
embs = embeddings_map[emb_name]
key1 = f"bundle_emb.{emb_name}.clip_l"
state_dict[key1] = embs[0]
key2 = f"bundle_emb.{emb_name}.clip_g"
state_dict[key2] = embs[1]
else:
key = f"bundle_emb.{emb_name}.string_to_param.*"
state_dict[key] = embeddings_map[emb_name]

if metadata_to_save is not None and len(metadata_to_save) == 0:
metadata_to_save = None
Expand Down Expand Up @@ -943,8 +955,8 @@ def remove_model(old_ckpt_name):
current_epoch.value = epoch + 1

if args.continue_inversion:
for text_encoder in text_encoders:
text_encoder.train()
for t_enc in text_encoders:
t_enc.train()

metadata["ss_epoch"] = str(epoch + 1)

Expand Down Expand Up @@ -1061,8 +1073,8 @@ def remove_model(old_ckpt_name):

# zero out gradients for all tokens we aren't training
if args.continue_inversion:
for text_encoder, index_no_updates in zip(text_encoders, index_no_updates_list):
input_embeddings_weight = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
for t_enc, index_no_updates in zip(text_encoders, index_no_updates_list):
input_embeddings_weight = accelerator.unwrap_model(t_enc).get_input_embeddings().weight
input_embeddings_weight.grad[index_no_updates] = 0

optimizer.step()
Expand All @@ -1074,35 +1086,35 @@ def remove_model(old_ckpt_name):
with torch.no_grad():
# normalize embeddings
if args.clip_ti_decay:
for text_encoder, index_updates in zip(text_encoders, index_updates_list):
for t_enc, index_updates in zip(text_encoders, index_updates_list):
pre_norm = (
text_encoder.get_input_embeddings()
t_enc.get_input_embeddings()
.weight[index_updates, :]
.norm(dim=-1, keepdim=True)
)
lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0])
text_encoder.get_input_embeddings().weight[
t_enc.get_input_embeddings().weight[
index_updates
] = torch.nn.functional.normalize(
text_encoder.get_input_embeddings().weight[index_updates, :],
t_enc.get_input_embeddings().weight[index_updates, :],
dim=-1,
) * (
pre_norm + lambda_ * (0.4 - pre_norm)
)

# # Let's make sure we don't update any embedding weights besides the newly added token
# for text_encoder, orig_embeds_params, index_no_updates in zip(
# for t_enc, orig_embeds_params, index_no_updates in zip(
# text_encoders, orig_embeds_params_list, index_no_updates_list
# ):
# input_embeddings_weight = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
# input_embeddings_weight = accelerator.unwrap_model(t_enc).get_input_embeddings().weight
# input_embeddings_weight[index_no_updates] = orig_embeds_params[index_no_updates]

# Update embeddings map (for saving)
# TODO: this is not optimal, might need to be refactored
for emb_name in embeddings_map.keys():
emb_token_ids = embedding_to_token_ids[emb_name]
updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[emb_token_ids].data.detach().clone()
embeddings_map[emb_name] = updated_embs
for i, (t_enc, emb_token_ids) in enumerate(zip(text_encoders, embedding_to_token_ids[emb_name])):
updated_embs = accelerator.unwrap_model(t_enc).get_input_embeddings().weight[emb_token_ids].data.detach().clone()
embeddings_map[emb_name][i] = updated_embs

if args.scale_weight_norms:
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
Expand Down
32 changes: 32 additions & 0 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,26 @@ def train(self, args):

self.assert_extra_args(args, train_dataset_group)

# merge network before training
if args.base_weights is not None:
import sys, importlib
# 差分追加学習のためにモデルを読み込む
sys.path.append(os.path.dirname(__file__))
accelerator.print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)

# base_weights が指定されている場合は、指定された重みを読み込みマージする
for weight_path in args.base_weights:
multiplier = 1.0
accelerator.print(f"merging module: {weight_path}")

module, weights_sd = network_module.create_network_from_weights(
multiplier, weight_path, vae, text_encoders, unet, for_inference=True
)
module.merge_to(text_encoders, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")

accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")

current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
Expand Down Expand Up @@ -830,6 +850,18 @@ def setup_parser() -> argparse.ArgumentParser:
help="Keep the norm of the textual inversion intact",
)

# for merging network before training
parser.add_argument(
"--network_module", type=str, default="networks.lora", help="network module to train / 学習対象のネットワークのモジュール"
)
parser.add_argument(
"--base_weights",
type=str,
default=None,
nargs="*",
help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル",
)

return parser


Expand Down

0 comments on commit ef4f5ea

Please sign in to comment.