Skip to content

Commit

Permalink
Merge pull request #818 from bghira/main
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
bghira authored Aug 19, 2024
2 parents 3172256 + 719c0a4 commit 4f3d545
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"terminus": 8.0,
"sd3": 5.0,
}
lycoris_algos = ["lokr"]

lora_ranks = [1, 16, 64, 128, 256]
learning_rates_by_rank = {
1: "3e-4",
Expand Down Expand Up @@ -84,7 +84,7 @@ def configure_lycoris():
print("6. DyLoRA - Dynamic updates, efficient with large dims. (algo=dylora)")
print("7. Diag-OFT - Fast convergence with orthogonal fine-tuning. (algo=diag-oft)")
print("8. BOFT - Advanced version of Diag-OFT with more flexibility. (algo=boft)")
print("9. GLoRA/GLoKr - Generalized, still in development. (algo=glora/glokr)\n")
print("9. GLoRA - Generalized LoRA. (algo=glora)\n")

# Prompt user to select an algorithm
algo = prompt_user(
Expand Down
7 changes: 4 additions & 3 deletions helpers/training/save_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,9 @@ def __init__(
self.pipeline_class = FluxPipeline
elif args.flux and args.flux_attention_masked_training:
from helpers.models.flux.transformer import (
FluxTransformer2DModelWithMasking
FluxTransformer2DModelWithMasking,
)

self.denoiser_class = FluxTransformer2DModelWithMasking
self.pipeline_class = FluxPipeline
elif hasattr(args, "hunyuan_dit") and args.hunyuan_dit:
Expand Down Expand Up @@ -313,7 +314,7 @@ def save_model_hook(self, models, weights, output_dir):
StateTracker.save_training_state(
os.path.join(output_dir, "training_state.json")
)
if "lora" in self.args.model_type and self.args.lora_type == "Standard":
if "lora" in self.args.model_type and self.args.lora_type == "standard":
self._save_lora(models=models, weights=weights, output_dir=output_dir)
return
elif "lora" in self.args.model_type and self.args.lora_type == "lycoris":
Expand Down Expand Up @@ -461,7 +462,7 @@ def load_model_hook(self, models, input_dir):
f"Could not find training_state.json in checkpoint dir {input_dir}"
)

if "lora" in self.args.model_type and self.args.lora_type == "Standard":
if "lora" in self.args.model_type and self.args.lora_type == "standard":
self._load_lora(models=models, input_dir=input_dir)
elif "lora" in self.args.model_type and self.args.lora_type == "lycoris":
self._load_lycoris(models=models, input_dir=input_dir)
Expand Down

0 comments on commit 4f3d545

Please sign in to comment.