Skip to content

Commit

Permalink
clear GenerationParametersList before batch
Browse files Browse the repository at this point in the history
clears any generation parameters that are with the attribute tag_to_be_cleared_before_batch = True
prevent buildup of some parameters
  • Loading branch information
w-e-w committed Nov 24, 2024
1 parent 0250802 commit 541d5e8
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
13 changes: 10 additions & 3 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,14 +457,20 @@ def cached_params(self, required_prompts, steps, extra_network_data, hires_steps
opts.emphasis,
)

def apply_generation_params_states(self, generation_params_states):
def apply_generation_params_list(self, generation_params_states):
"""add and apply generation_params_states to self.extra_generation_params"""
for key, value in generation_params_states.items():
if key in self.extra_generation_params and isinstance(current_value := self.extra_generation_params[key], util.GenerationParametersList):
self.extra_generation_params[key] = current_value + value
else:
self.extra_generation_params[key] = value

def clear_marked_generation_params(self):
"""clears any generation parameters that are with the attribute tag_to_be_cleared_before_batch = True"""
for key, value in list(self.extra_generation_params.items()):
if getattr(value, 'tag_to_be_cleared_before_batch', False):
self.extra_generation_params.pop(key)

def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
"""
Returns the result of calling function(shared.sd_model, required_prompts, steps)
Expand All @@ -491,7 +497,7 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr
if len(cache) == 3:
generation_params_states, cached_cached_params = cache[2]
if cached_params == cached_cached_params:
self.apply_generation_params_states(generation_params_states)
self.apply_generation_params_list(generation_params_states)
return cache[1]

cache = caches[0]
Expand All @@ -500,7 +506,7 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)

generation_params_states = model_hijack.extract_generation_params_states()
self.apply_generation_params_states(generation_params_states)
self.apply_generation_params_list(generation_params_states)
if len(cache) == 2:
cache.append((generation_params_states, cached_params))
else:
Expand Down Expand Up @@ -959,6 +965,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.interrupted or state.stopping_generation:
break

p.clear_marked_generation_params() # clean up some generation params are tagged to be cleared before batch
sd_models.reload_model_weights() # model can be changed for example by refiner

p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
Expand Down
8 changes: 8 additions & 0 deletions modules/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,17 @@ class GenerationParametersList(list):
if return str, the value will be written to infotext, if return None will be ignored.
"""

def __init__(self, *args, to_be_clear_before_batch=True, **kwargs):
super().__init__(*args, **kwargs)
self._to_be_clear_before_batch = to_be_clear_before_batch

def __call__(self, *args, **kwargs):
return ', '.join(sorted(set(self), key=natural_sort_key))

@property
def to_be_clear_before_batch(self):
return self._to_be_clear_before_batch

def __add__(self, other):
if isinstance(other, GenerationParametersList):
return self.__class__([*self, *other])
Expand Down

0 comments on commit 541d5e8

Please sign in to comment.