Skip to content

Commit

Permalink
fix sdxl support
Browse files Browse the repository at this point in the history
  • Loading branch information
feffy380 committed Mar 23, 2024
1 parent 4bac4dc commit 1d62b98
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
2 changes: 0 additions & 2 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,13 +965,11 @@ def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True)
if apply_text_encoder:
logger.info("enable LoRA for text encoder")
else:
del self.text_encoder_loras
self.text_encoder_loras = []

if apply_unet:
logger.info("enable LoRA for U-Net")
else:
del self.unet_loras
self.unet_loras = []

for lora in self.text_encoder_loras + self.unet_loras:
Expand Down
31 changes: 17 additions & 14 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,17 @@ def train(self, args):
data = torch.load(embeds_file, map_location="cpu")

token_string = Path(embeds_file).stem
embeds, _shape, num_vectors_per_token = self.create_embedding_from_data(data, token_string)
embeds_list, _shape, num_vectors_per_token = self.create_embedding_from_data(data, token_string)
if isinstance(embeds_list, dict) and "clip_l" in embeds_list and "clip_g" in embeds_list:
embeds_list = [embeds_list["clip_l"], embeds_list["clip_g"]]
else:
embeds_list = [embeds_list]
embedding_to_token_ids[token_string] = []

token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
logger.info(f"Loaded token strings {token_strings}")
for i, (tokenizer, text_encoder) in enumerate(zip(tokenizers, text_encoders)):
for i, (tokenizer, text_encoder, embeds) in enumerate(zip(tokenizers, text_encoders, embeds_list)):
num_added_tokens = tokenizer.add_tokens(token_strings)

assert (
num_added_tokens == num_vectors_per_token
), f"The tokenizer already contains {token_string}. Please pass a different word that is not already in the tokenizer. / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
Expand All @@ -259,10 +262,12 @@ def train(self, args):
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))

# Initialise the newly added token with the provided embeddings
token_embeds = text_encoder.get_input_embeddings().weight.data
for token_id, embed in zip(token_ids, embeds):
text_encoder.get_input_embeddings().weight.data[token_id] = embed
embeddings_map[token_string] = embeds
token_embeds[token_id] = embed
embedding_to_token_ids[token_string].append(token_ids)
embeddings_map[token_string] = embeds_list

# データセットを準備する
if args.dataset_class is None:
Expand Down Expand Up @@ -304,11 +309,11 @@ def train(self, args):
]
}

blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)

current_epoch = Value("i", 0)
current_step = Value("i", 0)
Expand Down Expand Up @@ -543,13 +548,10 @@ def train(self, args):

# Build list of original embeddings to freeze all but the modified embeddings
if args.continue_inversion:
token_ids_list = []
for emb_name in embeddings_map.keys():
for i, sublist in enumerate(embedding_to_token_ids[emb_name]):
if i >= len(token_ids_list):
token_ids_list.append(sublist)
else:
token_ids_list[i].extend(sublist)
token_ids_list = [[] for _ in text_encoders]
for sublists in embedding_to_token_ids.values():
for i, sublist in enumerate(sublists):
token_ids_list[i].extend(sublist)

index_no_updates_list = []
index_updates_list = []
Expand All @@ -559,6 +561,7 @@ def train(self, args):
index_no_updates_list.append(index_no_updates)
index_updates = ~index_no_updates
index_updates_list.append(index_updates)

orig_embeds_params = accelerator.unwrap_model(t_enc).get_input_embeddings().weight.data.detach().clone()
orig_embeds_params_list.append(orig_embeds_params)

Expand Down

0 comments on commit 1d62b98

Please sign in to comment.