Skip to content

Commit

Permalink
fix missing infotext cased by conda cache
Browse files Browse the repository at this point in the history
some generation params such as TI hashes or Emphasis is added in sd_hijack / sd_hijack_clip
if conda are fetche from cache sd_hijack_clip will not be executed and it won't have a chance to to add generation params

the generation params will also be missing if in non low-vram mode because the hijack.extra_generation_params was never read after calculate_hr_conds
  • Loading branch information
w-e-w committed Nov 24, 2024
1 parent 023454b commit 0250802
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 13 deletions.
33 changes: 27 additions & 6 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Any

import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling, util
from modules.rng import slerp # noqa: F401
from modules.sd_hijack import model_hijack
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
Expand Down Expand Up @@ -457,6 +457,14 @@ def cached_params(self, required_prompts, steps, extra_network_data, hires_steps
opts.emphasis,
)

def apply_generation_params_states(self, generation_params_states):
"""add and apply generation_params_states to self.extra_generation_params"""
for key, value in generation_params_states.items():
if key in self.extra_generation_params and isinstance(current_value := self.extra_generation_params[key], util.GenerationParametersList):
self.extra_generation_params[key] = current_value + value
else:
self.extra_generation_params[key] = value

def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
"""
Returns the result of calling function(shared.sd_model, required_prompts, steps)
Expand All @@ -480,13 +488,24 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr

for cache in caches:
if cache[0] is not None and cached_params == cache[0]:
if len(cache) == 3:
generation_params_states, cached_cached_params = cache[2]
if cached_params == cached_cached_params:
self.apply_generation_params_states(generation_params_states)
return cache[1]

cache = caches[0]

with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)

generation_params_states = model_hijack.extract_generation_params_states()
self.apply_generation_params_states(generation_params_states)
if len(cache) == 2:
cache.append((generation_params_states, cached_params))
else:
cache[2] = (generation_params_states, cached_params)

cache[0] = cached_params
return cache[1]

Expand All @@ -502,6 +521,8 @@ def setup_conds(self):
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)

self.extra_generation_params.update(model_hijack.extra_generation_params)

def get_conds(self):
return self.c, self.uc

Expand Down Expand Up @@ -801,10 +822,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter

for key, value in generation_params.items():
try:
if isinstance(value, list):
generation_params[key] = value[index]
elif callable(value):
if callable(value):
generation_params[key] = value(**locals())
elif isinstance(value, list):
generation_params[key] = value[index]
except Exception:
errors.report(f'Error creating infotext for key "{key}"', exc_info=True)
generation_params[key] = None
Expand Down Expand Up @@ -965,8 +986,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:

p.setup_conds()

p.extra_generation_params.update(model_hijack.extra_generation_params)

# params.txt should be saved after scripts.process_batch, since the
# infotext could be modified by that callback
# Example: a wildcard processed by process_batch sets an extra model
Expand Down Expand Up @@ -1513,6 +1532,8 @@ def calculate_hr_conds(self):
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)

self.extra_generation_params.update(model_hijack.extra_generation_params)

def setup_conds(self):
if self.is_hr_pass:
# if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model
Expand Down
10 changes: 9 additions & 1 deletion modules/sd_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch.nn.functional import silu
from types import MethodType

from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches, util
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
Expand Down Expand Up @@ -321,6 +321,14 @@ def clear_comments(self):
self.comments = []
self.extra_generation_params = {}

def extract_generation_params_states(self):
"""Extracts GenerationParametersList so that they can be cached and restored later"""
states = {}
for key in list(self.extra_generation_params):
if isinstance(self.extra_generation_params[key], util.GenerationParametersList):
states[key] = self.extra_generation_params.pop(key)
return states

def get_prompt_lengths(self, text):
if self.clip is None:
return "-", "-"
Expand Down
34 changes: 28 additions & 6 deletions modules/sd_hijack_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from modules import prompt_parser, devices, sd_hijack, sd_emphasis
from modules import prompt_parser, devices, sd_hijack, sd_emphasis, util
from modules.shared import opts


Expand All @@ -27,6 +27,30 @@ def __init__(self):
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""


class EmphasisMode(util.GenerationParametersList):
def __init__(self, emphasis_mode:str = None):
super().__init__()
self.emphasis_mode = emphasis_mode

def __call__(self, *args, **kwargs):
return self.emphasis_mode

def __add__(self, other):
if isinstance(other, EmphasisMode):
return self if self.emphasis_mode else other
elif isinstance(other, str):
return self.__str__() + other
return NotImplemented

def __radd__(self, other):
if isinstance(other, str):
return other + self.__str__()
return NotImplemented

def __str__(self):
return self.emphasis_mode if self.emphasis_mode else ''


class TextConditionalModel(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -238,12 +262,10 @@ def forward(self, texts):
hashes.append(f"{name}: {shorthash}")

if hashes:
if self.hijack.extra_generation_params.get("TI hashes"):
hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
self.hijack.extra_generation_params["TI hashes"] = util.GenerationParametersList(hashes)

if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":
self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
if opts.emphasis != 'Original' and any(x for x in texts if '(' in x or '[' in x):
self.hijack.extra_generation_params["Emphasis"] = EmphasisMode(opts.emphasis)

if self.return_pooled:
return torch.hstack(zs), zs[0].pooled
Expand Down
38 changes: 38 additions & 0 deletions modules/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,41 @@ def compare_sha256(file_path: str, hash_prefix: str) -> bool:
for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower())


class GenerationParametersList(list):
"""A special object used in sd_hijack.StableDiffusionModelHijack for setting extra_generation_params
due to StableDiffusionProcessing.get_conds_with_caching
extra_generation_params set in StableDiffusionModelHijack will be lost when cached is used
When an extra_generation_params is set in StableDiffusionModelHijack using this object,
the params will be extracted by StableDiffusionModelHijack.extract_generation_params_states
the extracted params will be cached in StableDiffusionProcessing.get_conds_with_caching
and applyed to StableDiffusionProcessing.extra_generation_params by StableDiffusionProcessing.apply_generation_params_states
Example see modules.sd_hijack_clip.TextConditionalModel.hijack.extra_generation_params 'TI hashes' 'Emphasis'
Depending on the use case the methods can be overwritten.
In general __call__ method should return str or None, as normally it's called in modules.processing.create_infotext.
When called by create_infotext it will access to the locals() of the caller,
if return str, the value will be written to infotext, if return None will be ignored.
"""

def __call__(self, *args, **kwargs):
return ', '.join(sorted(set(self), key=natural_sort_key))

def __add__(self, other):
if isinstance(other, GenerationParametersList):
return self.__class__([*self, *other])
elif isinstance(other, str):
return self.__str__() + other
return NotImplemented

def __radd__(self, other):
if isinstance(other, str):
return other + self.__str__()
return NotImplemented

def __str__(self):
return self.__call__()

0 comments on commit 0250802

Please sign in to comment.