Skip to content

Commit

Permalink
add sd3 model adapter support
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira committed Oct 24, 2024
1 parent 464f251 commit 72b5382
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,27 @@ def clear_adapters(self):
self.loaded_adapters[clean_adapter_name] = None
self.pipeline.unload_lora_weights()
self.loaded_adapters = {}

def apply_adapters(self, user_config: dict, model_prefix: str = "model"):
# we will apply user LoRAs one at a time. the lora name can be split optionally with a : at the end so that <lora_path>:<strength> are set.
self.clear_adapters()
for i in range(1, 11, 1):
user_adapter = user_config.get(f"{model_prefix}_adapter_{i}", None)
if user_adapter is not None and user_adapter != "":
pieces = user_adapter.split(":")
adapter_strength = 1
adapter_type = "lora"
if len(pieces) == 1:
adapter_path = pieces[0]
elif len(pieces) == 2:
adapter_type, adapter_path = pieces
elif len(pieces) == 3:
adapter_type, adapter_path, adapter_strength = pieces
try:
self.load_adapter(
adapter_type, adapter_path, adapter_strength, fuse_adapter=False
)
except Exception as e:
import traceback
logging.error(f"Failed to download adapter {adapter_path}: {e}, {traceback.format_exc()}")
continue
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,7 @@ def __call__(self, **args):
args["guidance_scale_real"] = float(args["guidance_scale"])
args["guidance_scale"] = float(user_config.get("flux_guidance_scale", 4.0))

# we will apply user LoRAs one at a time. the lora name can be split optionally with a : at the end so that <lora_path>:<strength> are set.
self.clear_adapters()
for i in range(1, 11, 1):
user_adapter = user_config.get(f"flux_adapter_{i}", None)
if user_adapter is not None and user_adapter != "":
pieces = user_adapter.split(":")
adapter_strength = 1
adapter_type = "lora"
if len(pieces) == 1:
adapter_path = pieces[0]
elif len(pieces) == 2:
adapter_type, adapter_path = pieces
elif len(pieces) == 3:
adapter_type, adapter_path, adapter_strength = pieces
try:
self.load_adapter(
adapter_type, adapter_path, adapter_strength, fuse_adapter=False
)
except Exception as e:
import traceback
logging.error(f"Failed to download adapter {adapter_path}: {e}, {traceback.format_exc()}")
continue
self.apply_adapters(user_config, model_prefix="flux")

# Call the pipeline with arguments and return the images
return self.pipeline(**args).images
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,7 @@ def __call__(self, **args):
if "guidance_scale" in args:
args["guidance_scale"] = float(args["guidance_scale"])

self.apply_adapters(user_config)

# Call the pipeline with arguments and return the images
return self.pipeline(**args).images
9 changes: 6 additions & 3 deletions discord_tron_client/message/discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,12 @@ def print_prompt(payload, execute_duration="unknown", attributes: Dict = None):
stage1_guidance = f"\n**Stage 2 Guidance**: `!settings refiner_guidance {refiner_guidance}`"

flux_adapter = user_config.get('flux_adapter_1')
flux_adapter_text = ""
model_adapter = user_config.get('model_adapter_1')
model_adapter_text = ""
if "black-forest-labs" in model_id or 'flux' in model_id.lower() and flux_adapter:
flux_adapter_text = f"\n**Flux Adapter**: `!settings flux_adapter_1 {flux_adapter}`\n"
model_adapter_text = f"\n**Flux Adapter**: `!settings flux_adapter_1 {flux_adapter}`\n"
elif model_adapter is not None and model_adapter != "":
model_adapter_text = f"\n**Model Adapter**: `!settings model_adapter_1 {model_adapter}`\n"

guidance_rescale = user_config.get("guidance_rescale")
if latent_refiner == "On":
Expand Down Expand Up @@ -156,7 +159,7 @@ def print_prompt(payload, execute_duration="unknown", attributes: Dict = None):
f"<@{author_id}>\n"
f"**Prompt**: {prompt[:255]}{truncate_suffix}\n"
f"**Settings**: `!seed {seed}`, `!guidance {user_config['guidance_scaling']}`, `!guidance_rescale {guidance_rescale}`, `!steps {steps}`, `!strength {strength}`, `!resolution {resolution_string}`{stage1_guidance}\n"
f"**Model**: `{model_id}` (`{latest_hash}` {last_modified})\n{refiner_status}{flux_adapter_text}"
f"**Model**: `{model_id}` (`{latest_hash}` {last_modified})\n{refiner_status}{model_adapter_text}"
f"**{HardwareInfo.get_identifier()}**: {payload['gpu_power_consumption']}W power used in {execute_time} seconds via {system_hw['gpu_type']} ({vmem}G)\n" # , on a {system_hw['cpu_type']} with {system_hw['memory_amount']}G RAM\n"
# f"**Job ID:** `{payload['job_id']}`\n"
)
Expand Down

0 comments on commit 72b5382

Please sign in to comment.