Skip to content

Commit

Permalink
Reduce memory usage when applying DORA: #3557
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed May 25, 2024
1 parent 58c9838 commit efa5a71
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from comfy.types import UnetWrapperFunction


def apply_weight_decompose(dora_scale, weight):
def weight_decompose_scale(dora_scale, weight):
weight_norm = (
weight.transpose(0, 1)
.reshape(weight.shape[1], -1)
Expand All @@ -18,7 +18,7 @@ def apply_weight_decompose(dora_scale, weight):
.transpose(0, 1)
)

return weight * (dora_scale / weight_norm).type(weight.dtype)
return (dora_scale / weight_norm).type(weight.dtype)

def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy()
Expand Down Expand Up @@ -365,7 +365,7 @@ def calculate_weight(self, patches, weight, key):
try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
Expand Down Expand Up @@ -407,7 +407,7 @@ def calculate_weight(self, patches, weight, key):
try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
Expand Down Expand Up @@ -439,7 +439,7 @@ def calculate_weight(self, patches, weight, key):
try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
Expand All @@ -456,7 +456,7 @@ def calculate_weight(self, patches, weight, key):
try:
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
Expand Down

0 comments on commit efa5a71

Please sign in to comment.