Skip to content

Commit

Permalink
Add Stable Diffusion XL support (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
daemon authored Jan 7, 2024
1 parent b32e133 commit 09564e0
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 29 deletions.
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

![example image](example.jpg)

### Updated to support Diffusers 0.16.1!
### Updated to support Stable Diffusion XL (SDXL) and Diffusers 0.21.1!

I regularly update this codebase. Please submit an issue if you have any questions.

Expand Down Expand Up @@ -33,30 +33,31 @@ dog.heat_map.png running.heat_map.png prompt.txt
```
Your current working directory will now contain the generated image as `output.png` and a DAAM map for every word, as well as some auxiliary data.
You can see more options for `daam` by running `daam -h`.
To use Stable Diffusion XL as the backend, run `daam --model xl-base-1.0 "Dog jumping"`.

### Using DAAM as a Library

Import and use DAAM as follows:

```python
from daam import trace, set_seed
from diffusers import StableDiffusionPipeline
from diffusers import DiffusionPipeline
from matplotlib import pyplot as plt
import torch


model_id = 'stabilityai/stable-diffusion-2-base'
model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
device = 'cuda'

pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True)
pipe = DiffusionPipeline.from_pretrained(model_id, use_auth_token=True, torch_dtype=torch.float16, use_safetensors=True, variant='fp16')
pipe = pipe.to(device)

prompt = 'A dog runs across the field'
gen = set_seed(0) # for reproducibility

with torch.cuda.amp.autocast(dtype=torch.float16), torch.no_grad():
with torch.no_grad():
with trace(pipe) as tc:
out = pipe(prompt, num_inference_steps=30, generator=gen)
out = pipe(prompt, num_inference_steps=50, generator=gen)
heat_map = tc.compute_global_heat_map()
heat_map = heat_map.compute_word_heat_map('dog')
heat_map.plot_overlay(out.images[0])
Expand Down
2 changes: 1 addition & 1 deletion daam/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.0'
__version__ = '0.2.0'
10 changes: 7 additions & 3 deletions daam/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,13 @@ def unhook(self):

return self

def monkey_patch(self, fn_name, fn):
self.old_state[f'old_fn_{fn_name}'] = getattr(self.module, fn_name)
setattr(self.module, fn_name, functools.partial(fn, self.module))
def monkey_patch(self, fn_name, fn, strict: bool = True):
try:
self.old_state[f'old_fn_{fn_name}'] = getattr(self.module, fn_name)
setattr(self.module, fn_name, functools.partial(fn, self.module))
except AttributeError:
if strict:
raise

def monkey_super(self, fn_name, *args, **kwargs):
return self.old_state[f'old_fn_{fn_name}'](*args, **kwargs)
Expand Down
19 changes: 15 additions & 4 deletions daam/run/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time

import pandas as pd
from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionPipeline, DiffusionPipeline
from tqdm import tqdm
import inflect
import numpy as np
Expand All @@ -25,7 +25,8 @@ def main():
'v2-base': 'stabilityai/stable-diffusion-2-base',
'v2-large': 'stabilityai/stable-diffusion-2',
'v2-1-base': 'stabilityai/stable-diffusion-2-1-base',
'v2-1-large': 'stabilityai/stable-diffusion-2-1'
'v2-1-large': 'stabilityai/stable-diffusion-2-1',
'xl-base-1.0': 'stabilityai/stable-diffusion-xl-base-1.0',
}

parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -192,10 +193,20 @@ def main():
prompts = new_prompts

prompts = prompts[:args.gen_limit]
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True)

if 'xl' in model_id:
pipe = DiffusionPipeline.from_pretrained(
model_id,
use_auth_token=True,
torch_dtype=torch.float16,
use_safetensors=True, variant='fp16'
)
else:
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True)

pipe = auto_device(pipe)

with auto_autocast(dtype=torch.float16), torch.no_grad():
with torch.no_grad():
for gen_idx, (prompt_id, prompt) in enumerate(tqdm(prompts)):
seed = int(time.time()) if args.random_seed else args.seed
prompt = prompt.replace(',', ' ,').replace('.', ' .').strip()
Expand Down
47 changes: 36 additions & 11 deletions daam/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import List, Type, Any, Dict, Tuple, Union
import math

from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import Attention
import numpy as np
import PIL.Image as Image
Expand All @@ -21,16 +22,15 @@
class DiffusionHeatMapHooker(AggregateHooker):
def __init__(
self,
pipeline:
StableDiffusionPipeline,
pipeline: Union[StableDiffusionPipeline, StableDiffusionXLPipeline],
low_memory: bool = False,
load_heads: bool = False,
save_heads: bool = False,
data_dir: str = None
):
self.all_heat_maps = RawHeatMapCollection()
h = (pipeline.unet.config.sample_size * pipeline.vae_scale_factor)
self.latent_hw = 4096 if h == 512 else 9216 # 64x64 or 96x96 depending on if it's 2.0-v or 2.0
self.latent_hw = 4096 if h == 512 or h == 1024 else 9216 # 64x64 or 96x96 depending on if it's 2.0-v or 2.0
locate_middle = load_heads or save_heads
self.locator = UNetCrossAttentionLocator(restrict={0} if low_memory else None, locate_middle_block=locate_middle)
self.last_prompt: str = ''
Expand All @@ -52,6 +52,9 @@ def __init__(

modules.append(PipelineHooker(pipeline, self))

if type(pipeline) == StableDiffusionXLPipeline:
modules.append(ImageProcessorHooker(pipeline.image_processor, self))

super().__init__(modules)
self.pipe = pipeline

Expand Down Expand Up @@ -129,6 +132,21 @@ def compute_global_heat_map(self, prompt=None, factors=None, head_idx=None, laye
return GlobalHeatMap(self.pipe.tokenizer, prompt, maps)


class ImageProcessorHooker(ObjectHooker[VaeImageProcessor]):
def __init__(self, processor: VaeImageProcessor, parent_trace: 'trace'):
super().__init__(processor)
self.parent_trace = parent_trace

def _hooked_postprocess(hk_self, _: VaeImageProcessor, *args, **kwargs):
images = hk_self.monkey_super('postprocess', *args, **kwargs)
hk_self.parent_trace.last_image = images[0]

return images

def _hook_impl(self):
self.monkey_patch('postprocess', self._hooked_postprocess)


class PipelineHooker(ObjectHooker[StableDiffusionPipeline]):
def __init__(self, pipeline: StableDiffusionPipeline, parent_trace: 'trace'):
super().__init__(pipeline)
Expand All @@ -137,12 +155,20 @@ def __init__(self, pipeline: StableDiffusionPipeline, parent_trace: 'trace'):

def _hooked_run_safety_checker(hk_self, self: StableDiffusionPipeline, image, *args, **kwargs):
image, has_nsfw = hk_self.monkey_super('run_safety_checker', image, *args, **kwargs)
pil_image = self.numpy_to_pil(image)
hk_self.parent_trace.last_image = pil_image[0]

if self.image_processor:
if torch.is_tensor(image):
images = self.image_processor.postprocess(image, output_type='pil')
else:
images = self.image_processor.numpy_to_pil(image)
else:
images = self.numpy_to_pil(image)

hk_self.parent_trace.last_image = images[len(images)-1]

return image, has_nsfw

def _hooked_encode_prompt(hk_self, _: StableDiffusionPipeline, prompt: Union[str, List[str]], *args, **kwargs):
def _hooked_check_inputs(hk_self, _: StableDiffusionPipeline, prompt: Union[str, List[str]], *args, **kwargs):
if not isinstance(prompt, str) and len(prompt) > 1:
raise ValueError('Only single prompt generation is supported for heat map computation.')
elif not isinstance(prompt, str):
Expand All @@ -152,13 +178,12 @@ def _hooked_encode_prompt(hk_self, _: StableDiffusionPipeline, prompt: Union[str

hk_self.heat_maps.clear()
hk_self.parent_trace.last_prompt = last_prompt
ret = hk_self.monkey_super('_encode_prompt', prompt, *args, **kwargs)

return ret
return hk_self.monkey_super('check_inputs', prompt, *args, **kwargs)

def _hook_impl(self):
self.monkey_patch('run_safety_checker', self._hooked_run_safety_checker)
self.monkey_patch('_encode_prompt', self._hooked_encode_prompt)
self.monkey_patch('run_safety_checker', self._hooked_run_safety_checker, strict=False) # not present in SDXL
self.monkey_patch('check_inputs', self._hooked_check_inputs)


class UNetCrossAttentionHooker(ObjectHooker[Attention]):
Expand Down
6 changes: 5 additions & 1 deletion daam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,16 @@ def cache_dir() -> Path:
def compute_token_merge_indices(tokenizer, prompt: str, word: str, word_idx: int = None, offset_idx: int = 0):
merge_idxs = []
tokens = tokenizer.tokenize(prompt.lower())
tokens = [x.replace('</w>', '') for x in tokens] # New tokenizer uses wordpiece markers.

if word_idx is None:
word = word.lower()
search_tokens = tokenizer.tokenize(word)
search_tokens = [x.replace('</w>', '') for x in tokenizer.tokenize(word)] # New tokenizer uses wordpiece markers.
start_indices = [x + offset_idx for x in range(len(tokens)) if tokens[x:x + len(search_tokens)] == search_tokens]

for indice in start_indices:
merge_idxs += [i + indice for i in range(0, len(search_tokens))]

if not merge_idxs:
raise ValueError(f'Search word {word} not found in prompt!')
else:
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
scikit-image
diffusers==0.16.1
diffusers==0.21.2
spacy
gradio
ftfy
transformers==4.27.4
transformers==4.30.2
pandas
numba
nltk
inflect
joblib
accelerate==0.18.0
accelerate==0.23.0

0 comments on commit 09564e0

Please sign in to comment.