From 541d5e86028189ec37f83e463a399dda209b5c21 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 24 Nov 2024 20:07:00 +0900 Subject: [PATCH] clear GenerationParametersList before batch clears any generation parameters that are with the attribute tag_to_be_cleared_before_batch = True prevent buildup of some parameters --- modules/processing.py | 13 ++++++++++--- modules/util.py | 8 ++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 0c747601f7e..a58f6a028d8 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -457,7 +457,7 @@ def cached_params(self, required_prompts, steps, extra_network_data, hires_steps opts.emphasis, ) - def apply_generation_params_states(self, generation_params_states): + def apply_generation_params_list(self, generation_params_states): """add and apply generation_params_states to self.extra_generation_params""" for key, value in generation_params_states.items(): if key in self.extra_generation_params and isinstance(current_value := self.extra_generation_params[key], util.GenerationParametersList): @@ -465,6 +465,12 @@ def apply_generation_params_states(self, generation_params_states): else: self.extra_generation_params[key] = value + def clear_marked_generation_params(self): + """clears any generation parameters that are with the attribute tag_to_be_cleared_before_batch = True""" + for key, value in list(self.extra_generation_params.items()): + if getattr(value, 'tag_to_be_cleared_before_batch', False): + self.extra_generation_params.pop(key) + def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None): """ Returns the result of calling function(shared.sd_model, required_prompts, steps) @@ -491,7 +497,7 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr if len(cache) == 3: generation_params_states, cached_cached_params = cache[2] if cached_params == cached_cached_params: - self.apply_generation_params_states(generation_params_states) + self.apply_generation_params_list(generation_params_states) return cache[1] cache = caches[0] @@ -500,7 +506,7 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling) generation_params_states = model_hijack.extract_generation_params_states() - self.apply_generation_params_states(generation_params_states) + self.apply_generation_params_list(generation_params_states) if len(cache) == 2: cache.append((generation_params_states, cached_params)) else: @@ -959,6 +965,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if state.interrupted or state.stopping_generation: break + p.clear_marked_generation_params() # clean up some generation params are tagged to be cleared before batch sd_models.reload_model_weights() # model can be changed for example by refiner p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] diff --git a/modules/util.py b/modules/util.py index 1aef93cfef6..4fe7da1109a 100644 --- a/modules/util.py +++ b/modules/util.py @@ -308,9 +308,17 @@ class GenerationParametersList(list): if return str, the value will be written to infotext, if return None will be ignored. """ + def __init__(self, *args, to_be_clear_before_batch=True, **kwargs): + super().__init__(*args, **kwargs) + self._to_be_clear_before_batch = to_be_clear_before_batch + def __call__(self, *args, **kwargs): return ', '.join(sorted(set(self), key=natural_sort_key)) + @property + def to_be_clear_before_batch(self): + return self._to_be_clear_before_batch + def __add__(self, other): if isinstance(other, GenerationParametersList): return self.__class__([*self, *other])