Skip to content

Commit

Permalink
implement LoRA+
Browse files Browse the repository at this point in the history
  • Loading branch information
feffy380 committed Apr 5, 2024
1 parent aba39c0 commit 3a28d57
Showing 1 changed file with 41 additions and 19 deletions.
60 changes: 41 additions & 19 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,10 @@ def create_network(
if network_alpha is None:
network_alpha = 1.0

lora_plus_ratio = kwargs.get("lora_plus_ratio", None)
if lora_plus_ratio is not None:
lora_plus_ratio = float(lora_plus_ratio)

# extract dim/alpha for conv2d, and block dim
conv_dim = kwargs.get("conv_dim", None)
conv_alpha = kwargs.get("conv_alpha", None)
Expand Down Expand Up @@ -479,6 +483,7 @@ def create_network(
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
lora_plus_ratio=lora_plus_ratio,
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
Expand Down Expand Up @@ -708,6 +713,10 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
else:
weights_sd = torch.load(file, map_location="cpu")

lora_plus_ratio = kwargs.get("lora_plus_ratio", None)
if lora_plus_ratio is not None:
lora_plus_ratio = float(lora_plus_ratio)

# get dim/alpha mapping
modules_dim = {}
modules_alpha = {}
Expand All @@ -731,7 +740,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
module_class = LoRAInfModule if for_inference else LoRAModule

network = LoRANetwork(
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
text_encoder, unet, multiplier=multiplier, lora_plus_ratio=lora_plus_ratio, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
)

# block lr
Expand Down Expand Up @@ -762,6 +771,7 @@ def __init__(
multiplier: float = 1.0,
lora_dim: int = 4,
alpha: float = 1,
lora_plus_ratio: Optional[float] = None,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
Expand Down Expand Up @@ -793,6 +803,7 @@ def __init__(

self.lora_dim = lora_dim
self.alpha = alpha
self.lora_plus_ratio = lora_plus_ratio
self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha
self.dropout = dropout
Expand Down Expand Up @@ -1043,17 +1054,28 @@ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []

def enumerate_params(loras):
params = []
def assemble_params(loras, lr):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
params.extend(lora.parameters())
for name, param in lora.named_parameters():
group = param_groups["lora"]
if self.lora_plus_ratio is not None and "lora_up" in name:
group = param_groups["plus"]
group[f"{lora.lora_name}.{name}"] = param

params = []
for key, group in param_groups.items():
param_data = {"params": group.values()}
if lr is not None:
param_data["lr"] = lr
if key == "plus":
param_data["lr"] *= self.lora_plus_ratio
params.append(param_data)
return params

if self.text_encoder_loras and text_encoder_lr != 0.0:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
params = assemble_params(self.text_encoder_loras, text_encoder_lr)
all_params.extend(params)

if self.unet_loras and unet_lr != 0.0:
if self.block_lr:
Expand All @@ -1067,21 +1089,21 @@ def enumerate_params(loras):

# blockごとにパラメータを設定する
for idx, block_loras in block_idx_to_lora.items():
param_data = {"params": enumerate_params(block_loras)}

if unet_lr is not None:
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]))
elif default_lr is not None:
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
if ("lr" in param_data) and (param_data["lr"] == 0):
continue
all_params.append(param_data)
lr = default_lr * self.get_lr_weight(block_loras[0])
if lr == 0:
continue
params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]))
all_params.extend(params)

else:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
lr = unet_lr
if lr is None:
lr = default_lr
params = assemble_params(self.unet_loras, lr)
all_params.extend(params)

return all_params

Expand Down

0 comments on commit 3a28d57

Please sign in to comment.