From 025080218fa77442ac382c666047d66eb48156c5 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 23 Nov 2024 17:31:01 +0900 Subject: [PATCH] fix missing infotext cased by conda cache some generation params such as TI hashes or Emphasis is added in sd_hijack / sd_hijack_clip if conda are fetche from cache sd_hijack_clip will not be executed and it won't have a chance to to add generation params the generation params will also be missing if in non low-vram mode because the hijack.extra_generation_params was never read after calculate_hr_conds --- modules/processing.py | 33 +++++++++++++++++++++++++++------ modules/sd_hijack.py | 10 +++++++++- modules/sd_hijack_clip.py | 34 ++++++++++++++++++++++++++++------ modules/util.py | 38 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 102 insertions(+), 13 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 92c3582cc66..0c747601f7e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -16,7 +16,7 @@ from typing import Any import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling, util from modules.rng import slerp # noqa: F401 from modules.sd_hijack import model_hijack from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes @@ -457,6 +457,14 @@ def cached_params(self, required_prompts, steps, extra_network_data, hires_steps opts.emphasis, ) + def apply_generation_params_states(self, generation_params_states): + """add and apply generation_params_states to self.extra_generation_params""" + for key, value in generation_params_states.items(): + if key in self.extra_generation_params and isinstance(current_value := self.extra_generation_params[key], util.GenerationParametersList): + self.extra_generation_params[key] = current_value + value + else: + self.extra_generation_params[key] = value + def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None): """ Returns the result of calling function(shared.sd_model, required_prompts, steps) @@ -480,6 +488,10 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr for cache in caches: if cache[0] is not None and cached_params == cache[0]: + if len(cache) == 3: + generation_params_states, cached_cached_params = cache[2] + if cached_params == cached_cached_params: + self.apply_generation_params_states(generation_params_states) return cache[1] cache = caches[0] @@ -487,6 +499,13 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr with devices.autocast(): cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling) + generation_params_states = model_hijack.extract_generation_params_states() + self.apply_generation_params_states(generation_params_states) + if len(cache) == 2: + cache.append((generation_params_states, cached_params)) + else: + cache[2] = (generation_params_states, cached_params) + cache[0] = cached_params return cache[1] @@ -502,6 +521,8 @@ def setup_conds(self): self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data) self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data) + self.extra_generation_params.update(model_hijack.extra_generation_params) + def get_conds(self): return self.c, self.uc @@ -801,10 +822,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter for key, value in generation_params.items(): try: - if isinstance(value, list): - generation_params[key] = value[index] - elif callable(value): + if callable(value): generation_params[key] = value(**locals()) + elif isinstance(value, list): + generation_params[key] = value[index] except Exception: errors.report(f'Error creating infotext for key "{key}"', exc_info=True) generation_params[key] = None @@ -965,8 +986,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.setup_conds() - p.extra_generation_params.update(model_hijack.extra_generation_params) - # params.txt should be saved after scripts.process_batch, since the # infotext could be modified by that callback # Example: a wildcard processed by process_batch sets an extra model @@ -1513,6 +1532,8 @@ def calculate_hr_conds(self): self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps) self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps) + self.extra_generation_params.update(model_hijack.extra_generation_params) + def setup_conds(self): if self.is_hr_pass: # if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 0de83054186..4ac22ec53bc 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -2,7 +2,7 @@ from torch.nn.functional import silu from types import MethodType -from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches +from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches, util from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18 @@ -321,6 +321,14 @@ def clear_comments(self): self.comments = [] self.extra_generation_params = {} + def extract_generation_params_states(self): + """Extracts GenerationParametersList so that they can be cached and restored later""" + states = {} + for key in list(self.extra_generation_params): + if isinstance(self.extra_generation_params[key], util.GenerationParametersList): + states[key] = self.extra_generation_params.pop(key) + return states + def get_prompt_lengths(self, text): if self.clip is None: return "-", "-" diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index a479148fc21..62c632f82f5 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -3,7 +3,7 @@ import torch -from modules import prompt_parser, devices, sd_hijack, sd_emphasis +from modules import prompt_parser, devices, sd_hijack, sd_emphasis, util from modules.shared import opts @@ -27,6 +27,30 @@ def __init__(self): are applied by sd_hijack.EmbeddingsWithFixes's forward function.""" +class EmphasisMode(util.GenerationParametersList): + def __init__(self, emphasis_mode:str = None): + super().__init__() + self.emphasis_mode = emphasis_mode + + def __call__(self, *args, **kwargs): + return self.emphasis_mode + + def __add__(self, other): + if isinstance(other, EmphasisMode): + return self if self.emphasis_mode else other + elif isinstance(other, str): + return self.__str__() + other + return NotImplemented + + def __radd__(self, other): + if isinstance(other, str): + return other + self.__str__() + return NotImplemented + + def __str__(self): + return self.emphasis_mode if self.emphasis_mode else '' + + class TextConditionalModel(torch.nn.Module): def __init__(self): super().__init__() @@ -238,12 +262,10 @@ def forward(self, texts): hashes.append(f"{name}: {shorthash}") if hashes: - if self.hijack.extra_generation_params.get("TI hashes"): - hashes.append(self.hijack.extra_generation_params.get("TI hashes")) - self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes) + self.hijack.extra_generation_params["TI hashes"] = util.GenerationParametersList(hashes) - if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original": - self.hijack.extra_generation_params["Emphasis"] = opts.emphasis + if opts.emphasis != 'Original' and any(x for x in texts if '(' in x or '[' in x): + self.hijack.extra_generation_params["Emphasis"] = EmphasisMode(opts.emphasis) if self.return_pooled: return torch.hstack(zs), zs[0].pooled diff --git a/modules/util.py b/modules/util.py index baeba2fa271..1aef93cfef6 100644 --- a/modules/util.py +++ b/modules/util.py @@ -288,3 +288,41 @@ def compare_sha256(file_path: str, hash_prefix: str) -> bool: for chunk in iter(lambda: f.read(blksize), b""): hash_sha256.update(chunk) return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower()) + + +class GenerationParametersList(list): + """A special object used in sd_hijack.StableDiffusionModelHijack for setting extra_generation_params + due to StableDiffusionProcessing.get_conds_with_caching + extra_generation_params set in StableDiffusionModelHijack will be lost when cached is used + + When an extra_generation_params is set in StableDiffusionModelHijack using this object, + the params will be extracted by StableDiffusionModelHijack.extract_generation_params_states + the extracted params will be cached in StableDiffusionProcessing.get_conds_with_caching + and applyed to StableDiffusionProcessing.extra_generation_params by StableDiffusionProcessing.apply_generation_params_states + + Example see modules.sd_hijack_clip.TextConditionalModel.hijack.extra_generation_params 'TI hashes' 'Emphasis' + + Depending on the use case the methods can be overwritten. + In general __call__ method should return str or None, as normally it's called in modules.processing.create_infotext. + When called by create_infotext it will access to the locals() of the caller, + if return str, the value will be written to infotext, if return None will be ignored. + """ + + def __call__(self, *args, **kwargs): + return ', '.join(sorted(set(self), key=natural_sort_key)) + + def __add__(self, other): + if isinstance(other, GenerationParametersList): + return self.__class__([*self, *other]) + elif isinstance(other, str): + return self.__str__() + other + return NotImplemented + + def __radd__(self, other): + if isinstance(other, str): + return other + self.__str__() + return NotImplemented + + def __str__(self): + return self.__call__() +