From 45feb6cb9cd7c327a79f77bb5da2ed1b0e886078 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 2 Aug 2023 23:14:30 +0000 Subject: [PATCH 01/56] Use wrapper correctly in refiner helper --- sgm/inference/api.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 12efc064..64c1b022 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -268,6 +268,12 @@ def refiner( "negative_aesthetic_score": 2.5, } + if params.img2img_strength < 1.0: + sampler.discretization = Img2ImgDiscretizationWrapper( + sampler.discretization, + strength=params.img2img_strength, + ) + return do_img2img( image, self.model, From 853adb402252d2162adffe806381e9636d93d8e5 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 3 Aug 2023 12:50:23 -0700 Subject: [PATCH 02/56] Add defaults to refiner function --- sgm/inference/api.py | 8 +++++--- tests/inference/test_inference.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 64c1b022..be4b2455 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -55,7 +55,7 @@ class Thresholder(str, Enum): class SamplingParams: width: int = 1024 height: int = 1024 - steps: int = 50 + steps: int = 40 sampler: Sampler = Sampler.DPMPP2M discretization: Discretization = Discretization.LEGACY_DDPM guider: Guider = Guider.VANILLA @@ -247,10 +247,12 @@ def image_to_image( def refiner( self, - params: SamplingParams, image, prompt: str, - negative_prompt: Optional[str] = None, + negative_prompt: str = "", + params: SamplingParams = SamplingParams( + sampler=Sampler.EULER_EDM, steps=40, img2img_strength=0.2 + ), samples: int = 1, return_latents: bool = False, ): diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index 2b2af11e..ae6f3550 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -102,10 +102,10 @@ def test_sdxl_with_refiner( samples, samples_z = output assert samples is not None assert samples_z is not None - refiner_pipeline.refiner( - params=SamplingParams(sampler=sampler_enum.value, steps=10), + refiner_pipeline.refiner( image=samples_z, prompt="A professional photograph of an astronaut riding a pig", + params=SamplingParams(sampler=sampler_enum.value, steps=40, img2img_strength=0.20), negative_prompt="", samples=1, ) From 73287ec3a30109020e57996612399c99ca759b39 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 3 Aug 2023 23:42:11 +0000 Subject: [PATCH 03/56] Extract method for img2img wrapper --- sgm/inference/api.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index be4b2455..ec17dfe6 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -223,11 +223,10 @@ def image_to_image( ): sampler = get_sampler_config(params) - if params.img2img_strength < 1.0: - sampler.discretization = Img2ImgDiscretizationWrapper( - sampler.discretization, - strength=params.img2img_strength, - ) + sampler.discretization = self.wrap_discretization( + sampler.discretization, strength=params.img2img_strength + ) + height, width = image.shape[2], image.shape[3] value_dict = asdict(params) value_dict["prompt"] = prompt @@ -245,6 +244,14 @@ def image_to_image( filter=None, ) + def wrap_discretization(self, discretization, strength=1.0): + if ( + not isinstance(discretization, Img2ImgDiscretizationWrapper) + and strength < 1.0 + ): + return Img2ImgDiscretizationWrapper(discretization, strength=strength) + return discretization + def refiner( self, image, @@ -270,11 +277,9 @@ def refiner( "negative_aesthetic_score": 2.5, } - if params.img2img_strength < 1.0: - sampler.discretization = Img2ImgDiscretizationWrapper( - sampler.discretization, - strength=params.img2img_strength, - ) + sampler.discretization = self.wrap_discretization( + sampler.discretization, strength=params.img2img_strength + ) return do_img2img( image, From 44943df4f218cda2265a0f7bdf4db4ffa501a504 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 3 Aug 2023 23:59:42 +0000 Subject: [PATCH 04/56] Allow loading custom models and improve path logic --- sgm/inference/api.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index ec17dfe6..182e8cfe 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, asdict from enum import Enum from omegaconf import OmegaConf +import os import pathlib from sgm.inference.helpers import ( do_sample, @@ -158,18 +159,33 @@ class SamplingSpec: class SamplingPipeline: def __init__( self, - model_id: ModelArchitecture, - model_path="checkpoints", - config_path="configs/inference", + model_id: Optional[ModelArchitecture] = None, + model_spec: Optional[SamplingSpec] = None, + model_path=None, + config_path=None, device="cuda", use_fp16=True, - ) -> None: - if model_id not in model_specs: - raise ValueError(f"Model {model_id} not supported") + ) -> None: self.model_id = model_id - self.specs = model_specs[self.model_id] - self.config = str(pathlib.Path(config_path, self.specs.config)) - self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt)) + if model_spec is not None: + self.specs = model_spec + elif model_id is not None: + if model_id not in model_specs: + raise ValueError(f"Model {model_id} not supported") + self.specs = model_specs[model_id] + else: + raise ValueError("Either model_id or model_spec should be provided") + + if model_path is None: + model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints" + if config_path is None: + config_path = pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" + self.config = str(config_path / self.specs.config) + self.ckpt = str(model_path / self.specs.ckpt) + if not os.path.exists(self.config): + raise ValueError(f"Config {self.config} not found, check model spec or config_path") + if not os.path.exists(self.ckpt): + raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path") self.device = device self.model = self._load_model(device=device, use_fp16=use_fp16) From baf79d2d79262330ff09ec8923f3ed5d864562d4 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Fri, 4 Aug 2023 00:00:51 +0000 Subject: [PATCH 05/56] black --- sgm/inference/api.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 182e8cfe..f5ce36f3 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -165,7 +165,7 @@ def __init__( config_path=None, device="cuda", use_fp16=True, - ) -> None: + ) -> None: self.model_id = model_id if model_spec is not None: self.specs = model_spec @@ -179,13 +179,19 @@ def __init__( if model_path is None: model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints" if config_path is None: - config_path = pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" + config_path = ( + pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" + ) self.config = str(config_path / self.specs.config) self.ckpt = str(model_path / self.specs.ckpt) if not os.path.exists(self.config): - raise ValueError(f"Config {self.config} not found, check model spec or config_path") + raise ValueError( + f"Config {self.config} not found, check model spec or config_path" + ) if not os.path.exists(self.ckpt): - raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path") + raise ValueError( + f"Checkpoint {self.ckpt} not found, check model spec or config_path" + ) self.device = device self.model = self._load_model(device=device, use_fp16=use_fp16) From 4e2236f67d976664852e41110428f4823f642c54 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Fri, 4 Aug 2023 00:15:22 +0000 Subject: [PATCH 06/56] Fix path logic for development installs --- sgm/inference/api.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index f5ce36f3..f81e7903 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -185,6 +185,11 @@ def __init__( self.config = str(config_path / self.specs.config) self.ckpt = str(model_path / self.specs.ckpt) if not os.path.exists(self.config): + # This supports development installs where configs is root level of the repo + if config_path is None: + config_path = ( + pathlib.Path(__file__).parent.parent.parent.resolve() / "configs/inference" + ) raise ValueError( f"Config {self.config} not found, check model spec or config_path" ) From 19fa4da3de76d705a49ba565125950da87ed225e Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Fri, 4 Aug 2023 00:16:29 +0000 Subject: [PATCH 07/56] run black again --- sgm/inference/api.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index f81e7903..285fe9e3 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -188,7 +188,8 @@ def __init__( # This supports development installs where configs is root level of the repo if config_path is None: config_path = ( - pathlib.Path(__file__).parent.parent.parent.resolve() / "configs/inference" + pathlib.Path(__file__).parent.parent.parent.resolve() + / "configs/inference" ) raise ValueError( f"Config {self.config} not found, check model spec or config_path" From 84d3a7f6f5b56264f8ae5737e21c15ea78b4b6bc Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 3 Aug 2023 17:50:10 -0700 Subject: [PATCH 08/56] fix fallback logic for config path --- sgm/inference/api.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 285fe9e3..a3205003 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -182,15 +182,12 @@ def __init__( config_path = ( pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" ) + if not os.path.exists(config_path): + # This supports development installs where configs is root level of the repo + config_path = pathlib.Path(__file__).parent.parent.parent.resolve() / "configs/inference" self.config = str(config_path / self.specs.config) self.ckpt = str(model_path / self.specs.ckpt) if not os.path.exists(self.config): - # This supports development installs where configs is root level of the repo - if config_path is None: - config_path = ( - pathlib.Path(__file__).parent.parent.parent.resolve() - / "configs/inference" - ) raise ValueError( f"Config {self.config} not found, check model spec or config_path" ) From 4aea6fa2a47c55752137a9a9ecaa46eebf58e366 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 3 Aug 2023 17:56:24 -0700 Subject: [PATCH 09/56] Fix checkpoint loading too --- sgm/inference/api.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index a3205003..2b5500ca 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -178,6 +178,9 @@ def __init__( if model_path is None: model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints" + if not os.path.exists(model_path): + # This supports development installs where checkpoints is root level of the repo + model_path = pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" if config_path is None: config_path = ( pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" From 77d0e27747464c465df8b57ee78a599080c1106b Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 3 Aug 2023 17:57:55 -0700 Subject: [PATCH 10/56] format --- sgm/inference/api.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 2b5500ca..4f7c78bf 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -180,14 +180,20 @@ def __init__( model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints" if not os.path.exists(model_path): # This supports development installs where checkpoints is root level of the repo - model_path = pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" + model_path = ( + pathlib.Path(__file__).parent.parent.parent.resolve() + / "checkpoints" + ) if config_path is None: config_path = ( pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" ) if not os.path.exists(config_path): # This supports development installs where configs is root level of the repo - config_path = pathlib.Path(__file__).parent.parent.parent.resolve() / "configs/inference" + config_path = ( + pathlib.Path(__file__).parent.parent.parent.resolve() + / "configs/inference" + ) self.config = str(config_path / self.specs.config) self.ckpt = str(model_path / self.specs.ckpt) if not os.path.exists(self.config): From b216934b7e34aa8b34b589a9408c1a0ae25e283f Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 6 Aug 2023 11:20:22 +0000 Subject: [PATCH 11/56] align with streamlit helpers and re-de-deuplicate --- scripts/demo/sampling.py | 24 +- scripts/demo/streamlit_helpers.py | 430 ++---------------------------- sgm/inference/api.py | 35 +-- sgm/inference/helpers.py | 119 ++++++--- 4 files changed, 140 insertions(+), 468 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 2984dbf7..2d9a62ef 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -1,5 +1,11 @@ from pytorch_lightning import seed_everything +from sgm.inference.helpers import ( + do_img2img, + do_sample, + get_unique_embedder_keys_from_conditioner, + perform_save_locally, +) from scripts.demo.streamlit_helpers import * SAVE_PATH = "outputs/demo/txt2img/" @@ -99,9 +105,7 @@ def load_img(display=True, key=None, device="cuda"): st.image(image) w, h = image.size print(f"loaded input image of size ({w}, {h})") - width, height = map( - lambda x: x - x % 64, (w, h) - ) # resize to integer multiple of 64 + width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 image = image.resize((width, height)) image = np.array(image.convert("RGB")) image = image[None].transpose(0, 3, 1, 2) @@ -143,6 +147,8 @@ def run_txt2img( if st.button("Sample"): st.write(f"**Model I:** {version}") + outputs = st.empty() + st.text("Sampling") out = do_sample( state["model"], sampler, @@ -156,6 +162,9 @@ def run_txt2img( return_latents=return_latents, filter=filter, ) + + show_samples(out, outputs) + return out @@ -184,9 +193,7 @@ def run_img2img( prompt=prompt, negative_prompt=negative_prompt, ) - strength = st.number_input( - "**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0 - ) + strength = st.number_input("**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0) sampler, num_rows, num_cols = init_sampling( img2img_strength=strength, stage2strength=stage2strength, @@ -194,6 +201,8 @@ def run_img2img( num_samples = num_rows * num_cols if st.button("Sample"): + outputs = st.empty() + st.text("Sampling") out = do_img2img( repeat(img, "1 ... -> n ...", n=num_samples), state["model"], @@ -204,6 +213,7 @@ def run_img2img( return_latents=return_latents, filter=filter, ) + show_samples(out, outputs) return out @@ -342,6 +352,7 @@ def apply_refiner( samples_z = None if add_pipeline and samples_z is not None: + outputs = st.empty() st.write("**Running Refinement Stage**") samples = apply_refiner( samples_z, @@ -353,6 +364,7 @@ def apply_refiner( filter=state.get("filter"), finish_denoising=finish_denoising, ) + show_samples(samples, outputs) if save_locally and samples is not None: perform_save_locally(save_path, samples) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 82b7fb9c..fa104b9c 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -1,18 +1,13 @@ -import math import os -from typing import List, Union import numpy as np import streamlit as st import torch from einops import rearrange, repeat -from imwatermark import WatermarkEncoder -from omegaconf import ListConfig, OmegaConf +from omegaconf import OmegaConf from PIL import Image -from safetensors.torch import load_file as load_safetensors -from torch import autocast from torchvision import transforms -from torchvision.utils import make_grid + from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from sgm.modules.diffusionmodules.sampling import ( @@ -23,52 +18,12 @@ HeunEDMSampler, LinearMultistepSampler, ) -from sgm.util import append_dims, instantiate_from_config - - -class WatermarkEmbedder: - def __init__(self, watermark): - self.watermark = watermark - self.num_bits = len(WATERMARK_BITS) - self.encoder = WatermarkEncoder() - self.encoder.set_watermark("bits", self.watermark) - - def __call__(self, image: torch.Tensor): - """ - Adds a predefined watermark to the input image - - Args: - image: ([N,] B, C, H, W) in range [0, 1] - - Returns: - same as input but watermarked - """ - # watermarking libary expects input as cv2 BGR format - squeeze = len(image.shape) == 4 - if squeeze: - image = image[None, ...] - n = image.shape[0] - image_np = rearrange( - (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" - ).numpy()[:, :, :, ::-1] - # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] - for k in range(image_np.shape[0]): - image_np[k] = self.encoder.encode(image_np[k], "dwtDct") - image = torch.from_numpy( - rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n) - ).to(image.device) - image = torch.clamp(image / 255, min=0.0, max=1.0) - if squeeze: - image = image[0] - return image - - -# A fixed 48-bit message that was choosen at random -# WATERMARK_MESSAGE = 0xB3EC907BB19E -WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 -# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 -WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] -embed_watemark = WatermarkEmbedder(WATERMARK_BITS) +from sgm.inference.helpers import ( + Img2ImgDiscretizationWrapper, + Txt2NoisyDiscretizationWrapper, + embed_watermark, +) +from sgm.util import load_model_from_config @st.cache_resource() @@ -79,9 +34,8 @@ def init_st(version_dict, load_ckpt=True, load_filter=True): ckpt = version_dict["ckpt"] config = OmegaConf.load(config) - model, msg = load_model_from_config(config, ckpt if load_ckpt else None) + model = load_model_from_config(config, ckpt if load_ckpt else None, freeze=False) - state["msg"] = msg state["model"] = model state["ckpt"] = ckpt if load_ckpt else None state["config"] = config @@ -90,10 +44,6 @@ def init_st(version_dict, load_ckpt=True, load_filter=True): return state -def load_model(model): - model.cuda() - - lowvram_mode = False @@ -111,48 +61,6 @@ def initial_model_load(model): return model -def unload_model(model): - global lowvram_mode - if lowvram_mode: - model.cpu() - torch.cuda.empty_cache() - - -def load_model_from_config(config, ckpt=None, verbose=True): - model = instantiate_from_config(config.model) - - if ckpt is not None: - print(f"Loading model from {ckpt}") - if ckpt.endswith("ckpt"): - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - global_step = pl_sd["global_step"] - st.info(f"loaded ckpt from global step {global_step}") - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - elif ckpt.endswith("safetensors"): - sd = load_safetensors(ckpt) - else: - raise NotImplementedError - - msg = None - - m, u = model.load_state_dict(sd, strict=False) - - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - else: - msg = None - - model = initial_model_load(model) - model.eval() - return model, msg - - def get_unique_embedder_keys_from_conditioner(conditioner): return list(set([x.input_key for x in conditioner.embedders])) @@ -209,7 +117,7 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): def perform_save_locally(save_path, samples): os.makedirs(os.path.join(save_path), exist_ok=True) base_count = len(os.listdir(os.path.join(save_path))) - samples = embed_watemark(samples) + samples = embed_watermark(samples) for sample in samples: sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") Image.fromarray(sample.astype(np.uint8)).save( @@ -228,58 +136,12 @@ def init_save_locally(_dir, init_value: bool = False): return save_locally, save_path -class Img2ImgDiscretizationWrapper: - """ - wraps a discretizer, and prunes the sigmas - params: - strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) - """ - - def __init__(self, discretization, strength: float = 1.0): - self.discretization = discretization - self.strength = strength - assert 0.0 <= self.strength <= 1.0 - - def __call__(self, *args, **kwargs): - # sigmas start large first, and decrease then - sigmas = self.discretization(*args, **kwargs) - print(f"sigmas after discretization, before pruning img2img: ", sigmas) - sigmas = torch.flip(sigmas, (0,)) - sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] - print("prune index:", max(int(self.strength * len(sigmas)), 1)) - sigmas = torch.flip(sigmas, (0,)) - print(f"sigmas after pruning: ", sigmas) - return sigmas - - -class Txt2NoisyDiscretizationWrapper: - """ - wraps a discretizer, and prunes the sigmas - params: - strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned) - """ - - def __init__(self, discretization, strength: float = 0.0, original_steps=None): - self.discretization = discretization - self.strength = strength - self.original_steps = original_steps - assert 0.0 <= self.strength <= 1.0 - - def __call__(self, *args, **kwargs): - # sigmas start large first, and decrease then - sigmas = self.discretization(*args, **kwargs) - print(f"sigmas after discretization, before pruning img2img: ", sigmas) - sigmas = torch.flip(sigmas, (0,)) - if self.original_steps is None: - steps = len(sigmas) - else: - steps = self.original_steps + 1 - prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0) - sigmas = sigmas[prune_index:] - print("prune index:", prune_index) - sigmas = torch.flip(sigmas, (0,)) - print(f"sigmas after pruning: ", sigmas) - return sigmas +def show_samples(samples, outputs): + if isinstance(samples, tuple): + samples, _ = samples + grid = embed_watermark(torch.stack([samples])) + grid = rearrange(grid, "n b c h w -> (n h) (b w) c") + outputs.image(grid.cpu().numpy()) def get_guider(key): @@ -292,13 +154,9 @@ def get_guider(key): ) if guider == "IdentityGuider": - guider_config = { - "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" - } + guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} elif guider == "VanillaCFG": - scale = st.number_input( - f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0 - ) + scale = st.number_input(f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0) thresholder = st.sidebar.selectbox( f"Thresholder #{key}", @@ -331,13 +189,9 @@ def init_sampling( ): num_rows, num_cols = 1, 1 if specify_num_samples: - num_cols = st.number_input( - f"num cols #{key}", value=2, min_value=1, max_value=10 - ) + num_cols = st.number_input(f"num cols #{key}", value=2, min_value=1, max_value=10) - steps = st.sidebar.number_input( - f"steps #{key}", value=40, min_value=1, max_value=1000 - ) + steps = st.sidebar.number_input(f"steps #{key}", value=40, min_value=1, max_value=1000) sampler = st.sidebar.selectbox( f"Sampler #{key}", [ @@ -364,9 +218,7 @@ def init_sampling( sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key) if img2img_strength < 1.0: - st.warning( - f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper" - ) + st.warning(f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper") sampler.discretization = Img2ImgDiscretizationWrapper( sampler.discretization, strength=img2img_strength ) @@ -427,10 +279,7 @@ def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1 s_noise=s_noise, verbose=True, ) - elif ( - sampler_name == "EulerAncestralSampler" - or sampler_name == "DPMPP2SAncestralSampler" - ): + elif sampler_name == "EulerAncestralSampler" or sampler_name == "DPMPP2SAncestralSampler": s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0) eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0) @@ -507,238 +356,3 @@ def get_init_img(batch_size=1, key=None): init_image = load_img(key=key).cuda() init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) return init_image - - -def do_sample( - model, - sampler, - value_dict, - num_samples, - H, - W, - C, - F, - force_uc_zero_embeddings: List = None, - batch2model_input: List = None, - return_latents=False, - filter=None, -): - if force_uc_zero_embeddings is None: - force_uc_zero_embeddings = [] - if batch2model_input is None: - batch2model_input = [] - - st.text("Sampling") - - outputs = st.empty() - precision_scope = autocast - with torch.no_grad(): - with precision_scope("cuda"): - with model.ema_scope(): - num_samples = [num_samples] - load_model(model.conditioner) - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - num_samples, - ) - for key in batch: - if isinstance(batch[key], torch.Tensor): - print(key, batch[key].shape) - elif isinstance(batch[key], list): - print(key, [len(l) for l in batch[key]]) - else: - print(key, batch[key]) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=force_uc_zero_embeddings, - ) - unload_model(model.conditioner) - - for k in c: - if not k == "crossattn": - c[k], uc[k] = map( - lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc) - ) - - additional_model_inputs = {} - for k in batch2model_input: - additional_model_inputs[k] = batch[k] - - shape = (math.prod(num_samples), C, H // F, W // F) - randn = torch.randn(shape).to("cuda") - - def denoiser(input, sigma, c): - return model.denoiser( - model.model, input, sigma, c, **additional_model_inputs - ) - - load_model(model.denoiser) - load_model(model.model) - samples_z = sampler(denoiser, randn, cond=c, uc=uc) - unload_model(model.model) - unload_model(model.denoiser) - - load_model(model.first_stage_model) - samples_x = model.decode_first_stage(samples_z) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) - unload_model(model.first_stage_model) - - if filter is not None: - samples = filter(samples) - - grid = torch.stack([samples]) - grid = rearrange(grid, "n b c h w -> (n h) (b w) c") - outputs.image(grid.cpu().numpy()) - - if return_latents: - return samples, samples_z - return samples - - -def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): - # Hardcoded demo setups; might undergo some changes in the future - - batch = {} - batch_uc = {} - - for key in keys: - if key == "txt": - batch["txt"] = ( - np.repeat([value_dict["prompt"]], repeats=math.prod(N)) - .reshape(N) - .tolist() - ) - batch_uc["txt"] = ( - np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) - .reshape(N) - .tolist() - ) - elif key == "original_size_as_tuple": - batch["original_size_as_tuple"] = ( - torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) - .to(device) - .repeat(*N, 1) - ) - elif key == "crop_coords_top_left": - batch["crop_coords_top_left"] = ( - torch.tensor( - [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] - ) - .to(device) - .repeat(*N, 1) - ) - elif key == "aesthetic_score": - batch["aesthetic_score"] = ( - torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) - ) - batch_uc["aesthetic_score"] = ( - torch.tensor([value_dict["negative_aesthetic_score"]]) - .to(device) - .repeat(*N, 1) - ) - - elif key == "target_size_as_tuple": - batch["target_size_as_tuple"] = ( - torch.tensor([value_dict["target_height"], value_dict["target_width"]]) - .to(device) - .repeat(*N, 1) - ) - else: - batch[key] = value_dict[key] - - for key in batch.keys(): - if key not in batch_uc and isinstance(batch[key], torch.Tensor): - batch_uc[key] = torch.clone(batch[key]) - return batch, batch_uc - - -@torch.no_grad() -def do_img2img( - img, - model, - sampler, - value_dict, - num_samples, - force_uc_zero_embeddings=[], - additional_kwargs={}, - offset_noise_level: int = 0.0, - return_latents=False, - skip_encode=False, - filter=None, - add_noise=True, -): - st.text("Sampling") - - outputs = st.empty() - precision_scope = autocast - with torch.no_grad(): - with precision_scope("cuda"): - with model.ema_scope(): - load_model(model.conditioner) - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - [num_samples], - ) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=force_uc_zero_embeddings, - ) - unload_model(model.conditioner) - for k in c: - c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc)) - - for k in additional_kwargs: - c[k] = uc[k] = additional_kwargs[k] - if skip_encode: - z = img - else: - load_model(model.first_stage_model) - z = model.encode_first_stage(img) - unload_model(model.first_stage_model) - - noise = torch.randn_like(z) - - sigmas = sampler.discretization(sampler.num_steps).cuda() - sigma = sigmas[0] - - st.info(f"all sigmas: {sigmas}") - st.info(f"noising sigma: {sigma}") - if offset_noise_level > 0.0: - noise = noise + offset_noise_level * append_dims( - torch.randn(z.shape[0], device=z.device), z.ndim - ) - if add_noise: - noised_z = z + noise * append_dims(sigma, z.ndim).cuda() - noised_z = noised_z / torch.sqrt( - 1.0 + sigmas[0] ** 2.0 - ) # Note: hardcoded to DDPM-like scaling. need to generalize later. - else: - noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0) - - def denoiser(x, sigma, c): - return model.denoiser(model.model, x, sigma, c) - - load_model(model.denoiser) - load_model(model.model) - samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) - unload_model(model.model) - unload_model(model.denoiser) - - load_model(model.first_stage_model) - samples_x = model.decode_first_stage(samples_z) - unload_model(model.first_stage_model) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) - - if filter is not None: - samples = filter(samples) - - grid = embed_watemark(torch.stack([samples])) - grid = rearrange(grid, "n b c h w -> (n h) (b w) c") - outputs.image(grid.cpu().numpy()) - if return_latents: - return samples, samples_z - return samples diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 4f7c78bf..4f2ce18d 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -7,6 +7,7 @@ do_sample, do_img2img, Img2ImgDiscretizationWrapper, + Txt2NoisyDiscretizationWrapper, ) from sgm.modules.diffusionmodules.sampling import ( EulerEDMSampler, @@ -180,30 +181,20 @@ def __init__( model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints" if not os.path.exists(model_path): # This supports development installs where checkpoints is root level of the repo - model_path = ( - pathlib.Path(__file__).parent.parent.parent.resolve() - / "checkpoints" - ) + model_path = pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" if config_path is None: - config_path = ( - pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" - ) + config_path = pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" if not os.path.exists(config_path): # This supports development installs where configs is root level of the repo config_path = ( - pathlib.Path(__file__).parent.parent.parent.resolve() - / "configs/inference" + pathlib.Path(__file__).parent.parent.parent.resolve() / "configs/inference" ) self.config = str(config_path / self.specs.config) self.ckpt = str(model_path / self.specs.ckpt) if not os.path.exists(self.config): - raise ValueError( - f"Config {self.config} not found, check model spec or config_path" - ) + raise ValueError(f"Config {self.config} not found, check model spec or config_path") if not os.path.exists(self.ckpt): - raise ValueError( - f"Checkpoint {self.ckpt} not found, check model spec or config_path" - ) + raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path") self.device = device self.model = self._load_model(device=device, use_fp16=use_fp16) @@ -225,8 +216,13 @@ def text_to_image( negative_prompt: str = "", samples: int = 1, return_latents: bool = False, + stage2strength=None, ): sampler = get_sampler_config(params) + if stage2strength is not None: + sampler.discretization = Txt2NoisyDiscretizationWrapper( + sampler.discretization, strength=stage2strength, original_steps=params.steps + ) value_dict = asdict(params) value_dict["prompt"] = prompt value_dict["negative_prompt"] = negative_prompt @@ -279,10 +275,7 @@ def image_to_image( ) def wrap_discretization(self, discretization, strength=1.0): - if ( - not isinstance(discretization, Img2ImgDiscretizationWrapper) - and strength < 1.0 - ): + if not isinstance(discretization, Img2ImgDiscretizationWrapper) and strength < 1.0: return Img2ImgDiscretizationWrapper(discretization, strength=strength) return discretization @@ -329,9 +322,7 @@ def refiner( def get_guider_config(params: SamplingParams): if params.guider == Guider.IDENTITY: - guider_config = { - "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" - } + guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} elif params.guider == Guider.VANILLA: scale = params.scale diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 1c653708..90b06c9f 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -35,9 +35,9 @@ def __call__(self, image: torch.Tensor): if squeeze: image = image[None, ...] n = image.shape[0] - image_np = rearrange( - (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" - ).numpy()[:, :, :, ::-1] + image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[ + :, :, :, ::-1 + ] # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] for k in range(image_np.shape[0]): image_np[k] = self.encoder.encode(image_np[k], "dwtDct") @@ -98,6 +98,36 @@ def __call__(self, *args, **kwargs): return sigmas +class Txt2NoisyDiscretizationWrapper: + """ + wraps a discretizer, and prunes the sigmas + params: + strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned) + """ + + def __init__(self, discretization, strength: float = 0.0, original_steps=None): + self.discretization = discretization + self.strength = strength + self.original_steps = original_steps + assert 0.0 <= self.strength <= 1.0 + + def __call__(self, *args, **kwargs): + # sigmas start large first, and decrease then + sigmas = self.discretization(*args, **kwargs) + print(f"sigmas after discretization, before pruning img2img: ", sigmas) + sigmas = torch.flip(sigmas, (0,)) + if self.original_steps is None: + steps = len(sigmas) + else: + steps = self.original_steps + 1 + prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0) + sigmas = sigmas[prune_index:] + print("prune index:", prune_index) + sigmas = torch.flip(sigmas, (0,)) + print(f"sigmas after pruning: ", sigmas) + return sigmas + + def do_sample( model, sampler, @@ -154,13 +184,15 @@ def do_sample( randn = torch.randn(shape).to(device) def denoiser(input, sigma, c): - return model.denoiser( - model.model, input, sigma, c, **additional_model_inputs - ) + return model.denoiser(model.model, input, sigma, c, **additional_model_inputs) + + with ModelOnDevice(model.denoiser, device): + with ModelOnDevice(model.model, device): + samples_z = sampler(denoiser, randn, cond=c, uc=uc) - samples_z = sampler(denoiser, randn, cond=c, uc=uc) - samples_x = model.decode_first_stage(samples_z) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + with ModelOnDevice(model.first_stage_model, device): + samples_x = model.decode_first_stage(samples_z) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) if filter is not None: samples = filter(samples) @@ -179,14 +211,10 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): for key in keys: if key == "txt": batch["txt"] = ( - np.repeat([value_dict["prompt"]], repeats=math.prod(N)) - .reshape(N) - .tolist() + np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() ) batch_uc["txt"] = ( - np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) - .reshape(N) - .tolist() + np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() ) elif key == "original_size_as_tuple": batch["original_size_as_tuple"] = ( @@ -196,9 +224,7 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): ) elif key == "crop_coords_top_left": batch["crop_coords_top_left"] = ( - torch.tensor( - [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] - ) + torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]) .to(device) .repeat(*N, 1) ) @@ -207,9 +233,7 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) ) batch_uc["aesthetic_score"] = ( - torch.tensor([value_dict["negative_aesthetic_score"]]) - .to(device) - .repeat(*N, 1) + torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1) ) elif key == "target_size_as_tuple": @@ -230,9 +254,7 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): def get_input_image_tensor(image: Image.Image, device="cuda"): w, h = image.size print(f"loaded input image of size ({w}, {h})") - width, height = map( - lambda x: x - x % 64, (w, h) - ) # resize to integer multiple of 64 + width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 image = image.resize((width, height)) image_array = np.array(image.convert("RGB")) image_array = image_array[None].transpose(0, 3, 1, 2) @@ -252,10 +274,11 @@ def do_img2img( return_latents=False, skip_encode=False, filter=None, + add_noise=True, device="cuda", ): with torch.no_grad(): - with autocast(device) as precision_scope: + with autocast(device): with model.ema_scope(): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), @@ -285,17 +308,24 @@ def do_img2img( noise = noise + offset_noise_level * append_dims( torch.randn(z.shape[0], device=z.device), z.ndim ) - noised_z = z + noise * append_dims(sigma, z.ndim) - noised_z = noised_z / torch.sqrt( - 1.0 + sigmas[0] ** 2.0 - ) # Note: hardcoded to DDPM-like scaling. need to generalize later. + if add_noise: + noised_z = z + noise * append_dims(sigma, z.ndim).cuda() + noised_z = noised_z / torch.sqrt( + 1.0 + sigmas[0] ** 2.0 + ) # Note: hardcoded to DDPM-like scaling. need to generalize later. + else: + noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0) def denoiser(x, sigma, c): return model.denoiser(model.model, x, sigma, c) - samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) - samples_x = model.decode_first_stage(samples_z) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + with ModelOnDevice(model.denoiser, device): + with ModelOnDevice(model.model, device): + samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) + + with ModelOnDevice(model.first_stage_model, device): + samples_x = model.decode_first_stage(samples_z) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) if filter is not None: samples = filter(samples) @@ -303,3 +333,28 @@ def denoiser(x, sigma, c): if return_latents: return samples, samples_z return samples + + +class ModelOnDevice(object): + def __init__(self, model, device): + self.model = model + self.device = device + self.original_device = model.device + + def __enter__(self): + if self.device != self.original_device: + self.model.to(self.device) + + def __exit__(self, *args): + if self.device != self.original_device: + self.model.to(self.original_device) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def load_model(model, device): + if model.device != device: + old_device = model.device + model.to(device) + return old_device + return False From f06c67c2062c54542dd3038674d0f21d24669ad0 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 6 Aug 2023 11:30:40 +0000 Subject: [PATCH 12/56] formatting, remove reference --- scripts/demo/sampling.py | 11 ++++++----- scripts/demo/streamlit_helpers.py | 29 ++++++++++++++++++++------- sgm/inference/api.py | 33 +++++++++++++++++++++++-------- sgm/inference/helpers.py | 30 +++++++++++++++++++--------- 4 files changed, 74 insertions(+), 29 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 2d9a62ef..29cb3054 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -105,7 +105,9 @@ def load_img(display=True, key=None, device="cuda"): st.image(image) w, h = image.size print(f"loaded input image of size ({w}, {h})") - width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 + width, height = map( + lambda x: x - x % 64, (w, h) + ) # resize to integer multiple of 64 image = image.resize((width, height)) image = np.array(image.convert("RGB")) image = image[None].transpose(0, 3, 1, 2) @@ -193,7 +195,9 @@ def run_img2img( prompt=prompt, negative_prompt=negative_prompt, ) - strength = st.number_input("**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0) + strength = st.number_input( + "**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0 + ) sampler, num_rows, num_cols = init_sampling( img2img_strength=strength, stage2strength=stage2strength, @@ -280,8 +284,6 @@ def apply_refiner( save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version)) state = init_st(version_dict, load_filter=True) - if state["msg"]: - st.info(state["msg"]) model = state["model"] is_legacy = version_dict["is_legacy"] @@ -308,7 +310,6 @@ def apply_refiner( version_dict2 = VERSION2SPECS[version2] state2 = init_st(version_dict2, load_filter=False) - st.info(state2["msg"]) stage2strength = st.number_input( "**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0 diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index fa104b9c..d3dd7d71 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -34,7 +34,9 @@ def init_st(version_dict, load_ckpt=True, load_filter=True): ckpt = version_dict["ckpt"] config = OmegaConf.load(config) - model = load_model_from_config(config, ckpt if load_ckpt else None, freeze=False) + model = load_model_from_config( + config, ckpt if load_ckpt else None, freeze=False + ) state["model"] = model state["ckpt"] = ckpt if load_ckpt else None @@ -154,9 +156,13 @@ def get_guider(key): ) if guider == "IdentityGuider": - guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" + } elif guider == "VanillaCFG": - scale = st.number_input(f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0) + scale = st.number_input( + f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0 + ) thresholder = st.sidebar.selectbox( f"Thresholder #{key}", @@ -189,9 +195,13 @@ def init_sampling( ): num_rows, num_cols = 1, 1 if specify_num_samples: - num_cols = st.number_input(f"num cols #{key}", value=2, min_value=1, max_value=10) + num_cols = st.number_input( + f"num cols #{key}", value=2, min_value=1, max_value=10 + ) - steps = st.sidebar.number_input(f"steps #{key}", value=40, min_value=1, max_value=1000) + steps = st.sidebar.number_input( + f"steps #{key}", value=40, min_value=1, max_value=1000 + ) sampler = st.sidebar.selectbox( f"Sampler #{key}", [ @@ -218,7 +228,9 @@ def init_sampling( sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key) if img2img_strength < 1.0: - st.warning(f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper") + st.warning( + f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper" + ) sampler.discretization = Img2ImgDiscretizationWrapper( sampler.discretization, strength=img2img_strength ) @@ -279,7 +291,10 @@ def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1 s_noise=s_noise, verbose=True, ) - elif sampler_name == "EulerAncestralSampler" or sampler_name == "DPMPP2SAncestralSampler": + elif ( + sampler_name == "EulerAncestralSampler" + or sampler_name == "DPMPP2SAncestralSampler" + ): s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0) eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 4f2ce18d..77f5667d 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -181,20 +181,30 @@ def __init__( model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints" if not os.path.exists(model_path): # This supports development installs where checkpoints is root level of the repo - model_path = pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" + model_path = ( + pathlib.Path(__file__).parent.parent.parent.resolve() + / "checkpoints" + ) if config_path is None: - config_path = pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" + config_path = ( + pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" + ) if not os.path.exists(config_path): # This supports development installs where configs is root level of the repo config_path = ( - pathlib.Path(__file__).parent.parent.parent.resolve() / "configs/inference" + pathlib.Path(__file__).parent.parent.parent.resolve() + / "configs/inference" ) self.config = str(config_path / self.specs.config) self.ckpt = str(model_path / self.specs.ckpt) if not os.path.exists(self.config): - raise ValueError(f"Config {self.config} not found, check model spec or config_path") + raise ValueError( + f"Config {self.config} not found, check model spec or config_path" + ) if not os.path.exists(self.ckpt): - raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path") + raise ValueError( + f"Checkpoint {self.ckpt} not found, check model spec or config_path" + ) self.device = device self.model = self._load_model(device=device, use_fp16=use_fp16) @@ -221,7 +231,9 @@ def text_to_image( sampler = get_sampler_config(params) if stage2strength is not None: sampler.discretization = Txt2NoisyDiscretizationWrapper( - sampler.discretization, strength=stage2strength, original_steps=params.steps + sampler.discretization, + strength=stage2strength, + original_steps=params.steps, ) value_dict = asdict(params) value_dict["prompt"] = prompt @@ -275,7 +287,10 @@ def image_to_image( ) def wrap_discretization(self, discretization, strength=1.0): - if not isinstance(discretization, Img2ImgDiscretizationWrapper) and strength < 1.0: + if ( + not isinstance(discretization, Img2ImgDiscretizationWrapper) + and strength < 1.0 + ): return Img2ImgDiscretizationWrapper(discretization, strength=strength) return discretization @@ -322,7 +337,9 @@ def refiner( def get_guider_config(params: SamplingParams): if params.guider == Guider.IDENTITY: - guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" + } elif params.guider == Guider.VANILLA: scale = params.scale diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 90b06c9f..7b1ba764 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -35,9 +35,9 @@ def __call__(self, image: torch.Tensor): if squeeze: image = image[None, ...] n = image.shape[0] - image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[ - :, :, :, ::-1 - ] + image_np = rearrange( + (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" + ).numpy()[:, :, :, ::-1] # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] for k in range(image_np.shape[0]): image_np[k] = self.encoder.encode(image_np[k], "dwtDct") @@ -184,7 +184,9 @@ def do_sample( randn = torch.randn(shape).to(device) def denoiser(input, sigma, c): - return model.denoiser(model.model, input, sigma, c, **additional_model_inputs) + return model.denoiser( + model.model, input, sigma, c, **additional_model_inputs + ) with ModelOnDevice(model.denoiser, device): with ModelOnDevice(model.model, device): @@ -211,10 +213,14 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): for key in keys: if key == "txt": batch["txt"] = ( - np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() + np.repeat([value_dict["prompt"]], repeats=math.prod(N)) + .reshape(N) + .tolist() ) batch_uc["txt"] = ( - np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() + np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) + .reshape(N) + .tolist() ) elif key == "original_size_as_tuple": batch["original_size_as_tuple"] = ( @@ -224,7 +230,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): ) elif key == "crop_coords_top_left": batch["crop_coords_top_left"] = ( - torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]) + torch.tensor( + [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] + ) .to(device) .repeat(*N, 1) ) @@ -233,7 +241,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) ) batch_uc["aesthetic_score"] = ( - torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1) + torch.tensor([value_dict["negative_aesthetic_score"]]) + .to(device) + .repeat(*N, 1) ) elif key == "target_size_as_tuple": @@ -254,7 +264,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): def get_input_image_tensor(image: Image.Image, device="cuda"): w, h = image.size print(f"loaded input image of size ({w}, {h})") - width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 + width, height = map( + lambda x: x - x % 64, (w, h) + ) # resize to integer multiple of 64 image = image.resize((width, height)) image_array = np.array(image.convert("RGB")) image_array = image_array[None].transpose(0, 3, 1, 2) From ea5f232d5d2175b1ab30ea25b752f18da0430110 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 6 Aug 2023 11:42:39 +0000 Subject: [PATCH 13/56] move conditioner to device --- sgm/inference/helpers.py | 61 ++++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 7b1ba764..30dc082f 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -152,23 +152,24 @@ def do_sample( with autocast(device) as precision_scope: with model.ema_scope(): num_samples = [num_samples] - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - num_samples, - ) - for key in batch: - if isinstance(batch[key], torch.Tensor): - print(key, batch[key].shape) - elif isinstance(batch[key], list): - print(key, [len(l) for l in batch[key]]) - else: - print(key, batch[key]) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=force_uc_zero_embeddings, - ) + with ModelOnDevice(model.conditioner, device): + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + num_samples, + ) + for key in batch: + if isinstance(batch[key], torch.Tensor): + print(key, batch[key].shape) + elif isinstance(batch[key], list): + print(key, [len(l) for l in batch[key]]) + else: + print(key, batch[key]) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) for k in c: if not k == "crossattn": @@ -292,16 +293,17 @@ def do_img2img( with torch.no_grad(): with autocast(device): with model.ema_scope(): - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - [num_samples], - ) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=force_uc_zero_embeddings, - ) + with ModelOnDevice(model.conditioner, device): + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + [num_samples], + ) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) for k in c: c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc)) @@ -311,8 +313,11 @@ def do_img2img( if skip_encode: z = img else: - z = model.encode_first_stage(img) + with ModelOnDevice(model.first_stage_model, device): + z = model.encode_first_stage(img) + noise = torch.randn_like(z) + sigmas = sampler.discretization(sampler.num_steps) sigma = sigmas[0].to(z.device) From 0c2c5c66a2a3c0aa3945ac33bac36eee50132c1b Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 6 Aug 2023 12:26:01 +0000 Subject: [PATCH 14/56] fix device check --- sgm/inference/helpers.py | 58 ++++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 32 deletions(-) diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 30dc082f..6f06218c 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -35,9 +35,9 @@ def __call__(self, image: torch.Tensor): if squeeze: image = image[None, ...] n = image.shape[0] - image_np = rearrange( - (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" - ).numpy()[:, :, :, ::-1] + image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[ + :, :, :, ::-1 + ] # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] for k in range(image_np.shape[0]): image_np[k] = self.encoder.encode(image_np[k], "dwtDct") @@ -185,9 +185,7 @@ def do_sample( randn = torch.randn(shape).to(device) def denoiser(input, sigma, c): - return model.denoiser( - model.model, input, sigma, c, **additional_model_inputs - ) + return model.denoiser(model.model, input, sigma, c, **additional_model_inputs) with ModelOnDevice(model.denoiser, device): with ModelOnDevice(model.model, device): @@ -214,14 +212,10 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): for key in keys: if key == "txt": batch["txt"] = ( - np.repeat([value_dict["prompt"]], repeats=math.prod(N)) - .reshape(N) - .tolist() + np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() ) batch_uc["txt"] = ( - np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) - .reshape(N) - .tolist() + np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() ) elif key == "original_size_as_tuple": batch["original_size_as_tuple"] = ( @@ -231,9 +225,7 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): ) elif key == "crop_coords_top_left": batch["crop_coords_top_left"] = ( - torch.tensor( - [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] - ) + torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]) .to(device) .repeat(*N, 1) ) @@ -242,9 +234,7 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) ) batch_uc["aesthetic_score"] = ( - torch.tensor([value_dict["negative_aesthetic_score"]]) - .to(device) - .repeat(*N, 1) + torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1) ) elif key == "target_size_as_tuple": @@ -265,9 +255,7 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): def get_input_image_tensor(image: Image.Image, device="cuda"): w, h = image.size print(f"loaded input image of size ({w}, {h})") - width, height = map( - lambda x: x - x % 64, (w, h) - ) # resize to integer multiple of 64 + width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 image = image.resize((width, height)) image_array = np.array(image.convert("RGB")) image_array = image_array[None].transpose(0, 3, 1, 2) @@ -353,10 +341,24 @@ def denoiser(x, sigma, c): class ModelOnDevice(object): - def __init__(self, model, device): + def __init__( + self, model: Union[torch.nn.Module, torch.Tensor], device: Union[torch.device, str] + ): self.model = model - self.device = device - self.original_device = model.device + self.device = torch.device(device) + if isinstance(model, torch.Tensor): + self.original_device = model.device + else: + param = next(model.parameters(), None) + if param is not None: + self.original_device = param.device + else: + buf = next(model.buffers(), None) + if buf is not None: + self.original_device = buf.device + else: + # If device could not be found, turn this into a no-op + self.original_device = self.device def __enter__(self): if self.device != self.original_device: @@ -367,11 +369,3 @@ def __exit__(self, *args): self.model.to(self.original_device) if torch.cuda.is_available(): torch.cuda.empty_cache() - - -def load_model(model, device): - if model.device != device: - old_device = model.device - model.to(device) - return old_device - return False From 451c76ada1cb64228c477092e42088384f4afbfd Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 6 Aug 2023 12:26:16 +0000 Subject: [PATCH 15/56] format --- sgm/inference/helpers.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 6f06218c..b576669f 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -35,9 +35,9 @@ def __call__(self, image: torch.Tensor): if squeeze: image = image[None, ...] n = image.shape[0] - image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[ - :, :, :, ::-1 - ] + image_np = rearrange( + (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" + ).numpy()[:, :, :, ::-1] # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] for k in range(image_np.shape[0]): image_np[k] = self.encoder.encode(image_np[k], "dwtDct") @@ -185,7 +185,9 @@ def do_sample( randn = torch.randn(shape).to(device) def denoiser(input, sigma, c): - return model.denoiser(model.model, input, sigma, c, **additional_model_inputs) + return model.denoiser( + model.model, input, sigma, c, **additional_model_inputs + ) with ModelOnDevice(model.denoiser, device): with ModelOnDevice(model.model, device): @@ -212,10 +214,14 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): for key in keys: if key == "txt": batch["txt"] = ( - np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() + np.repeat([value_dict["prompt"]], repeats=math.prod(N)) + .reshape(N) + .tolist() ) batch_uc["txt"] = ( - np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() + np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) + .reshape(N) + .tolist() ) elif key == "original_size_as_tuple": batch["original_size_as_tuple"] = ( @@ -225,7 +231,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): ) elif key == "crop_coords_top_left": batch["crop_coords_top_left"] = ( - torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]) + torch.tensor( + [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] + ) .to(device) .repeat(*N, 1) ) @@ -234,7 +242,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) ) batch_uc["aesthetic_score"] = ( - torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1) + torch.tensor([value_dict["negative_aesthetic_score"]]) + .to(device) + .repeat(*N, 1) ) elif key == "target_size_as_tuple": @@ -255,7 +265,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): def get_input_image_tensor(image: Image.Image, device="cuda"): w, h = image.size print(f"loaded input image of size ({w}, {h})") - width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 + width, height = map( + lambda x: x - x % 64, (w, h) + ) # resize to integer multiple of 64 image = image.resize((width, height)) image_array = np.array(image.convert("RGB")) image_array = image_array[None].transpose(0, 3, 1, 2) @@ -342,7 +354,9 @@ def denoiser(x, sigma, c): class ModelOnDevice(object): def __init__( - self, model: Union[torch.nn.Module, torch.Tensor], device: Union[torch.device, str] + self, + model: Union[torch.nn.Module, torch.Tensor], + device: Union[torch.device, str], ): self.model = model self.device = torch.device(device) From f2fba1dfa22803fdefd86141d448bf85c16c827e Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 6 Aug 2023 21:08:19 +0000 Subject: [PATCH 16/56] fix noisy latent handling --- sgm/inference/api.py | 52 ++++++++++++++++++++++--------- tests/inference/test_inference.py | 14 ++++----- 2 files changed, 44 insertions(+), 22 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 77f5667d..1ab892f5 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -226,15 +226,17 @@ def text_to_image( negative_prompt: str = "", samples: int = 1, return_latents: bool = False, - stage2strength=None, + noise_strength=None, ): sampler = get_sampler_config(params) - if stage2strength is not None: - sampler.discretization = Txt2NoisyDiscretizationWrapper( - sampler.discretization, - strength=stage2strength, - original_steps=params.steps, - ) + + sampler.discretization = self.wrap_discretization( + sampler.discretization, + image_strength=None, + noise_strength=noise_strength, + steps=params.steps, + ) + value_dict = asdict(params) value_dict["prompt"] = prompt value_dict["negative_prompt"] = negative_prompt @@ -262,11 +264,15 @@ def image_to_image( negative_prompt: str = "", samples: int = 1, return_latents: bool = False, + noise_strength=None, ): sampler = get_sampler_config(params) sampler.discretization = self.wrap_discretization( - sampler.discretization, strength=params.img2img_strength + sampler.discretization, + image_strength=params.img2img_strength, + noise_strength=noise_strength, + steps=params.steps, ) height, width = image.shape[2], image.shape[3] @@ -286,12 +292,29 @@ def image_to_image( filter=None, ) - def wrap_discretization(self, discretization, strength=1.0): + def wrap_discretization( + self, discretization, image_strength=None, noise_strength=None, steps=None + ): + if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance( + discretization, Txt2NoisyDiscretizationWrapper + ): + return discretization # Already wrapped + if image_strength is not None and image_strength < 1.0 and image_strength > 0.0: + discretization = Img2ImgDiscretizationWrapper( + discretization, strength=image_strength + ) + if ( - not isinstance(discretization, Img2ImgDiscretizationWrapper) - and strength < 1.0 + noise_strength is not None + and noise_strength < 1.0 + and noise_strength > 0.0 + and steps is not None ): - return Img2ImgDiscretizationWrapper(discretization, strength=strength) + discretization = Txt2NoisyDiscretizationWrapper( + discretization, + strength=noise_strength, + original_steps=steps, + ) return discretization def refiner( @@ -300,7 +323,7 @@ def refiner( prompt: str, negative_prompt: str = "", params: SamplingParams = SamplingParams( - sampler=Sampler.EULER_EDM, steps=40, img2img_strength=0.2 + sampler=Sampler.EULER_EDM, steps=40, img2img_strength=0.15 ), samples: int = 1, return_latents: bool = False, @@ -320,7 +343,7 @@ def refiner( } sampler.discretization = self.wrap_discretization( - sampler.discretization, strength=params.img2img_strength + sampler.discretization, image_strength=params.img2img_strength ) return do_img2img( @@ -332,6 +355,7 @@ def refiner( skip_encode=True, return_latents=return_latents, filter=None, + add_noise=False, ) diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index ae6f3550..617e4088 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -68,9 +68,7 @@ def test_img2img(self, pipeline: SamplingPipeline, sampler_enum): assert output is not None @pytest.mark.parametrize("sampler_enum", Sampler) - @pytest.mark.parametrize( - "use_init_image", [True, False], ids=["img2img", "txt2img"] - ) + @pytest.mark.parametrize("use_init_image", [True, False], ids=["img2img", "txt2img"]) def test_sdxl_with_refiner( self, sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline], @@ -81,13 +79,12 @@ def test_sdxl_with_refiner( if use_init_image: output = base_pipeline.image_to_image( params=SamplingParams(sampler=sampler_enum.value, steps=10), - image=self.create_init_image( - base_pipeline.specs.height, base_pipeline.specs.width - ), + image=self.create_init_image(base_pipeline.specs.height, base_pipeline.specs.width), prompt="A professional photograph of an astronaut riding a pig", negative_prompt="", samples=1, return_latents=True, + noise_strength=0.15, ) else: output = base_pipeline.text_to_image( @@ -96,16 +93,17 @@ def test_sdxl_with_refiner( negative_prompt="", samples=1, return_latents=True, + noise_strength=0.15, ) assert isinstance(output, (tuple, list)) samples, samples_z = output assert samples is not None assert samples_z is not None - refiner_pipeline.refiner( + refiner_pipeline.refiner( image=samples_z, prompt="A professional photograph of an astronaut riding a pig", - params=SamplingParams(sampler=sampler_enum.value, steps=40, img2img_strength=0.20), + params=SamplingParams(sampler=sampler_enum.value, steps=40, img2img_strength=0.15), negative_prompt="", samples=1, ) From 8f8757b4ff16188434ae23fd1c68e5ebfdd76d72 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 6 Aug 2023 21:09:09 +0000 Subject: [PATCH 17/56] version bump for changes to inference helpers --- sgm/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sgm/__init__.py b/sgm/__init__.py index 24bc84af..2a058931 100644 --- a/sgm/__init__.py +++ b/sgm/__init__.py @@ -1,4 +1,4 @@ from .models import AutoencodingEngine, DiffusionEngine from .util import get_configs_path, instantiate_from_config -__version__ = "0.1.0" +__version__ = "0.1.1" From 76ca428422e99e744e2308e357c3aff6ed5ae2b2 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 6 Aug 2023 21:39:18 +0000 Subject: [PATCH 18/56] fix path resolution bug --- sgm/inference/api.py | 32 +++++++++----------------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 1ab892f5..fd89558b 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -181,30 +181,20 @@ def __init__( model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints" if not os.path.exists(model_path): # This supports development installs where checkpoints is root level of the repo - model_path = ( - pathlib.Path(__file__).parent.parent.parent.resolve() - / "checkpoints" - ) + model_path = pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" if config_path is None: - config_path = ( - pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" - ) + config_path = pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" if not os.path.exists(config_path): # This supports development installs where configs is root level of the repo config_path = ( - pathlib.Path(__file__).parent.parent.parent.resolve() - / "configs/inference" + pathlib.Path(__file__).parent.parent.parent.resolve() / "configs/inference" ) - self.config = str(config_path / self.specs.config) - self.ckpt = str(model_path / self.specs.ckpt) + self.config = str(pathlib.Path(config_path) / self.specs.config) + self.ckpt = str(pathlib.Path(model_path) / self.specs.ckpt) if not os.path.exists(self.config): - raise ValueError( - f"Config {self.config} not found, check model spec or config_path" - ) + raise ValueError(f"Config {self.config} not found, check model spec or config_path") if not os.path.exists(self.ckpt): - raise ValueError( - f"Checkpoint {self.ckpt} not found, check model spec or config_path" - ) + raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path") self.device = device self.model = self._load_model(device=device, use_fp16=use_fp16) @@ -300,9 +290,7 @@ def wrap_discretization( ): return discretization # Already wrapped if image_strength is not None and image_strength < 1.0 and image_strength > 0.0: - discretization = Img2ImgDiscretizationWrapper( - discretization, strength=image_strength - ) + discretization = Img2ImgDiscretizationWrapper(discretization, strength=image_strength) if ( noise_strength is not None @@ -361,9 +349,7 @@ def refiner( def get_guider_config(params: SamplingParams): if params.guider == Guider.IDENTITY: - guider_config = { - "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" - } + guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} elif params.guider == Guider.VANILLA: scale = params.scale From ced97f0e84939a6b8ca751789a150a4d72683a9e Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 6 Aug 2023 23:24:14 +0000 Subject: [PATCH 19/56] update defaults --- sgm/inference/api.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index fd89558b..89b73700 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -62,9 +62,9 @@ class SamplingParams: discretization: Discretization = Discretization.LEGACY_DDPM guider: Guider = Guider.VANILLA thresholder: Thresholder = Thresholder.NONE - scale: float = 6.0 - aesthetic_score: float = 5.0 - negative_aesthetic_score: float = 5.0 + scale: float = 5.0 + aesthetic_score: float = 6.0 + negative_aesthetic_score: float = 2.5 img2img_strength: float = 1.0 orig_width: int = 1024 orig_height: int = 1024 @@ -181,20 +181,30 @@ def __init__( model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints" if not os.path.exists(model_path): # This supports development installs where checkpoints is root level of the repo - model_path = pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" + model_path = ( + pathlib.Path(__file__).parent.parent.parent.resolve() + / "checkpoints" + ) if config_path is None: - config_path = pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" + config_path = ( + pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" + ) if not os.path.exists(config_path): # This supports development installs where configs is root level of the repo config_path = ( - pathlib.Path(__file__).parent.parent.parent.resolve() / "configs/inference" + pathlib.Path(__file__).parent.parent.parent.resolve() + / "configs/inference" ) self.config = str(pathlib.Path(config_path) / self.specs.config) self.ckpt = str(pathlib.Path(model_path) / self.specs.ckpt) if not os.path.exists(self.config): - raise ValueError(f"Config {self.config} not found, check model spec or config_path") + raise ValueError( + f"Config {self.config} not found, check model spec or config_path" + ) if not os.path.exists(self.ckpt): - raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path") + raise ValueError( + f"Checkpoint {self.ckpt} not found, check model spec or config_path" + ) self.device = device self.model = self._load_model(device=device, use_fp16=use_fp16) @@ -290,7 +300,9 @@ def wrap_discretization( ): return discretization # Already wrapped if image_strength is not None and image_strength < 1.0 and image_strength > 0.0: - discretization = Img2ImgDiscretizationWrapper(discretization, strength=image_strength) + discretization = Img2ImgDiscretizationWrapper( + discretization, strength=image_strength + ) if ( noise_strength is not None @@ -349,7 +361,9 @@ def refiner( def get_guider_config(params: SamplingParams): if params.guider == Guider.IDENTITY: - guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" + } elif params.guider == Guider.VANILLA: scale = params.scale From 6c18c8443a964fe1cbbf6f434088a73fdab24e3b Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 6 Aug 2023 23:46:20 +0000 Subject: [PATCH 20/56] rename ModelOnDevice to SwapToDevice --- sgm/inference/helpers.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index b576669f..a0c9e221 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -152,7 +152,7 @@ def do_sample( with autocast(device) as precision_scope: with model.ema_scope(): num_samples = [num_samples] - with ModelOnDevice(model.conditioner, device): + with SwapToDevice(model.conditioner, device): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, @@ -189,11 +189,11 @@ def denoiser(input, sigma, c): model.model, input, sigma, c, **additional_model_inputs ) - with ModelOnDevice(model.denoiser, device): - with ModelOnDevice(model.model, device): + with SwapToDevice(model.denoiser, device): + with SwapToDevice(model.model, device): samples_z = sampler(denoiser, randn, cond=c, uc=uc) - with ModelOnDevice(model.first_stage_model, device): + with SwapToDevice(model.first_stage_model, device): samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) @@ -293,7 +293,7 @@ def do_img2img( with torch.no_grad(): with autocast(device): with model.ema_scope(): - with ModelOnDevice(model.conditioner, device): + with SwapToDevice(model.conditioner, device): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, @@ -313,7 +313,7 @@ def do_img2img( if skip_encode: z = img else: - with ModelOnDevice(model.first_stage_model, device): + with SwapToDevice(model.first_stage_model, device): z = model.encode_first_stage(img) noise = torch.randn_like(z) @@ -336,11 +336,11 @@ def do_img2img( def denoiser(x, sigma, c): return model.denoiser(model.model, x, sigma, c) - with ModelOnDevice(model.denoiser, device): - with ModelOnDevice(model.model, device): + with SwapToDevice(model.denoiser, device): + with SwapToDevice(model.model, device): samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) - with ModelOnDevice(model.first_stage_model, device): + with SwapToDevice(model.first_stage_model, device): samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) @@ -352,7 +352,7 @@ def denoiser(x, sigma, c): return samples -class ModelOnDevice(object): +class SwapToDevice(object): def __init__( self, model: Union[torch.nn.Module, torch.Tensor], From 49fe53c1652c83d90eac77d420db12d35e3b773a Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 6 Aug 2023 19:21:17 -0700 Subject: [PATCH 21/56] use env var for sgm checkpoints path --- .github/workflows/test-inference.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-inference.yml b/.github/workflows/test-inference.yml index 88b879cc..38aed1e4 100644 --- a/.github/workflows/test-inference.yml +++ b/.github/workflows/test-inference.yml @@ -15,7 +15,7 @@ jobs: steps: - uses: actions/checkout@v3 - name: "Symlink checkpoints" - run: ln -s ${{vars.SGM_CHECKPOINTS_PATH}} checkpoints + run: ln -s ${{env.SGM_CHECKPOINTS}} checkpoints - name: "Setup python" uses: actions/setup-python@v4 with: From 7e7fee3f0fb337f549f9955127018c08046f1400 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 6 Aug 2023 19:22:59 -0700 Subject: [PATCH 22/56] system env var --- .github/workflows/test-inference.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-inference.yml b/.github/workflows/test-inference.yml index 38aed1e4..6cb6a837 100644 --- a/.github/workflows/test-inference.yml +++ b/.github/workflows/test-inference.yml @@ -15,7 +15,7 @@ jobs: steps: - uses: actions/checkout@v3 - name: "Symlink checkpoints" - run: ln -s ${{env.SGM_CHECKPOINTS}} checkpoints + run: ln -s $SGM_CHECKPOINTS checkpoints - name: "Setup python" uses: actions/setup-python@v4 with: From c4b7baf8963ffb01e0e8ef708175b13517175250 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 6 Aug 2023 19:58:52 -0700 Subject: [PATCH 23/56] Streamlit refactor (#105) * initial streamlit refactoring pass * cleanup and fixes * fix refiner strength * Modify params correctly * fix exception --- scripts/demo/sampling.py | 295 ++++++++++++---------------- scripts/demo/streamlit_helpers.py | 315 +++++++++++------------------- sgm/inference/api.py | 22 ++- 3 files changed, 242 insertions(+), 390 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 29cb3054..d6387177 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -1,8 +1,13 @@ +from dataclasses import asdict from pytorch_lightning import seed_everything +from sgm.inference.api import ( + SamplingParams, + ModelArchitecture, + SamplingPipeline, + model_specs, +) from sgm.inference.helpers import ( - do_img2img, - do_sample, get_unique_embedder_keys_from_conditioner, perform_save_locally, ) @@ -39,63 +44,6 @@ "3.0": (1728, 576), } -VERSION2SPECS = { - "SDXL-base-1.0": { - "H": 1024, - "W": 1024, - "C": 4, - "f": 8, - "is_legacy": False, - "config": "configs/inference/sd_xl_base.yaml", - "ckpt": "checkpoints/sd_xl_base_1.0.safetensors", - }, - "SDXL-base-0.9": { - "H": 1024, - "W": 1024, - "C": 4, - "f": 8, - "is_legacy": False, - "config": "configs/inference/sd_xl_base.yaml", - "ckpt": "checkpoints/sd_xl_base_0.9.safetensors", - }, - "SD-2.1": { - "H": 512, - "W": 512, - "C": 4, - "f": 8, - "is_legacy": True, - "config": "configs/inference/sd_2_1.yaml", - "ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors", - }, - "SD-2.1-768": { - "H": 768, - "W": 768, - "C": 4, - "f": 8, - "is_legacy": True, - "config": "configs/inference/sd_2_1_768.yaml", - "ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors", - }, - "SDXL-refiner-0.9": { - "H": 1024, - "W": 1024, - "C": 4, - "f": 8, - "is_legacy": True, - "config": "configs/inference/sd_xl_refiner.yaml", - "ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors", - }, - "SDXL-refiner-1.0": { - "H": 1024, - "W": 1024, - "C": 4, - "f": 8, - "is_legacy": True, - "config": "configs/inference/sd_xl_refiner.yaml", - "ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors", - }, -} - def load_img(display=True, key=None, device="cuda"): image = get_interactive_image(key=key) @@ -117,52 +65,48 @@ def load_img(display=True, key=None, device="cuda"): def run_txt2img( state, - version, - version_dict, - is_legacy=False, + version: str, + prompt: str, + negative_prompt: str, return_latents=False, - filter=None, stage2strength=None, ): - if version.startswith("SDXL-base"): - W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10) + spec: SamplingSpec = state.get("spec") + model: SamplingPipeline = state.get("model") + params: SamplingParams = state.get("params") + if version.startswith("stable-diffusion-xl") and version.endswith("-base"): + params.width, params.height = st.selectbox( + "Resolution:", list(SD_XL_BASE_RATIOS.values()), 10 + ) else: - H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048) - W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048) - C = version_dict["C"] - F = version_dict["f"] - - init_dict = { - "orig_width": W, - "orig_height": H, - "target_width": W, - "target_height": H, - } - value_dict = init_embedder_options( - get_unique_embedder_keys_from_conditioner(state["model"].conditioner), - init_dict, + params.height = int( + st.number_input("H", value=spec.height, min_value=64, max_value=2048) + ) + params.width = int( + st.number_input("W", value=spec.width, min_value=64, max_value=2048) + ) + + params = init_embedder_options( + get_unique_embedder_keys_from_conditioner(model.model.conditioner), + params=params, prompt=prompt, negative_prompt=negative_prompt, ) - sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength) + params, num_rows, num_cols = init_sampling(params=params) num_samples = num_rows * num_cols if st.button("Sample"): st.write(f"**Model I:** {version}") outputs = st.empty() st.text("Sampling") - out = do_sample( - state["model"], - sampler, - value_dict, - num_samples, - H, - W, - C, - F, - force_uc_zero_embeddings=["txt"] if not is_legacy else [], + out = model.text_to_image( + params=params, + prompt=prompt, + negative_prompt=negative_prompt, + samples=int(num_samples), return_latents=return_latents, - filter=filter, + noise_strength=stage2strength, + filter=state.get("filter"), ) show_samples(out, outputs) @@ -172,51 +116,45 @@ def run_txt2img( def run_img2img( state, - version_dict, - is_legacy=False, + prompt: str, + negative_prompt: str, return_latents=False, - filter=None, stage2strength=None, ): + model: SamplingPipeline = state.get("model") + params: SamplingParams = state.get("params") + img = load_img() if img is None: return None - H, W = img.shape[2], img.shape[3] - - init_dict = { - "orig_width": W, - "orig_height": H, - "target_width": W, - "target_height": H, - } - value_dict = init_embedder_options( - get_unique_embedder_keys_from_conditioner(state["model"].conditioner), - init_dict, + params.height, params.width = img.shape[2], img.shape[3] + + params = init_embedder_options( + get_unique_embedder_keys_from_conditioner(model.model.conditioner), + params=params, prompt=prompt, negative_prompt=negative_prompt, ) - strength = st.number_input( + params.img2img_strength = st.number_input( "**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0 ) - sampler, num_rows, num_cols = init_sampling( - img2img_strength=strength, - stage2strength=stage2strength, - ) + params, num_rows, num_cols = init_sampling(params=params) num_samples = num_rows * num_cols if st.button("Sample"): outputs = st.empty() st.text("Sampling") - out = do_img2img( - repeat(img, "1 ... -> n ...", n=num_samples), - state["model"], - sampler, - value_dict, - num_samples, - force_uc_zero_embeddings=["txt"] if not is_legacy else [], + out = model.image_to_image( + image=repeat(img, "1 ... -> n ...", n=num_samples), + params=params, + prompt=prompt, + negative_prompt=negative_prompt, + samples=int(num_samples), return_latents=return_latents, - filter=filter, + noise_strength=stage2strength, + filter=state.get("filter"), ) + show_samples(out, outputs) return out @@ -224,39 +162,29 @@ def run_img2img( def apply_refiner( input, state, - sampler, - num_samples, - prompt, - negative_prompt, - filter=None, + num_samples: int, + prompt: str, + negative_prompt: str, finish_denoising=False, ): - init_dict = { - "orig_width": input.shape[3] * 8, - "orig_height": input.shape[2] * 8, - "target_width": input.shape[3] * 8, - "target_height": input.shape[2] * 8, - } + model: SamplingPipeline = state.get("model") + params: SamplingParams = state.get("params") - value_dict = init_dict - value_dict["prompt"] = prompt - value_dict["negative_prompt"] = negative_prompt - - value_dict["crop_coords_top"] = 0 - value_dict["crop_coords_left"] = 0 - - value_dict["aesthetic_score"] = 6.0 - value_dict["negative_aesthetic_score"] = 2.5 + params.orig_width = input.shape[3] * 8 + params.orig_height = input.shape[2] * 8 + params.width = input.shape[3] * 8 + params.height = input.shape[2] * 8 st.warning(f"refiner input shape: {input.shape}") - samples = do_img2img( - input, - state["model"], - sampler, - value_dict, - num_samples, - skip_encode=True, - filter=filter, + + samples = model.refiner( + image=input, + params=params, + prompt=prompt, + negative_prompt=negative_prompt, + samples=num_samples, + return_latents=False, + filter=state.get("filter"), add_noise=not finish_denoising, ) @@ -265,28 +193,34 @@ def apply_refiner( if __name__ == "__main__": st.title("Stable Diffusion") - version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) - version_dict = VERSION2SPECS[version] + version = st.selectbox( + "Model Version", + [member.value for member in ModelArchitecture], + 0, + ) + version_enum = ModelArchitecture(version) + specs = model_specs[version_enum] mode = st.radio("Mode", ("txt2img", "img2img"), 0) st.write("__________________________") set_lowvram_mode(st.checkbox("Low vram mode", True)) - if version.startswith("SDXL-base"): + if str(version).startswith("stable-diffusion-xl"): add_pipeline = st.checkbox("Load SDXL-refiner?", False) st.write("__________________________") else: add_pipeline = False - seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)) + seed = int( + st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)) + ) seed_everything(seed) - save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version)) - - state = init_st(version_dict, load_filter=True) + save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version))) + state = init_st(model_specs[version_enum], load_filter=True) model = state["model"] - is_legacy = version_dict["is_legacy"] + is_legacy = specs.is_legacy prompt = st.text_input( "prompt", @@ -302,46 +236,59 @@ def apply_refiner( if add_pipeline: st.write("__________________________") - version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"]) + version2 = ModelArchitecture( + st.selectbox( + "Refiner:", + [ + ModelArchitecture.SDXL_V1_REFINER.value, + ModelArchitecture.SDXL_V0_9_REFINER.value, + ], + ) + ) st.warning( f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) " ) st.write("**Refiner Options:**") - version_dict2 = VERSION2SPECS[version2] - state2 = init_st(version_dict2, load_filter=False) + specs2 = model_specs[version2] + state2 = init_st(specs2, load_filter=False) + params2 = state2["params"] - stage2strength = st.number_input( + params2.img2img_strength = st.number_input( "**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0 ) - sampler2, *_ = init_sampling( + params2, *_ = init_sampling( key=2, - img2img_strength=stage2strength, + params=state2["params"], specify_num_samples=False, ) st.write("__________________________") finish_denoising = st.checkbox("Finish denoising with refiner.", True) - if not finish_denoising: + if finish_denoising: + stage2strength = params2.img2img_strength + else: stage2strength = None + else: + state2 = None + params2 = None + stage2strength = None if mode == "txt2img": out = run_txt2img( - state, - version, - version_dict, - is_legacy=is_legacy, + state=state, + version=str(version), + prompt=prompt, + negative_prompt=negative_prompt, return_latents=add_pipeline, - filter=state.get("filter"), stage2strength=stage2strength, ) elif mode == "img2img": out = run_img2img( - state, - version_dict, - is_legacy=is_legacy, + state=state, + prompt=prompt, + negative_prompt=negative_prompt, return_latents=add_pipeline, - filter=state.get("filter"), stage2strength=stage2strength, ) else: @@ -356,13 +303,11 @@ def apply_refiner( outputs = st.empty() st.write("**Running Refinement Stage**") samples = apply_refiner( - samples_z, - state2, - sampler2, - samples_z.shape[0], + input=samples_z, + state=state2, + num_samples=samples_z.shape[0], prompt=prompt, negative_prompt=negative_prompt if is_legacy else "", - filter=state.get("filter"), finish_denoising=finish_denoising, ) show_samples(samples, outputs) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index d3dd7d71..9a90f3be 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -4,43 +4,46 @@ import streamlit as st import torch from einops import rearrange, repeat -from omegaconf import OmegaConf from PIL import Image from torchvision import transforms +from typing import Optional, Tuple from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering -from sgm.modules.diffusionmodules.sampling import ( - DPMPP2MSampler, - DPMPP2SAncestralSampler, - EulerAncestralSampler, - EulerEDMSampler, - HeunEDMSampler, - LinearMultistepSampler, + +from sgm.inference.api import ( + Discretization, + Guider, + Sampler, + SamplingParams, + SamplingSpec, + SamplingPipeline, + Thresholder, ) from sgm.inference.helpers import ( - Img2ImgDiscretizationWrapper, - Txt2NoisyDiscretizationWrapper, embed_watermark, ) -from sgm.util import load_model_from_config @st.cache_resource() -def init_st(version_dict, load_ckpt=True, load_filter=True): +def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True): + global lowvram_mode state = dict() if not "model" in state: - config = version_dict["config"] - ckpt = version_dict["ckpt"] + config = spec.config + ckpt = spec.ckpt - config = OmegaConf.load(config) - model = load_model_from_config( - config, ckpt if load_ckpt else None, freeze=False + pipeline = SamplingPipeline( + model_spec=spec, + use_fp16=lowvram_mode, + device="cpu" if lowvram_mode else "cuda", ) - state["model"] = model + state["spec"] = spec + state["model"] = pipeline state["ckpt"] = ckpt if load_ckpt else None state["config"] = config + state["params"] = SamplingParams() if load_filter: state["filter"] = DeepFloydDataFiltering(verbose=False) return state @@ -54,23 +57,13 @@ def set_lowvram_mode(mode): lowvram_mode = mode -def initial_model_load(model): - global lowvram_mode - if lowvram_mode: - model.model.half() - else: - model.cuda() - return model - - def get_unique_embedder_keys_from_conditioner(conditioner): return list(set([x.input_key for x in conditioner.embedders])) -def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): - # Hardcoded demo settings; might undergo some changes in the future - - value_dict = {} +def init_embedder_options( + keys, params: SamplingParams, prompt=None, negative_prompt=None +) -> SamplingParams: for key in keys: if key == "txt": if prompt is None: @@ -80,40 +73,32 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): if negative_prompt is None: negative_prompt = st.text_input("Negative prompt", "") - value_dict["prompt"] = prompt - value_dict["negative_prompt"] = negative_prompt - if key == "original_size_as_tuple": orig_width = st.number_input( "orig_width", - value=init_dict["orig_width"], + value=params.orig_width, min_value=16, ) orig_height = st.number_input( "orig_height", - value=init_dict["orig_height"], + value=params.orig_height, min_value=16, ) - value_dict["orig_width"] = orig_width - value_dict["orig_height"] = orig_height + params.orig_width = int(orig_width) + params.orig_height = int(orig_height) if key == "crop_coords_top_left": - crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0) - crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0) - - value_dict["crop_coords_top"] = crop_coord_top - value_dict["crop_coords_left"] = crop_coord_left - - if key == "aesthetic_score": - value_dict["aesthetic_score"] = 6.0 - value_dict["negative_aesthetic_score"] = 2.5 - - if key == "target_size_as_tuple": - value_dict["target_width"] = init_dict["target_width"] - value_dict["target_height"] = init_dict["target_height"] + crop_coord_top = st.number_input( + "crop_coords_top", value=params.crop_coords_top, min_value=0 + ) + crop_coord_left = st.number_input( + "crop_coords_left", value=params.crop_coords_left, min_value=0 + ) - return value_dict + params.crop_coords_top = int(crop_coord_top) + params.crop_coords_left = int(crop_coord_left) + return params def perform_save_locally(save_path, samples): @@ -146,24 +131,18 @@ def show_samples(samples, outputs): outputs.image(grid.cpu().numpy()) -def get_guider(key): - guider = st.sidebar.selectbox( - f"Discretization #{key}", - [ - "VanillaCFG", - "IdentityGuider", - ], +def get_guider(key, params: SamplingParams) -> SamplingParams: + params.guider = Guider( + st.sidebar.selectbox( + f"Discretization #{key}", [member.value for member in Guider] + ) ) - if guider == "IdentityGuider": - guider_config = { - "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" - } - elif guider == "VanillaCFG": + if params.guider == Guider.VANILLA: scale = st.number_input( - f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0 + f"cfg-scale #{key}", value=params.scale, min_value=0.0, max_value=100.0 ) - + params.scale = scale thresholder = st.sidebar.selectbox( f"Thresholder #{key}", [ @@ -172,173 +151,97 @@ def get_guider(key): ) if thresholder == "None": - dyn_thresh_config = { - "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" - } + params.thresholder = Thresholder.NONE else: raise NotImplementedError - - guider_config = { - "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", - "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config}, - } - else: - raise NotImplementedError - return guider_config + return params def init_sampling( key=1, - img2img_strength=1.0, + params: SamplingParams = SamplingParams(), specify_num_samples=True, - stage2strength=None, -): +) -> Tuple[SamplingParams, int, int]: + params = SamplingParams(img2img_strength=params.img2img_strength) + num_rows, num_cols = 1, 1 if specify_num_samples: num_cols = st.number_input( f"num cols #{key}", value=2, min_value=1, max_value=10 ) - steps = st.sidebar.number_input( - f"steps #{key}", value=40, min_value=1, max_value=1000 + params.steps = int( + st.sidebar.number_input( + f"steps #{key}", value=params.steps, min_value=1, max_value=1000 + ) ) - sampler = st.sidebar.selectbox( - f"Sampler #{key}", - [ - "EulerEDMSampler", - "HeunEDMSampler", - "EulerAncestralSampler", - "DPMPP2SAncestralSampler", - "DPMPP2MSampler", - "LinearMultistepSampler", - ], - 0, + + params.sampler = Sampler( + st.sidebar.selectbox( + f"Sampler #{key}", + [member.value for member in Sampler], + 0, + ) ) - discretization = st.sidebar.selectbox( - f"Discretization #{key}", - [ - "LegacyDDPMDiscretization", - "EDMDiscretization", - ], + params.discretization = Discretization( + st.sidebar.selectbox( + f"Discretization #{key}", + [member.value for member in Discretization], + ) ) - discretization_config = get_discretization(discretization, key=key) + params = get_discretization(params, key=key) + + params = get_guider(key=key, params=params) + + params = get_sampler(params, key=key) + return params, num_rows, num_cols + + +def get_discretization(params: SamplingParams, key=1) -> SamplingParams: + if params.discretization == Discretization.EDM: + params.sigma_min = st.number_input( + f"sigma_min #{key}", value=params.sigma_min + ) # 0.0292 + params.sigma_max = st.number_input( + f"sigma_max #{key}", value=params.sigma_max + ) # 14.6146 + params.rho = st.number_input(f"rho #{key}", value=params.rho) + return params - guider_config = get_guider(key=key) - sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key) - if img2img_strength < 1.0: - st.warning( - f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper" +def get_sampler(params: SamplingParams, key=1) -> SamplingParams: + if params.sampler == Sampler.EULER_EDM or params.sampler == Sampler.HEUN_EDM: + params.s_churn = st.sidebar.number_input( + f"s_churn #{key}", value=params.s_churn, min_value=0.0 ) - sampler.discretization = Img2ImgDiscretizationWrapper( - sampler.discretization, strength=img2img_strength + params.s_tmin = st.sidebar.number_input( + f"s_tmin #{key}", value=params.s_tmin, min_value=0.0 ) - if stage2strength is not None: - sampler.discretization = Txt2NoisyDiscretizationWrapper( - sampler.discretization, strength=stage2strength, original_steps=steps + params.s_tmax = st.sidebar.number_input( + f"s_tmax #{key}", value=params.s_tmax, min_value=0.0 ) - return sampler, num_rows, num_cols - - -def get_discretization(discretization, key=1): - if discretization == "LegacyDDPMDiscretization": - discretization_config = { - "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", - } - elif discretization == "EDMDiscretization": - sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292 - sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146 - rho = st.number_input(f"rho #{key}", value=3.0) - discretization_config = { - "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", - "params": { - "sigma_min": sigma_min, - "sigma_max": sigma_max, - "rho": rho, - }, - } - - return discretization_config - - -def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1): - if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler": - s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0) - s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0) - s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0) - s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0) - - if sampler_name == "EulerEDMSampler": - sampler = EulerEDMSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - s_churn=s_churn, - s_tmin=s_tmin, - s_tmax=s_tmax, - s_noise=s_noise, - verbose=True, - ) - elif sampler_name == "HeunEDMSampler": - sampler = HeunEDMSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - s_churn=s_churn, - s_tmin=s_tmin, - s_tmax=s_tmax, - s_noise=s_noise, - verbose=True, - ) + params.s_noise = st.sidebar.number_input( + f"s_noise #{key}", value=params.s_noise, min_value=0.0 + ) + elif ( - sampler_name == "EulerAncestralSampler" - or sampler_name == "DPMPP2SAncestralSampler" + params.sampler == Sampler.EULER_ANCESTRAL + or params.sampler == Sampler.DPMPP2S_ANCESTRAL ): - s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0) - eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0) - - if sampler_name == "EulerAncestralSampler": - sampler = EulerAncestralSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - eta=eta, - s_noise=s_noise, - verbose=True, - ) - elif sampler_name == "DPMPP2SAncestralSampler": - sampler = DPMPP2SAncestralSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - eta=eta, - s_noise=s_noise, - verbose=True, - ) - elif sampler_name == "DPMPP2MSampler": - sampler = DPMPP2MSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - verbose=True, - ) - elif sampler_name == "LinearMultistepSampler": - order = st.sidebar.number_input("order", value=4, min_value=1) - sampler = LinearMultistepSampler( - num_steps=steps, - discretization_config=discretization_config, - guider_config=guider_config, - order=order, - verbose=True, + params.s_noise = st.sidebar.number_input( + "s_noise", value=params.s_noise, min_value=0.0 ) - else: - raise ValueError(f"unknown sampler {sampler_name}!") + params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0) - return sampler + elif params.sampler == Sampler.LINEAR_MULTISTEP: + params.order = int( + st.sidebar.number_input("order", value=params.order, min_value=1) + ) + return params -def get_interactive_image(key=None) -> Image.Image: +def get_interactive_image(key=None) -> Optional[Image.Image]: image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key) if image is not None: image = Image.open(image) @@ -347,7 +250,7 @@ def get_interactive_image(key=None) -> Image.Image: return image -def load_img(display=True, key=None): +def load_img(display=True, key=None) -> torch.Tensor: image = get_interactive_image(key=key) if image is None: return None diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 89b73700..ad6aecc9 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -22,12 +22,12 @@ class ModelArchitecture(str, Enum): - SD_2_1 = "stable-diffusion-v2-1" - SD_2_1_768 = "stable-diffusion-v2-1-768" - SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base" - SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner" SDXL_V1_BASE = "stable-diffusion-xl-v1-base" SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner" + SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base" + SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner" + SD_2_1 = "stable-diffusion-v2-1" + SD_2_1_768 = "stable-diffusion-v2-1-768" class Sampler(str, Enum): @@ -58,7 +58,7 @@ class SamplingParams: width: int = 1024 height: int = 1024 steps: int = 40 - sampler: Sampler = Sampler.DPMPP2M + sampler: Sampler = Sampler.EULER_EDM discretization: Discretization = Discretization.LEGACY_DDPM guider: Guider = Guider.VANILLA thresholder: Thresholder = Thresholder.NONE @@ -227,6 +227,7 @@ def text_to_image( samples: int = 1, return_latents: bool = False, noise_strength=None, + filter=None, ): sampler = get_sampler_config(params) @@ -253,7 +254,7 @@ def text_to_image( self.specs.factor, force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, - filter=None, + filter=filter, ) def image_to_image( @@ -265,6 +266,7 @@ def image_to_image( samples: int = 1, return_latents: bool = False, noise_strength=None, + filter=None, ): sampler = get_sampler_config(params) @@ -289,7 +291,7 @@ def image_to_image( samples, force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, - filter=None, + filter=filter, ) def wrap_discretization( @@ -327,6 +329,8 @@ def refiner( ), samples: int = 1, return_latents: bool = False, + filter=None, + add_noise=False, ): sampler = get_sampler_config(params) value_dict = { @@ -354,8 +358,8 @@ def refiner( samples, skip_encode=True, return_latents=return_latents, - filter=None, - add_noise=False, + filter=filter, + add_noise=add_noise, ) From a726ce3eb71a8a97b7e8c071b156c11e64dfaa13 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 9 Aug 2023 12:30:43 -0700 Subject: [PATCH 24/56] replace usage of get --- scripts/demo/sampling.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index d6387177..f5355bc8 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -71,9 +71,9 @@ def run_txt2img( return_latents=False, stage2strength=None, ): - spec: SamplingSpec = state.get("spec") - model: SamplingPipeline = state.get("model") - params: SamplingParams = state.get("params") + spec: SamplingSpec = state["spec"] + model: SamplingPipeline = state["model"] + params: SamplingParams = state["params"] if version.startswith("stable-diffusion-xl") and version.endswith("-base"): params.width, params.height = st.selectbox( "Resolution:", list(SD_XL_BASE_RATIOS.values()), 10 @@ -106,7 +106,7 @@ def run_txt2img( samples=int(num_samples), return_latents=return_latents, noise_strength=stage2strength, - filter=state.get("filter"), + filter=state["filter"], ) show_samples(out, outputs) @@ -121,8 +121,8 @@ def run_img2img( return_latents=False, stage2strength=None, ): - model: SamplingPipeline = state.get("model") - params: SamplingParams = state.get("params") + model: SamplingPipeline = state["model"] + params: SamplingParams = state["params"] img = load_img() if img is None: @@ -152,7 +152,7 @@ def run_img2img( samples=int(num_samples), return_latents=return_latents, noise_strength=stage2strength, - filter=state.get("filter"), + filter=state["filter"], ) show_samples(out, outputs) @@ -167,8 +167,8 @@ def apply_refiner( negative_prompt: str, finish_denoising=False, ): - model: SamplingPipeline = state.get("model") - params: SamplingParams = state.get("params") + model: SamplingPipeline = state["model"] + params: SamplingParams = state["params"] params.orig_width = input.shape[3] * 8 params.orig_height = input.shape[2] * 8 @@ -184,7 +184,7 @@ def apply_refiner( negative_prompt=negative_prompt, samples=num_samples, return_latents=False, - filter=state.get("filter"), + filter=state["filter"], add_noise=not finish_denoising, ) From f86ffac274f0574f95f19d01c70c57870d75fcfb Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 9 Aug 2023 12:38:44 -0700 Subject: [PATCH 25/56] context manager --- scripts/demo/streamlit_helpers.py | 8 ++--- sgm/inference/helpers.py | 60 ++++++++++++++++--------------- 2 files changed, 33 insertions(+), 35 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 9a90f3be..70b7d065 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -200,12 +200,8 @@ def init_sampling( def get_discretization(params: SamplingParams, key=1) -> SamplingParams: if params.discretization == Discretization.EDM: - params.sigma_min = st.number_input( - f"sigma_min #{key}", value=params.sigma_min - ) # 0.0292 - params.sigma_max = st.number_input( - f"sigma_max #{key}", value=params.sigma_max - ) # 14.6146 + params.sigma_min = st.number_input(f"sigma_min #{key}", value=params.sigma_min) + params.sigma_max = st.number_input(f"sigma_max #{key}", value=params.sigma_max) params.rho = st.number_input(f"rho #{key}", value=params.rho) return params diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index a0c9e221..aa9e8cda 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -1,3 +1,4 @@ +import contextlib import os from typing import Union, List, Optional @@ -352,34 +353,35 @@ def denoiser(x, sigma, c): return samples -class SwapToDevice(object): - def __init__( - self, - model: Union[torch.nn.Module, torch.Tensor], - device: Union[torch.device, str], - ): - self.model = model - self.device = torch.device(device) - if isinstance(model, torch.Tensor): - self.original_device = model.device +@contextlib.contextmanager +def SwapToDevice( + model: Union[torch.nn.Module, torch.Tensor], device: Union[torch.device, str] +): + """ + Context manager that swaps a model or tensor to a device, and then swaps it back to its original device + when the context is exited. + """ + if isinstance(model, torch.Tensor): + original_device = model.device + else: + param = next(model.parameters(), None) + if param is not None: + original_device = param.device else: - param = next(model.parameters(), None) - if param is not None: - self.original_device = param.device + buf = next(model.buffers(), None) + if buf is not None: + original_device = buf.device else: - buf = next(model.buffers(), None) - if buf is not None: - self.original_device = buf.device - else: - # If device could not be found, turn this into a no-op - self.original_device = self.device - - def __enter__(self): - if self.device != self.original_device: - self.model.to(self.device) - - def __exit__(self, *args): - if self.device != self.original_device: - self.model.to(self.original_device) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + # If device could not be found, do nothing + return + device = torch.device(device) + + if device != original_device: + model.to(device) + + yield + + if device != original_device: + model.to(original_device) + if torch.cuda.is_available(): + torch.cuda.empty_cache() From a009aa8a9f58918beff28cbd62bcdb1615986c1a Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 9 Aug 2023 13:27:30 -0700 Subject: [PATCH 26/56] adding some typing --- scripts/demo/streamlit_helpers.py | 9 +++++---- sgm/inference/api.py | 8 +++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 70b7d065..c25284f5 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -6,7 +6,7 @@ from einops import rearrange, repeat from PIL import Image from torchvision import transforms -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict, Any from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering @@ -26,9 +26,9 @@ @st.cache_resource() -def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True): +def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, Any]: global lowvram_mode - state = dict() + state: Dict[str, Any] = dict() if not "model" in state: config = spec.config ckpt = spec.ckpt @@ -244,9 +244,10 @@ def get_interactive_image(key=None) -> Optional[Image.Image]: if not image.mode == "RGB": image = image.convert("RGB") return image + return None -def load_img(display=True, key=None) -> torch.Tensor: +def load_img(display=True, key=None) -> Optional[torch.Tensor]: image = get_interactive_image(key=key) if image is None: return None diff --git a/sgm/inference/api.py b/sgm/inference/api.py index ad6aecc9..668cc65d 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -18,7 +18,7 @@ LinearMultistepSampler, ) from sgm.util import load_model_from_config -from typing import Optional +from typing import Optional, Dict, Any class ModelArchitecture(str, Enum): @@ -363,7 +363,8 @@ def refiner( ) -def get_guider_config(params: SamplingParams): +def get_guider_config(params: SamplingParams) -> Dict[str, Any]: + guider_config: Dict[str, Any] if params.guider == Guider.IDENTITY: guider_config = { "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" @@ -389,7 +390,8 @@ def get_guider_config(params: SamplingParams): return guider_config -def get_discretization_config(params: SamplingParams): +def get_discretization_config(params: SamplingParams) -> Dict[str, Any]: + discretization_config: Dict[str, Any] if params.discretization == Discretization.LEGACY_DDPM: discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", From 725bea9f75caa383b4e933e3b1aebb92ced6711a Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 9 Aug 2023 13:29:16 -0700 Subject: [PATCH 27/56] pull in import fix --- scripts/demo/sampling.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index f5355bc8..ef146f3e 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -1,7 +1,13 @@ -from dataclasses import asdict +import os + +import numpy as np +import streamlit as st +import torch +from einops import repeat from pytorch_lightning import seed_everything from sgm.inference.api import ( + SamplingSpec, SamplingParams, ModelArchitecture, SamplingPipeline, @@ -11,7 +17,17 @@ get_unique_embedder_keys_from_conditioner, perform_save_locally, ) -from scripts.demo.streamlit_helpers import * +from scripts.demo.streamlit_helpers import ( + get_interactive_image, + get_unique_embedder_keys_from_conditioner, + init_embedder_options, + init_sampling, + init_save_locally, + init_st, + perform_save_locally, + set_lowvram_mode, + show_samples, +) SAVE_PATH = "outputs/demo/txt2img/" From d245e2002fa6b2b0eb6826a954d738a6481c9505 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 9 Aug 2023 13:46:06 -0700 Subject: [PATCH 28/56] more types --- sgm/inference/api.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 668cc65d..e3f3d17d 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -18,6 +18,7 @@ LinearMultistepSampler, ) from sgm.util import load_model_from_config +import torch from typing import Optional, Dict, Any @@ -226,8 +227,8 @@ def text_to_image( negative_prompt: str = "", samples: int = 1, return_latents: bool = False, - noise_strength=None, - filter=None, + noise_strength: Optional[float] = None, + filter: Any = None, ): sampler = get_sampler_config(params) @@ -260,13 +261,13 @@ def text_to_image( def image_to_image( self, params: SamplingParams, - image, + image: torch.Tensor, prompt: str, negative_prompt: str = "", samples: int = 1, return_latents: bool = False, - noise_strength=None, - filter=None, + noise_strength: Optional[float] = None, + filter: Any = None, ): sampler = get_sampler_config(params) @@ -321,7 +322,7 @@ def wrap_discretization( def refiner( self, - image, + image: torch.Tensor, prompt: str, negative_prompt: str = "", params: SamplingParams = SamplingParams( @@ -329,8 +330,8 @@ def refiner( ), samples: int = 1, return_latents: bool = False, - filter=None, - add_noise=False, + filter: Any = None, + add_noise: bool = False, ): sampler = get_sampler_config(params) value_dict = { From b51c36b0dffefaa4a552e75427868dcb19ad508e Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 9 Aug 2023 19:31:59 -0700 Subject: [PATCH 29/56] extract path resolution method, fix/improve device swapping support --- scripts/demo/streamlit_helpers.py | 11 +++--- sgm/inference/api.py | 61 ++++++++++++++++++++----------- 2 files changed, 45 insertions(+), 27 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index c25284f5..2d6972e8 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -33,11 +33,12 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A config = spec.config ckpt = spec.ckpt - pipeline = SamplingPipeline( - model_spec=spec, - use_fp16=lowvram_mode, - device="cpu" if lowvram_mode else "cuda", - ) + if lowvram_mode: + pipeline = SamplingPipeline( + model_spec=spec, use_fp16=True, device="cuda", swap_device="cpu" + ) + else: + pipeline = SamplingPipeline(model_spec=spec, use_fp16=True, device="cuda") state["spec"] = spec state["model"] = pipeline diff --git a/sgm/inference/api.py b/sgm/inference/api.py index e3f3d17d..ecdf7066 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -19,7 +19,7 @@ ) from sgm.util import load_model_from_config import torch -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, Union class ModelArchitecture(str, Enum): @@ -163,11 +163,24 @@ def __init__( self, model_id: Optional[ModelArchitecture] = None, model_spec: Optional[SamplingSpec] = None, - model_path=None, - config_path=None, - device="cuda", - use_fp16=True, + model_path: Optional[Union[str, pathlib.Path]] = None, + config_path: Optional[Union[str, pathlib.Path]] = None, + device: Union[str, torch.Device] = "cuda", + swap_device: Optional[Union[str, torch.Device]] = None, + use_fp16: bool = True, ) -> None: + """ + Sampling pipeline for generating images from a model. + + @param model_id: Model architecture to use. If not specified, model_spec must be specified. + @param model_spec: Model specification to use. If not specified, model_id must be specified. + @param model_path: Path to model checkpoints folder. + @param config_path: Path to model config folder. + @param device: Device to use for sampling. + @param swap_device: Device to swap models to when not in use. + @param use_fp16: Whether to use fp16 for sampling. + """ + self.model_id = model_id if model_spec is not None: self.specs = model_spec @@ -179,23 +192,9 @@ def __init__( raise ValueError("Either model_id or model_spec should be provided") if model_path is None: - model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints" - if not os.path.exists(model_path): - # This supports development installs where checkpoints is root level of the repo - model_path = ( - pathlib.Path(__file__).parent.parent.parent.resolve() - / "checkpoints" - ) + model_path = self._resolve_default_path("checkpoints") if config_path is None: - config_path = ( - pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" - ) - if not os.path.exists(config_path): - # This supports development installs where configs is root level of the repo - config_path = ( - pathlib.Path(__file__).parent.parent.parent.resolve() - / "configs/inference" - ) + config_path = self._resolve_default_path("configs/inference") self.config = str(pathlib.Path(config_path) / self.specs.config) self.ckpt = str(pathlib.Path(model_path) / self.specs.ckpt) if not os.path.exists(self.config): @@ -207,7 +206,22 @@ def __init__( f"Checkpoint {self.ckpt} not found, check model spec or config_path" ) self.device = device - self.model = self._load_model(device=device, use_fp16=use_fp16) + self.swap_device = swap_device + load_device = device if swap_device is None else swap_device + self.model = self._load_model(device=load_device, use_fp16=use_fp16) + + def _resolve_default_path(self, suffix: str) -> pathlib.Path: + # Resolves a path relative to the root of the module or repo + repo_path = pathlib.Path(__file__).parent.parent.parent.resolve() / suffix + module_path = pathlib.Path(__file__).parent.parent.resolve() / suffix + path = module_path / suffix + if not os.path.exists(path): + path = repo_path / suffix + if not os.path.exists(path): + raise ValueError( + f"Default locations for {suffix} not found, please specify path" + ) + return pathlib.Path(path) def _load_model(self, device="cuda", use_fp16=True): config = OmegaConf.load(self.config) @@ -256,6 +270,7 @@ def text_to_image( force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, filter=filter, + device=self.device, ) def image_to_image( @@ -293,6 +308,7 @@ def image_to_image( force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, filter=filter, + device=self.device, ) def wrap_discretization( @@ -361,6 +377,7 @@ def refiner( return_latents=return_latents, filter=filter, add_noise=add_noise, + device=self.device, ) From 8011d54ca1bcc9b0858e07024c68ff9afe69b3fa Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 03:19:37 -0700 Subject: [PATCH 30/56] some PR fixes --- scripts/demo/streamlit_helpers.py | 7 +++---- sgm/inference/api.py | 4 ++-- sgm/inference/helpers.py | 20 ++++++++++---------- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 2d6972e8..b6814ec2 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -191,11 +191,10 @@ def init_sampling( ) ) - params = get_discretization(params, key=key) + params = get_discretization(params=params, key=key) + params = get_guider(params=params, key=key) + params = get_sampler(params=params, key=key) - params = get_guider(key=key, params=params) - - params = get_sampler(params, key=key) return params, num_rows, num_cols diff --git a/sgm/inference/api.py b/sgm/inference/api.py index ecdf7066..082ca187 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -165,8 +165,8 @@ def __init__( model_spec: Optional[SamplingSpec] = None, model_path: Optional[Union[str, pathlib.Path]] = None, config_path: Optional[Union[str, pathlib.Path]] = None, - device: Union[str, torch.Device] = "cuda", - swap_device: Optional[Union[str, torch.Device]] = None, + device: Union[str, torch.device] = "cuda", + swap_device: Optional[Union[str, torch.device]] = None, use_fp16: bool = True, ) -> None: """ diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index aa9e8cda..68409a25 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -153,7 +153,7 @@ def do_sample( with autocast(device) as precision_scope: with model.ema_scope(): num_samples = [num_samples] - with SwapToDevice(model.conditioner, device): + with swap_to_device(model.conditioner, device): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, @@ -190,11 +190,11 @@ def denoiser(input, sigma, c): model.model, input, sigma, c, **additional_model_inputs ) - with SwapToDevice(model.denoiser, device): - with SwapToDevice(model.model, device): + with swap_to_device(model.denoiser, device): + with swap_to_device(model.model, device): samples_z = sampler(denoiser, randn, cond=c, uc=uc) - with SwapToDevice(model.first_stage_model, device): + with swap_to_device(model.first_stage_model, device): samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) @@ -294,7 +294,7 @@ def do_img2img( with torch.no_grad(): with autocast(device): with model.ema_scope(): - with SwapToDevice(model.conditioner, device): + with swap_to_device(model.conditioner, device): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, @@ -314,7 +314,7 @@ def do_img2img( if skip_encode: z = img else: - with SwapToDevice(model.first_stage_model, device): + with swap_to_device(model.first_stage_model, device): z = model.encode_first_stage(img) noise = torch.randn_like(z) @@ -337,11 +337,11 @@ def do_img2img( def denoiser(x, sigma, c): return model.denoiser(model.model, x, sigma, c) - with SwapToDevice(model.denoiser, device): - with SwapToDevice(model.model, device): + with swap_to_device(model.denoiser, device): + with swap_to_device(model.model, device): samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) - with SwapToDevice(model.first_stage_model, device): + with swap_to_device(model.first_stage_model, device): samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) @@ -354,7 +354,7 @@ def denoiser(x, sigma, c): @contextlib.contextmanager -def SwapToDevice( +def swap_to_device( model: Union[torch.nn.Module, torch.Tensor], device: Union[torch.device, str] ): """ From fc498bfaef8145233bdd48a594910ce35a377044 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 03:20:56 -0700 Subject: [PATCH 31/56] remove duplicate imports --- scripts/demo/sampling.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index ef146f3e..5a709e01 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -19,12 +19,10 @@ ) from scripts.demo.streamlit_helpers import ( get_interactive_image, - get_unique_embedder_keys_from_conditioner, init_embedder_options, init_sampling, init_save_locally, init_st, - perform_save_locally, set_lowvram_mode, show_samples, ) From e190ecc60bd04a8b4595995e6634d137f2f72876 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 04:35:59 -0700 Subject: [PATCH 32/56] path helper & model swapping rewrite --- scripts/demo/streamlit_helpers.py | 10 ++-- sgm/inference/api.py | 36 ++++-------- sgm/inference/helpers.py | 93 +++++++++++++++++++++---------- sgm/util.py | 18 ++++++ 4 files changed, 97 insertions(+), 60 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index b6814ec2..58387415 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -20,9 +20,7 @@ SamplingPipeline, Thresholder, ) -from sgm.inference.helpers import ( - embed_watermark, -) +from sgm.inference.helpers import embed_watermark, CudaModelLoader @st.cache_resource() @@ -35,10 +33,12 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A if lowvram_mode: pipeline = SamplingPipeline( - model_spec=spec, use_fp16=True, device="cuda", swap_device="cpu" + model_spec=spec, + use_fp16=True, + model_loader=CudaModelLoader(device="cuda", swap_device="cpu"), ) else: - pipeline = SamplingPipeline(model_spec=spec, use_fp16=True, device="cuda") + pipeline = SamplingPipeline(model_spec=spec, use_fp16=False) state["spec"] = spec state["model"] = pipeline diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 082ca187..0588a26e 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -2,10 +2,11 @@ from enum import Enum from omegaconf import OmegaConf import os -import pathlib from sgm.inference.helpers import ( do_sample, do_img2img, + BaseDeviceModelLoader, + CudaModelLoader, Img2ImgDiscretizationWrapper, Txt2NoisyDiscretizationWrapper, ) @@ -17,7 +18,7 @@ DPMPP2MSampler, LinearMultistepSampler, ) -from sgm.util import load_model_from_config +from sgm.util import load_model_from_config, get_configs_path, get_checkpoints_path import torch from typing import Optional, Dict, Any, Union @@ -163,11 +164,10 @@ def __init__( self, model_id: Optional[ModelArchitecture] = None, model_spec: Optional[SamplingSpec] = None, - model_path: Optional[Union[str, pathlib.Path]] = None, - config_path: Optional[Union[str, pathlib.Path]] = None, - device: Union[str, torch.device] = "cuda", - swap_device: Optional[Union[str, torch.device]] = None, + model_path: Optional[str] = None, + config_path: Optional[str] = None, use_fp16: bool = True, + model_loader: BaseDeviceModelLoader = CudaModelLoader(device="cuda"), ) -> None: """ Sampling pipeline for generating images from a model. @@ -176,9 +176,8 @@ def __init__( @param model_spec: Model specification to use. If not specified, model_id must be specified. @param model_path: Path to model checkpoints folder. @param config_path: Path to model config folder. - @param device: Device to use for sampling. - @param swap_device: Device to swap models to when not in use. @param use_fp16: Whether to use fp16 for sampling. + @param model_loader: Model loader class to use. Defaults to CudaModelLoader. """ self.model_id = model_id @@ -192,11 +191,11 @@ def __init__( raise ValueError("Either model_id or model_spec should be provided") if model_path is None: - model_path = self._resolve_default_path("checkpoints") + model_path = get_checkpoints_path() if config_path is None: - config_path = self._resolve_default_path("configs/inference") - self.config = str(pathlib.Path(config_path) / self.specs.config) - self.ckpt = str(pathlib.Path(model_path) / self.specs.ckpt) + config_path = get_configs_path() + self.config = os.path.join(config_path, "inference", self.specs.config) + self.ckpt = os.path.join(model_path, self.specs.ckpt) if not os.path.exists(self.config): raise ValueError( f"Config {self.config} not found, check model spec or config_path" @@ -210,19 +209,6 @@ def __init__( load_device = device if swap_device is None else swap_device self.model = self._load_model(device=load_device, use_fp16=use_fp16) - def _resolve_default_path(self, suffix: str) -> pathlib.Path: - # Resolves a path relative to the root of the module or repo - repo_path = pathlib.Path(__file__).parent.parent.parent.resolve() / suffix - module_path = pathlib.Path(__file__).parent.parent.resolve() / suffix - path = module_path / suffix - if not os.path.exists(path): - path = repo_path / suffix - if not os.path.exists(path): - raise ValueError( - f"Default locations for {suffix} not found, please specify path" - ) - return pathlib.Path(path) - def _load_model(self, device="cuda", use_fp16=True): config = OmegaConf.load(self.config) model = load_model_from_config(config, self.ckpt) diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 68409a25..bef2fb34 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -10,6 +10,7 @@ from imwatermark import WatermarkEncoder from omegaconf import ListConfig from torch import autocast +from abc import ABC, abstractmethod from sgm.util import append_dims @@ -353,35 +354,67 @@ def denoiser(x, sigma, c): return samples -@contextlib.contextmanager -def swap_to_device( - model: Union[torch.nn.Module, torch.Tensor], device: Union[torch.device, str] -): +class BaseDeviceModelLoader(ABC): """ - Context manager that swaps a model or tensor to a device, and then swaps it back to its original device - when the context is exited. + Base class for device managers. Device managers are used to manage the device used for a model. """ - if isinstance(model, torch.Tensor): - original_device = model.device - else: - param = next(model.parameters(), None) - if param is not None: - original_device = param.device - else: - buf = next(model.buffers(), None) - if buf is not None: - original_device = buf.device - else: - # If device could not be found, do nothing - return - device = torch.device(device) - - if device != original_device: - model.to(device) - - yield - - if device != original_device: - model.to(original_device) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + + @abstractmethod + def __init__(self, device: Union[torch.device, str]): + """ + Args: + device (Union[torch.device, str]): The device to use for the model. + """ + pass + + def load(self, model: torch.nn.Module): + """ + Loads a model to the device. + """ + pass + + @contextlib.contextmanager + def use(self, model: torch.nn.Module): + """ + Context manager that ensures a model is on the correct device during use. + """ + yield + + +class CudaModelLoader(BaseDeviceModelLoader): + """ + Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use. + """ + + def __init__( + self, + device: Union[torch.device, str] = "cuda", + swap_device: Union[torch.device, str] = None, + ): + """ + Args: + device (Union[torch.device, str]): The device to use for the model. + """ + self.device = torch.device(device) + self.swap_device = ( + torch.device(swap_device) if swap_device is not None else self.device + ) + + def load(self, model: Union[torch.nn.Module, torch.Tensor]): + """ + Loads a model to the device. + """ + model.to(self.swap_device) + + @contextlib.contextmanager + def use(self, model: Union[torch.nn.Module, torch.Tensor]): + """ + Context manager that ensures a model is on the correct device during use. + """ + if self.device != self.swap_device: + model.to(self.device) + yield + if self.device != self.swap_device: + model.to(self.swap_device) + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/sgm/util.py b/sgm/util.py index c5e68f4b..1f96aeb3 100644 --- a/sgm/util.py +++ b/sgm/util.py @@ -230,6 +230,24 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True): return model +def get_checkpoints_path() -> str: + """ + Get the `checkpoints` directory. + This could be in the root of the repository for a working copy, + or in the cwd for other use cases. + """ + this_dir = os.path.dirname(__file__) + candidates = ( + os.path.join(this_dir, "checkpoints"), + os.path.join(os.getcwd(), "checkpoints"), + ) + for candidate in candidates: + candidate = os.path.abspath(candidate) + if os.path.isdir(candidate): + return candidate + raise FileNotFoundError(f"Could not find SGM checkpoints in {candidates}") + + def get_configs_path() -> str: """ Get the `configs` directory. From 47805f233cae54019860ff4adb07e3d19e54cbb9 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 04:55:43 -0700 Subject: [PATCH 33/56] finish device manager refactor --- scripts/demo/streamlit_helpers.py | 4 +- sgm/inference/helpers.py | 178 ++++++++++++++++-------------- 2 files changed, 97 insertions(+), 85 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 58387415..5b0214ae 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -20,7 +20,7 @@ SamplingPipeline, Thresholder, ) -from sgm.inference.helpers import embed_watermark, CudaModelLoader +from sgm.inference.helpers import embed_watermark, CudaModelManager @st.cache_resource() @@ -35,7 +35,7 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A pipeline = SamplingPipeline( model_spec=spec, use_fp16=True, - model_loader=CudaModelLoader(device="cuda", swap_device="cpu"), + model_loader=CudaModelManager(device="cuda", swap_device="cpu"), ) else: pipeline = SamplingPipeline(model_spec=spec, use_fp16=False) diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index bef2fb34..095cf087 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -10,7 +10,6 @@ from imwatermark import WatermarkEncoder from omegaconf import ListConfig from torch import autocast -from abc import ABC, abstractmethod from sgm.util import append_dims @@ -60,6 +59,80 @@ def __call__(self, image: torch.Tensor): embed_watermark = WatermarkEmbedder(WATERMARK_BITS) +class DeviceModelManager(object): + """ + Default model loading class, should work for all device classes. + """ + + def __init__( + self, + device: Union[torch.device, str], + swap_device: Optional[Union[torch.device, str]] = None, + ): + """ + Args: + device (Union[torch.device, str]): The device to use for the model. + """ + self.device = torch.device(device) + self.swap_device = ( + torch.device(swap_device) if swap_device is not None else self.device + ) + + def load(self, model: torch.nn.Module): + """ + Loads a model to the device. + """ + return model.to(self.device) + + @contextlib.contextmanager + def use(self, model: torch.nn.Module): + """ + Context manager that ensures a model is on the correct device during use. + The default model loader does not perform any swapping, so the model will + stay on device. + """ + model.to(self.device) + yield + if self.device != self.swap_device: + model.to(self.swap_device) + + +class CudaModelManager(DeviceModelManager): + """ + Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use. + """ + + def __init__( + self, + device: Union[torch.device, str] = "cuda", + swap_device: Union[torch.device, str] = None, + ): + """ + Args: + device (Union[torch.device, str]): The device to use for the model. + """ + super().__init__(device, swap_device) + + def load(self, model: Union[torch.nn.Module, torch.Tensor]): + """ + Loads a model to the device. + """ + return model.to(self.device) + + @contextlib.contextmanager + def use(self, model: Union[torch.nn.Module, torch.Tensor]): + """ + Context manager that ensures a model is on the correct device during use. + If a swap device was provided, this will move the model to it after use and clear cache. + """ + model.to(self.device) + yield + if self.device != self.swap_device: + model.to(self.swap_device) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def get_unique_embedder_keys_from_conditioner(conditioner): return list({x.input_key for x in conditioner.embedders}) @@ -143,7 +216,7 @@ def do_sample( batch2model_input: Optional[List] = None, return_latents=False, filter=None, - device="cuda", + device_manager: DeviceModelManager = DeviceModelManager("cuda"), ): if force_uc_zero_embeddings is None: force_uc_zero_embeddings = [] @@ -151,10 +224,10 @@ def do_sample( batch2model_input = [] with torch.no_grad(): - with autocast(device) as precision_scope: + with autocast(device_manager.device): with model.ema_scope(): num_samples = [num_samples] - with swap_to_device(model.conditioner, device): + with device_manager.use(model.conditioner): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, @@ -176,7 +249,10 @@ def do_sample( for k in c: if not k == "crossattn": c[k], uc[k] = map( - lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc) + lambda y: y[k][: math.prod(num_samples)].to( + device_manager.device + ), + (c, uc), ) additional_model_inputs = {} @@ -184,18 +260,18 @@ def do_sample( additional_model_inputs[k] = batch[k] shape = (math.prod(num_samples), C, H // F, W // F) - randn = torch.randn(shape).to(device) + randn = torch.randn(shape).to(device_manager.device) def denoiser(input, sigma, c): return model.denoiser( model.model, input, sigma, c, **additional_model_inputs ) - with swap_to_device(model.denoiser, device): - with swap_to_device(model.model, device): + with device_manager.use(model.denoiser): + with device_manager.use(model.model): samples_z = sampler(denoiser, randn, cond=c, uc=uc) - with swap_to_device(model.first_stage_model, device): + with device_manager.use(model.first_stage_model): samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) @@ -290,12 +366,12 @@ def do_img2img( skip_encode=False, filter=None, add_noise=True, - device="cuda", + device_manager=DeviceModelManager("cuda"), ): with torch.no_grad(): - with autocast(device): + with autocast(device_manager.device): with model.ema_scope(): - with swap_to_device(model.conditioner, device): + with device_manager.use(model.conditioner): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, @@ -308,14 +384,16 @@ def do_img2img( ) for k in c: - c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc)) + c[k], uc[k] = map( + lambda y: y[k][:num_samples].to(device_manager.device), (c, uc) + ) for k in additional_kwargs: c[k] = uc[k] = additional_kwargs[k] if skip_encode: z = img else: - with swap_to_device(model.first_stage_model, device): + with device_manager.use(model.first_stage_model): z = model.encode_first_stage(img) noise = torch.randn_like(z) @@ -338,11 +416,11 @@ def do_img2img( def denoiser(x, sigma, c): return model.denoiser(model.model, x, sigma, c) - with swap_to_device(model.denoiser, device): - with swap_to_device(model.model, device): + with device_manager.use(model.denoiser): + with device_manager.use(model.model): samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) - with swap_to_device(model.first_stage_model, device): + with device_manager.use(model.first_stage_model): samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) @@ -352,69 +430,3 @@ def denoiser(x, sigma, c): if return_latents: return samples, samples_z return samples - - -class BaseDeviceModelLoader(ABC): - """ - Base class for device managers. Device managers are used to manage the device used for a model. - """ - - @abstractmethod - def __init__(self, device: Union[torch.device, str]): - """ - Args: - device (Union[torch.device, str]): The device to use for the model. - """ - pass - - def load(self, model: torch.nn.Module): - """ - Loads a model to the device. - """ - pass - - @contextlib.contextmanager - def use(self, model: torch.nn.Module): - """ - Context manager that ensures a model is on the correct device during use. - """ - yield - - -class CudaModelLoader(BaseDeviceModelLoader): - """ - Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use. - """ - - def __init__( - self, - device: Union[torch.device, str] = "cuda", - swap_device: Union[torch.device, str] = None, - ): - """ - Args: - device (Union[torch.device, str]): The device to use for the model. - """ - self.device = torch.device(device) - self.swap_device = ( - torch.device(swap_device) if swap_device is not None else self.device - ) - - def load(self, model: Union[torch.nn.Module, torch.Tensor]): - """ - Loads a model to the device. - """ - model.to(self.swap_device) - - @contextlib.contextmanager - def use(self, model: Union[torch.nn.Module, torch.Tensor]): - """ - Context manager that ensures a model is on the correct device during use. - """ - if self.device != self.swap_device: - model.to(self.device) - yield - if self.device != self.swap_device: - model.to(self.swap_device) - if torch.cuda.is_available(): - torch.cuda.empty_cache() From 9b18e6fa19c31abf2bc7b7816c7d287c1a43c23c Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 05:07:22 -0700 Subject: [PATCH 34/56] update api module --- sgm/inference/api.py | 21 +++++++++++---------- sgm/inference/helpers.py | 10 ++++++---- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 0588a26e..8516e733 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -5,8 +5,8 @@ from sgm.inference.helpers import ( do_sample, do_img2img, - BaseDeviceModelLoader, - CudaModelLoader, + DeviceModelManager, + CudaModelManager, Img2ImgDiscretizationWrapper, Txt2NoisyDiscretizationWrapper, ) @@ -167,7 +167,7 @@ def __init__( model_path: Optional[str] = None, config_path: Optional[str] = None, use_fp16: bool = True, - model_loader: BaseDeviceModelLoader = CudaModelLoader(device="cuda"), + device_manager: DeviceModelManager = CudaModelManager(device="cuda"), ) -> None: """ Sampling pipeline for generating images from a model. @@ -204,17 +204,18 @@ def __init__( raise ValueError( f"Checkpoint {self.ckpt} not found, check model spec or config_path" ) - self.device = device - self.swap_device = swap_device - load_device = device if swap_device is None else swap_device - self.model = self._load_model(device=load_device, use_fp16=use_fp16) - def _load_model(self, device="cuda", use_fp16=True): + self.model_manager = device_manager + self.model = self._load_model( + device_manager=self.model_manager, use_fp16=use_fp16 + ) + + def _load_model(self, device_manager: DeviceModelManager, use_fp16=True): config = OmegaConf.load(self.config) model = load_model_from_config(config, self.ckpt) if model is None: raise ValueError(f"Model {self.model_id} could not be loaded") - model.to(device) + device_manager.load(model) if use_fp16: model.conditioner.half() model.model.half() @@ -256,7 +257,7 @@ def text_to_image( force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, filter=filter, - device=self.device, + model_manager=self.model_manager, ) def image_to_image( diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 095cf087..314fe192 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -91,10 +91,12 @@ def use(self, model: torch.nn.Module): The default model loader does not perform any swapping, so the model will stay on device. """ - model.to(self.device) - yield - if self.device != self.swap_device: - model.to(self.swap_device) + try: + model.to(self.device) + yield + finally: + if self.device != self.swap_device: + model.to(self.swap_device) class CudaModelManager(DeviceModelManager): From de7a6279787221116e6be8ac94bc768b4f860dee Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 05:11:34 -0700 Subject: [PATCH 35/56] more fixes and cleanup --- scripts/demo/streamlit_helpers.py | 9 +++------ sgm/inference/api.py | 14 +++++++------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 5b0214ae..a9ff5e8e 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -35,7 +35,7 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A pipeline = SamplingPipeline( model_spec=spec, use_fp16=True, - model_loader=CudaModelManager(device="cuda", swap_device="cpu"), + device_manager=CudaModelManager(device="cuda", swap_device="cpu"), ) else: pipeline = SamplingPipeline(model_spec=spec, use_fp16=False) @@ -207,7 +207,7 @@ def get_discretization(params: SamplingParams, key=1) -> SamplingParams: def get_sampler(params: SamplingParams, key=1) -> SamplingParams: - if params.sampler == Sampler.EULER_EDM or params.sampler == Sampler.HEUN_EDM: + if params.sampler in (Sampler.EULER_EDM, Sampler.HEUN_EDM): params.s_churn = st.sidebar.number_input( f"s_churn #{key}", value=params.s_churn, min_value=0.0 ) @@ -221,10 +221,7 @@ def get_sampler(params: SamplingParams, key=1) -> SamplingParams: f"s_noise #{key}", value=params.s_noise, min_value=0.0 ) - elif ( - params.sampler == Sampler.EULER_ANCESTRAL - or params.sampler == Sampler.DPMPP2S_ANCESTRAL - ): + elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL): params.s_noise = st.sidebar.number_input( "s_noise", value=params.s_noise, min_value=0.0 ) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 8516e733..96aead65 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -205,9 +205,9 @@ def __init__( f"Checkpoint {self.ckpt} not found, check model spec or config_path" ) - self.model_manager = device_manager + self.device_manager = device_manager self.model = self._load_model( - device_manager=self.model_manager, use_fp16=use_fp16 + device_manager=self.device_manager, use_fp16=use_fp16 ) def _load_model(self, device_manager: DeviceModelManager, use_fp16=True): @@ -229,7 +229,7 @@ def text_to_image( samples: int = 1, return_latents: bool = False, noise_strength: Optional[float] = None, - filter: Any = None, + filter=None, ): sampler = get_sampler_config(params) @@ -257,7 +257,7 @@ def text_to_image( force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, filter=filter, - model_manager=self.model_manager, + device_manager=self.device_manager, ) def image_to_image( @@ -269,7 +269,7 @@ def image_to_image( samples: int = 1, return_latents: bool = False, noise_strength: Optional[float] = None, - filter: Any = None, + filter=None, ): sampler = get_sampler_config(params) @@ -295,7 +295,7 @@ def image_to_image( force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, filter=filter, - device=self.device, + device_manager=self.device_manager, ) def wrap_discretization( @@ -364,7 +364,7 @@ def refiner( return_latents=return_latents, filter=filter, add_noise=add_noise, - device=self.device, + device_manager=self.device_manager, ) From 3e7ada70c503622b474db9e55c281b63f36b3047 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 05:42:31 -0700 Subject: [PATCH 36/56] fix autocast --- sgm/inference/helpers.py | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 314fe192..f86eda6e 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -9,7 +9,6 @@ from einops import rearrange from imwatermark import WatermarkEncoder from omegaconf import ListConfig -from torch import autocast from sgm.util import append_dims @@ -84,6 +83,14 @@ def load(self, model: torch.nn.Module): """ return model.to(self.device) + def autocast(self): + """ + Context manager that enables autocast for the device if supported. + """ + if self.device.type not in ("cuda", "cpu"): + return contextlib.nullcontext() + return torch.autocast(self.device.type) + @contextlib.contextmanager def use(self, model: torch.nn.Module): """ @@ -104,23 +111,6 @@ class CudaModelManager(DeviceModelManager): Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use. """ - def __init__( - self, - device: Union[torch.device, str] = "cuda", - swap_device: Union[torch.device, str] = None, - ): - """ - Args: - device (Union[torch.device, str]): The device to use for the model. - """ - super().__init__(device, swap_device) - - def load(self, model: Union[torch.nn.Module, torch.Tensor]): - """ - Loads a model to the device. - """ - return model.to(self.device) - @contextlib.contextmanager def use(self, model: Union[torch.nn.Module, torch.Tensor]): """ @@ -226,7 +216,7 @@ def do_sample( batch2model_input = [] with torch.no_grad(): - with autocast(device_manager.device): + with device_manager.autocast(): with model.ema_scope(): num_samples = [num_samples] with device_manager.use(model.conditioner): @@ -371,7 +361,7 @@ def do_img2img( device_manager=DeviceModelManager("cuda"), ): with torch.no_grad(): - with autocast(device_manager.device): + with device_manager.autocast(): with model.ema_scope(): with device_manager.use(model.conditioner): batch, batch_uc = get_batch( From 26b10f56f33a2dfc50f705247aca2d81e4527cbb Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 12:24:12 -0700 Subject: [PATCH 37/56] fix missing index --- scripts/demo/streamlit_helpers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index a9ff5e8e..119ffd75 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -47,6 +47,8 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A state["params"] = SamplingParams() if load_filter: state["filter"] = DeepFloydDataFiltering(verbose=False) + else: + state["filter"] = None return state From a25662e969fc9a4e8df74e4917d2adca37621591 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 12:40:32 -0700 Subject: [PATCH 38/56] low vram checkbox fix, remove magic strings --- scripts/demo/sampling.py | 20 ++++++++++++-------- scripts/demo/streamlit_helpers.py | 11 ++--------- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 5a709e01..4dca18d7 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -23,7 +23,6 @@ init_sampling, init_save_locally, init_st, - set_lowvram_mode, show_samples, ) @@ -205,6 +204,16 @@ def apply_refiner( return samples +sdxl_base_model_list = [ + ModelArchitecture.SDXL_V1_BASE, + ModelArchitecture.SDXL_V0_9_BASE, +] + +sdxl_refiner_model_list = [ + ModelArchitecture.SDXL_V1_REFINER, + ModelArchitecture.SDXL_V0_9_REFINER, +] + if __name__ == "__main__": st.title("Stable Diffusion") version = st.selectbox( @@ -217,9 +226,7 @@ def apply_refiner( mode = st.radio("Mode", ("txt2img", "img2img"), 0) st.write("__________________________") - set_lowvram_mode(st.checkbox("Low vram mode", True)) - - if str(version).startswith("stable-diffusion-xl"): + if version_enum in sdxl_base_model_list: add_pipeline = st.checkbox("Load SDXL-refiner?", False) st.write("__________________________") else: @@ -253,10 +260,7 @@ def apply_refiner( version2 = ModelArchitecture( st.selectbox( "Refiner:", - [ - ModelArchitecture.SDXL_V1_REFINER.value, - ModelArchitecture.SDXL_V0_9_REFINER.value, - ], + [member.value for member in sdxl_refiner_model_list], ) ) st.warning( diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 119ffd75..59cd27ba 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -25,12 +25,13 @@ @st.cache_resource() def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, Any]: - global lowvram_mode state: Dict[str, Any] = dict() if not "model" in state: config = spec.config ckpt = spec.ckpt + lowvram_mode = st.checkbox("Low VRAM mode", value=False) + if lowvram_mode: pipeline = SamplingPipeline( model_spec=spec, @@ -52,14 +53,6 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A return state -lowvram_mode = False - - -def set_lowvram_mode(mode): - global lowvram_mode - lowvram_mode = mode - - def get_unique_embedder_keys_from_conditioner(conditioner): return list(set([x.input_key for x in conditioner.embedders])) From b3866d121890034ff26a1c815223af5847fdd7fd Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 12:44:48 -0700 Subject: [PATCH 39/56] move checkbox out of cached resource --- scripts/demo/sampling.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 4dca18d7..9bf0dfff 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -237,8 +237,12 @@ def apply_refiner( ) seed_everything(seed) + lowvram_mode = st.checkbox("Low vram mode", True) + save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version))) - state = init_st(model_specs[version_enum], load_filter=True) + state = init_st( + model_specs[version_enum], load_filter=True, lowvram_mode=lowvram_mode + ) model = state["model"] is_legacy = specs.is_legacy From 88395261d8fa7fb0b18764cf346b985f586f9b2f Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 12:45:37 -0700 Subject: [PATCH 40/56] update helpers --- scripts/demo/streamlit_helpers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 59cd27ba..b37e2ebd 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -24,14 +24,14 @@ @st.cache_resource() -def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, Any]: +def init_st( + spec: SamplingSpec, load_ckpt=True, load_filter=True, lowvram_mode=True +) -> Dict[str, Any]: state: Dict[str, Any] = dict() if not "model" in state: config = spec.config ckpt = spec.ckpt - lowvram_mode = st.checkbox("Low VRAM mode", value=False) - if lowvram_mode: pipeline = SamplingPipeline( model_spec=spec, From 3816aaa639612e99fe6a4c191cee41d76e0ad33c Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 13:05:30 -0700 Subject: [PATCH 41/56] simplify device_manager usage --- scripts/demo/streamlit_helpers.py | 2 +- sgm/inference/api.py | 21 ++++++++++++++++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index b37e2ebd..fec7d33e 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -36,7 +36,7 @@ def init_st( pipeline = SamplingPipeline( model_spec=spec, use_fp16=True, - device_manager=CudaModelManager(device="cuda", swap_device="cpu"), + device=CudaModelManager(device="cuda", swap_device="cpu"), ) else: pipeline = SamplingPipeline(model_spec=spec, use_fp16=False) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 96aead65..e680dc55 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -57,6 +57,13 @@ class Thresholder(str, Enum): @dataclass class SamplingParams: + """ + Parameters for sampling. + The defaults here are derived from user preference testing. + They will be subject to change in the future, likely pulled + from model specs instead of global defaults. + """ + width: int = 1024 height: int = 1024 steps: int = 40 @@ -167,7 +174,9 @@ def __init__( model_path: Optional[str] = None, config_path: Optional[str] = None, use_fp16: bool = True, - device_manager: DeviceModelManager = CudaModelManager(device="cuda"), + device: Union[DeviceModelManager, str, torch.device] = CudaModelManager( + device="cuda" + ), ) -> None: """ Sampling pipeline for generating images from a model. @@ -177,7 +186,7 @@ def __init__( @param model_path: Path to model checkpoints folder. @param config_path: Path to model config folder. @param use_fp16: Whether to use fp16 for sampling. - @param model_loader: Model loader class to use. Defaults to CudaModelLoader. + @param device: Device manager to use with this pipeline. If a string or torch.device is passed, a device manager will be created based on device type if possible. """ self.model_id = model_id @@ -205,7 +214,13 @@ def __init__( f"Checkpoint {self.ckpt} not found, check model spec or config_path" ) - self.device_manager = device_manager + if isinstance(device, torch.device) or isinstance(device, str): + if torch.device(device).type == "cuda": + self.device_manager = CudaModelManager(device=device) + else: + self.device_manager = DeviceModelManager(device=device) + else: + self.device_manager = device self.model = self._load_model( device_manager=self.device_manager, use_fp16=use_fp16 ) From 2aebc8882d864d7263988c019d223de737d5a3ac Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 13:14:38 -0700 Subject: [PATCH 42/56] split fp16 and swapping functionality --- scripts/demo/sampling.py | 12 +++++++++--- scripts/demo/streamlit_helpers.py | 12 ++++++++---- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 9bf0dfff..13d8db3b 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -226,6 +226,11 @@ def apply_refiner( mode = st.radio("Mode", ("txt2img", "img2img"), 0) st.write("__________________________") + st.write("### Performance Options") + use_fp16 = st.checkbox("Use fp16", True) + enable_swap = st.checkbox("Enable model swapping to CPU", False) + st.write("__________________________") + if version_enum in sdxl_base_model_list: add_pipeline = st.checkbox("Load SDXL-refiner?", False) st.write("__________________________") @@ -237,11 +242,12 @@ def apply_refiner( ) seed_everything(seed) - lowvram_mode = st.checkbox("Low vram mode", True) - save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version))) state = init_st( - model_specs[version_enum], load_filter=True, lowvram_mode=lowvram_mode + model_specs[version_enum], + load_filter=True, + use_fp16=use_fp16, + enable_swap=enable_swap, ) model = state["model"] diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index fec7d33e..84c4e628 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -25,21 +25,25 @@ @st.cache_resource() def init_st( - spec: SamplingSpec, load_ckpt=True, load_filter=True, lowvram_mode=True + spec: SamplingSpec, + load_ckpt=True, + load_filter=True, + use_fp16=True, + enable_swap=True, ) -> Dict[str, Any]: state: Dict[str, Any] = dict() if not "model" in state: config = spec.config ckpt = spec.ckpt - if lowvram_mode: + if enable_swap: pipeline = SamplingPipeline( model_spec=spec, - use_fp16=True, + use_fp16=use_fp16, device=CudaModelManager(device="cuda", swap_device="cpu"), ) else: - pipeline = SamplingPipeline(model_spec=spec, use_fp16=False) + pipeline = SamplingPipeline(model_spec=spec, use_fp16=use_fp16) state["spec"] = spec state["model"] = pipeline From 5c170434342ee8338a92cff132fb9a3b71286953 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 13:15:23 -0700 Subject: [PATCH 43/56] change default --- scripts/demo/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 13d8db3b..3856b2f1 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -228,7 +228,7 @@ def apply_refiner( st.write("### Performance Options") use_fp16 = st.checkbox("Use fp16", True) - enable_swap = st.checkbox("Enable model swapping to CPU", False) + enable_swap = st.checkbox("Enable model swapping to CPU", True) st.write("__________________________") if version_enum in sdxl_base_model_list: From cd81956241cceba1ad8b8983320637a7665fe8fc Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 13:31:03 -0700 Subject: [PATCH 44/56] text updates --- scripts/demo/sampling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 3856b2f1..87d6feb0 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -226,9 +226,9 @@ def apply_refiner( mode = st.radio("Mode", ("txt2img", "img2img"), 0) st.write("__________________________") - st.write("### Performance Options") - use_fp16 = st.checkbox("Use fp16", True) - enable_swap = st.checkbox("Enable model swapping to CPU", True) + st.write("**Performance Options:**") + use_fp16 = st.checkbox("Use fp16 (Saves VRAM)", True) + enable_swap = st.checkbox("Swap models to CPU (Saves VRAM, uses RAM)", True) st.write("__________________________") if version_enum in sdxl_base_model_list: From d6f2b7899429496aae347229a2f0f522cbb5c38b Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 10 Aug 2023 15:06:55 -0700 Subject: [PATCH 45/56] pass options into state2 init --- scripts/demo/sampling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 87d6feb0..a18b21aa 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -279,7 +279,9 @@ def apply_refiner( st.write("**Refiner Options:**") specs2 = model_specs[version2] - state2 = init_st(specs2, load_filter=False) + state2 = init_st( + specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap + ) params2 = state2["params"] params2.img2img_strength = st.number_input( From fe4632034b1ffcd46e9347b03b76e8c51104b14a Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Fri, 11 Aug 2023 16:31:53 -0700 Subject: [PATCH 46/56] fix for orig dimensions --- scripts/demo/streamlit_helpers.py | 43 +++++++++++++++---------------- sgm/inference/api.py | 6 +++-- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 84c4e628..93293dd7 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -32,28 +32,27 @@ def init_st( enable_swap=True, ) -> Dict[str, Any]: state: Dict[str, Any] = dict() - if not "model" in state: - config = spec.config - ckpt = spec.ckpt - - if enable_swap: - pipeline = SamplingPipeline( - model_spec=spec, - use_fp16=use_fp16, - device=CudaModelManager(device="cuda", swap_device="cpu"), - ) - else: - pipeline = SamplingPipeline(model_spec=spec, use_fp16=use_fp16) - - state["spec"] = spec - state["model"] = pipeline - state["ckpt"] = ckpt if load_ckpt else None - state["config"] = config - state["params"] = SamplingParams() - if load_filter: - state["filter"] = DeepFloydDataFiltering(verbose=False) - else: - state["filter"] = None + config = spec.config + ckpt = spec.ckpt + + if enable_swap: + pipeline = SamplingPipeline( + model_spec=spec, + use_fp16=use_fp16, + device=CudaModelManager(device="cuda", swap_device="cpu"), + ) + else: + pipeline = SamplingPipeline(model_spec=spec, use_fp16=use_fp16) + + state["spec"] = spec + state["model"] = pipeline + state["ckpt"] = ckpt if load_ckpt else None + state["config"] = config + state["params"] = SamplingParams() + if load_filter: + state["filter"] = DeepFloydDataFiltering(verbose=False) + else: + state["filter"] = None return state diff --git a/sgm/inference/api.py b/sgm/inference/api.py index e680dc55..eccf129b 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -75,8 +75,8 @@ class SamplingParams: aesthetic_score: float = 6.0 negative_aesthetic_score: float = 2.5 img2img_strength: float = 1.0 - orig_width: int = 1024 - orig_height: int = 1024 + orig_width: int = width + orig_height: int = height crop_coords_top: int = 0 crop_coords_left: int = 0 sigma_min: float = 0.0292 @@ -301,6 +301,8 @@ def image_to_image( value_dict["negative_prompt"] = negative_prompt value_dict["target_width"] = width value_dict["target_height"] = height + value_dict["orig_width"] = width + value_dict["orig_height"] = height return do_img2img( image, self.model, From d4307bef5d75435b67b6907fbb21e49a5efdce6b Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sat, 12 Aug 2023 07:15:36 +0000 Subject: [PATCH 47/56] Test model device manager and fix bugs --- pyproject.toml | 2 +- sgm/inference/api.py | 16 ++++----- sgm/inference/helpers.py | 19 ++++++++--- tests/inference/test_modelmanager.py | 51 ++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+), 16 deletions(-) create mode 100644 tests/inference/test_modelmanager.py diff --git a/pyproject.toml b/pyproject.toml index 2cc50216..94ba68df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,5 +44,5 @@ dependencies = [ test-inference = [ "pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118", "pip install -r requirements/pt2.txt", - "pytest -v tests/inference/test_inference.py {args}", + "pytest -v tests/inference {args}", ] diff --git a/sgm/inference/api.py b/sgm/inference/api.py index eccf129b..afb1f72f 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -6,7 +6,7 @@ do_sample, do_img2img, DeviceModelManager, - CudaModelManager, + get_model_manager, Img2ImgDiscretizationWrapper, Txt2NoisyDiscretizationWrapper, ) @@ -174,9 +174,7 @@ def __init__( model_path: Optional[str] = None, config_path: Optional[str] = None, use_fp16: bool = True, - device: Union[DeviceModelManager, str, torch.device] = CudaModelManager( - device="cuda" - ), + device: Optional[Union[DeviceModelManager, str, torch.device]] = None ) -> None: """ Sampling pipeline for generating images from a model. @@ -213,18 +211,16 @@ def __init__( raise ValueError( f"Checkpoint {self.ckpt} not found, check model spec or config_path" ) - - if isinstance(device, torch.device) or isinstance(device, str): - if torch.device(device).type == "cuda": - self.device_manager = CudaModelManager(device=device) - else: - self.device_manager = DeviceModelManager(device=device) + if not isinstance(device, DeviceModelManager): + self.device_manager = get_model_manager(device=device) else: self.device_manager = device + self.model = self._load_model( device_manager=self.device_manager, use_fp16=use_fp16 ) + def _load_model(self, device_manager: DeviceModelManager, use_fp16=True): config = OmegaConf.load(self.config) model = load_model_from_config(config, self.ckpt) diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index f86eda6e..addefe31 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -67,7 +67,7 @@ def __init__( self, device: Union[torch.device, str], swap_device: Optional[Union[torch.device, str]] = None, - ): + ) -> None: """ Args: device (Union[torch.device, str]): The device to use for the model. @@ -77,11 +77,11 @@ def __init__( torch.device(swap_device) if swap_device is not None else self.device ) - def load(self, model: torch.nn.Module): + def load(self, model: torch.nn.Module) -> None: """ - Loads a model to the device. + Loads a model to the (swap) device. """ - return model.to(self.device) + model.to(self.swap_device) def autocast(self): """ @@ -109,7 +109,7 @@ def use(self, model: torch.nn.Module): class CudaModelManager(DeviceModelManager): """ Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use. - """ + """ @contextlib.contextmanager def use(self, model: Union[torch.nn.Module, torch.Tensor]): @@ -141,6 +141,15 @@ def perform_save_locally(save_path, samples): base_count += 1 +def get_model_manager(device: Union[str,torch.device]) -> DeviceModelManager: + if isinstance(device, torch.device) or isinstance(device, str): + if torch.device(device).type == "cuda": + return CudaModelManager(device=device) + else: + return DeviceModelManager(device=device) + else: + return device + class Img2ImgDiscretizationWrapper: """ wraps a discretizer, and prunes the sigmas diff --git a/tests/inference/test_modelmanager.py b/tests/inference/test_modelmanager.py new file mode 100644 index 00000000..fa96d704 --- /dev/null +++ b/tests/inference/test_modelmanager.py @@ -0,0 +1,51 @@ +import numpy +from PIL import Image +import pytest +from pytest import fixture +import torch +from typing import Tuple, Optional + +from sgm.inference.api import ( + model_specs, + SamplingParams, + SamplingPipeline, + Sampler, + ModelArchitecture, +) +import sgm.inference.helpers as helpers + +def get_torch_device(model: torch.nn.Module) -> torch.device: + param = next(model.parameters(), None) + if param is not None: + return param.device + else: + buf = next(model.buffers(), None) + if buf is not None: + return buf.device + else: + raise TypeError("Could not determine device of input model") + + +@pytest.mark.inference +def test_default_loading(): + pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1) + assert get_torch_device(pipeline.model.model).type == "cuda" + assert get_torch_device(pipeline.model.conditioner).type == "cuda" + with pipeline.device_manager.use(pipeline.model.model): + assert get_torch_device(pipeline.model.model).type == "cuda" + assert get_torch_device(pipeline.model.model).type == "cuda" + with pipeline.device_manager.use(pipeline.model.conditioner): + assert get_torch_device(pipeline.model.conditioner).type == "cuda" + assert get_torch_device(pipeline.model.conditioner).type == "cuda" + +@pytest.mark.inference +def test_model_swapping(): + pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1, device=helpers.CudaModelManager(device="cuda", swap_device="cpu")) + assert get_torch_device(pipeline.model.model).type == "cpu" + assert get_torch_device(pipeline.model.conditioner).type == "cpu" + with pipeline.device_manager.use(pipeline.model.model): + assert get_torch_device(pipeline.model.model).type == "cuda" + assert get_torch_device(pipeline.model.model).type == "cpu" + with pipeline.device_manager.use(pipeline.model.conditioner): + assert get_torch_device(pipeline.model.conditioner).type == "cuda" + assert get_torch_device(pipeline.model.conditioner).type == "cpu" \ No newline at end of file From 98c4b7753b2cf9b467bdda0c94569b5b910262fa Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sat, 12 Aug 2023 07:16:02 +0000 Subject: [PATCH 48/56] cleanup imports in test --- tests/inference/test_modelmanager.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tests/inference/test_modelmanager.py b/tests/inference/test_modelmanager.py index fa96d704..bb1ab0e4 100644 --- a/tests/inference/test_modelmanager.py +++ b/tests/inference/test_modelmanager.py @@ -1,15 +1,8 @@ -import numpy -from PIL import Image import pytest -from pytest import fixture import torch -from typing import Tuple, Optional -from sgm.inference.api import ( - model_specs, - SamplingParams, - SamplingPipeline, - Sampler, +from sgm.inference.api import ( + SamplingPipeline, ModelArchitecture, ) import sgm.inference.helpers as helpers From f6704532a0c50eaf7961600d1e517a18e2060740 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sat, 12 Aug 2023 07:27:25 +0000 Subject: [PATCH 49/56] abstract device defaults --- sgm/inference/api.py | 15 ++++++--------- sgm/inference/helpers.py | 28 ++++++++++++++++++---------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index afb1f72f..9ca1111b 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -174,7 +174,7 @@ def __init__( model_path: Optional[str] = None, config_path: Optional[str] = None, use_fp16: bool = True, - device: Optional[Union[DeviceModelManager, str, torch.device]] = None + device: Optional[Union[DeviceModelManager, str, torch.device]] = None, ) -> None: """ Sampling pipeline for generating images from a model. @@ -211,16 +211,13 @@ def __init__( raise ValueError( f"Checkpoint {self.ckpt} not found, check model spec or config_path" ) - if not isinstance(device, DeviceModelManager): - self.device_manager = get_model_manager(device=device) - else: - self.device_manager = device + + self.device_manager = get_model_manager(device) self.model = self._load_model( device_manager=self.device_manager, use_fp16=use_fp16 ) - def _load_model(self, device_manager: DeviceModelManager, use_fp16=True): config = OmegaConf.load(self.config) model = load_model_from_config(config, self.ckpt) @@ -268,7 +265,7 @@ def text_to_image( force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, filter=filter, - device_manager=self.device_manager, + device=self.device_manager, ) def image_to_image( @@ -308,7 +305,7 @@ def image_to_image( force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, filter=filter, - device_manager=self.device_manager, + device=self.device_manager, ) def wrap_discretization( @@ -377,7 +374,7 @@ def refiner( return_latents=return_latents, filter=filter, add_noise=add_noise, - device_manager=self.device_manager, + device=self.device_manager, ) diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index addefe31..e84b8a27 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -109,7 +109,7 @@ def use(self, model: torch.nn.Module): class CudaModelManager(DeviceModelManager): """ Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use. - """ + """ @contextlib.contextmanager def use(self, model: Union[torch.nn.Module, torch.Tensor]): @@ -141,14 +141,19 @@ def perform_save_locally(save_path, samples): base_count += 1 -def get_model_manager(device: Union[str,torch.device]) -> DeviceModelManager: - if isinstance(device, torch.device) or isinstance(device, str): - if torch.device(device).type == "cuda": - return CudaModelManager(device=device) - else: - return DeviceModelManager(device=device) - else: +def get_model_manager( + device: Optional[Union[DeviceModelManager, str, torch.device]] +) -> DeviceModelManager: + if isinstance(device, DeviceModelManager): return device + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + if device.type == "cuda": + return CudaModelManager(device=device) + else: + return DeviceModelManager(device=device) + class Img2ImgDiscretizationWrapper: """ @@ -217,13 +222,15 @@ def do_sample( batch2model_input: Optional[List] = None, return_latents=False, filter=None, - device_manager: DeviceModelManager = DeviceModelManager("cuda"), + device: Optional[Union[DeviceModelManager, str, torch.device]] = None, ): if force_uc_zero_embeddings is None: force_uc_zero_embeddings = [] if batch2model_input is None: batch2model_input = [] + device_manager = get_model_manager(device=device) + with torch.no_grad(): with device_manager.autocast(): with model.ema_scope(): @@ -367,8 +374,9 @@ def do_img2img( skip_encode=False, filter=None, add_noise=True, - device_manager=DeviceModelManager("cuda"), + device: Optional[Union[DeviceModelManager, str, torch.device]] = None, ): + device_manager = get_model_manager(device) with torch.no_grad(): with device_manager.autocast(): with model.ema_scope(): From c0655731d5e637169f3019b349533528c02b6992 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sat, 12 Aug 2023 04:25:56 -0700 Subject: [PATCH 50/56] fix streamlit inputs --- scripts/demo/sampling.py | 10 ++++++---- scripts/demo/streamlit_helpers.py | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index a18b21aa..54155519 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -88,14 +88,14 @@ def run_txt2img( model: SamplingPipeline = state["model"] params: SamplingParams = state["params"] if version.startswith("stable-diffusion-xl") and version.endswith("-base"): - params.width, params.height = st.selectbox( + width, height = st.selectbox( "Resolution:", list(SD_XL_BASE_RATIOS.values()), 10 ) else: - params.height = int( + height = int( st.number_input("H", value=spec.height, min_value=64, max_value=2048) ) - params.width = int( + width = int( st.number_input("W", value=spec.width, min_value=64, max_value=2048) ) @@ -107,6 +107,8 @@ def run_txt2img( ) params, num_rows, num_cols = init_sampling(params=params) num_samples = num_rows * num_cols + params.height = height + params.width = width if st.button("Sample"): st.write(f"**Model I:** {version}") @@ -289,8 +291,8 @@ def apply_refiner( ) params2, *_ = init_sampling( - key=2, params=state2["params"], + key=2, specify_num_samples=False, ) st.write("__________________________") diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 93293dd7..5c350694 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -130,7 +130,7 @@ def show_samples(samples, outputs): outputs.image(grid.cpu().numpy()) -def get_guider(key, params: SamplingParams) -> SamplingParams: +def get_guider(params: SamplingParams, key=1) -> SamplingParams: params.guider = Guider( st.sidebar.selectbox( f"Discretization #{key}", [member.value for member in Guider] @@ -157,8 +157,8 @@ def get_guider(key, params: SamplingParams) -> SamplingParams: def init_sampling( + params: SamplingParams, key=1, - params: SamplingParams = SamplingParams(), specify_num_samples=True, ) -> Tuple[SamplingParams, int, int]: params = SamplingParams(img2img_strength=params.img2img_strength) From fbe93fc53b3407acab4cf3394b8c0645ced98a7c Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sat, 12 Aug 2023 05:33:16 -0700 Subject: [PATCH 51/56] PR fixes, model specific defaults --- scripts/demo/sampling.py | 35 +++------ scripts/demo/streamlit_helpers.py | 22 ++---- sgm/inference/api.py | 115 ++++++++++++++++++------------ tests/inference/test_inference.py | 2 +- 4 files changed, 88 insertions(+), 86 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 54155519..20a8f03e 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -66,9 +66,7 @@ def load_img(display=True, key=None, device="cuda"): st.image(image) w, h = image.size print(f"loaded input image of size ({w}, {h})") - width, height = map( - lambda x: x - x % 64, (w, h) - ) # resize to integer multiple of 64 + width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 image = image.resize((width, height)) image = np.array(image.convert("RGB")) image = image[None].transpose(0, 3, 1, 2) @@ -78,26 +76,19 @@ def load_img(display=True, key=None, device="cuda"): def run_txt2img( state, - version: str, + model_id: ModelArchitecture, prompt: str, negative_prompt: str, return_latents=False, stage2strength=None, ): - spec: SamplingSpec = state["spec"] model: SamplingPipeline = state["model"] params: SamplingParams = state["params"] - if version.startswith("stable-diffusion-xl") and version.endswith("-base"): - width, height = st.selectbox( - "Resolution:", list(SD_XL_BASE_RATIOS.values()), 10 - ) + if model_id in sdxl_base_model_list: + width, height = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10) else: - height = int( - st.number_input("H", value=spec.height, min_value=64, max_value=2048) - ) - width = int( - st.number_input("W", value=spec.width, min_value=64, max_value=2048) - ) + height = int(st.number_input("H", value=params.height, min_value=64, max_value=2048)) + width = int(st.number_input("W", value=params.width, min_value=64, max_value=2048)) params = init_embedder_options( get_unique_embedder_keys_from_conditioner(model.model.conditioner), @@ -207,12 +198,12 @@ def apply_refiner( sdxl_base_model_list = [ - ModelArchitecture.SDXL_V1_BASE, + ModelArchitecture.SDXL_V1_0_BASE, ModelArchitecture.SDXL_V0_9_BASE, ] sdxl_refiner_model_list = [ - ModelArchitecture.SDXL_V1_REFINER, + ModelArchitecture.SDXL_V1_0_REFINER, ModelArchitecture.SDXL_V0_9_REFINER, ] @@ -239,9 +230,7 @@ def apply_refiner( else: add_pipeline = False - seed = int( - st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)) - ) + seed = int(st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))) seed_everything(seed) save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version))) @@ -281,9 +270,7 @@ def apply_refiner( st.write("**Refiner Options:**") specs2 = model_specs[version2] - state2 = init_st( - specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap - ) + state2 = init_st(specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap) params2 = state2["params"] params2.img2img_strength = st.number_input( @@ -309,7 +296,7 @@ def apply_refiner( if mode == "txt2img": out = run_txt2img( state=state, - version=str(version), + model_id=version_enum, prompt=prompt, negative_prompt=negative_prompt, return_latents=add_pipeline, diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 5c350694..a1770a86 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -48,7 +48,7 @@ def init_st( state["model"] = pipeline state["ckpt"] = ckpt if load_ckpt else None state["config"] = config - state["params"] = SamplingParams() + state["params"] = spec.default_params if load_filter: state["filter"] = DeepFloydDataFiltering(verbose=False) else: @@ -132,9 +132,7 @@ def show_samples(samples, outputs): def get_guider(params: SamplingParams, key=1) -> SamplingParams: params.guider = Guider( - st.sidebar.selectbox( - f"Discretization #{key}", [member.value for member in Guider] - ) + st.sidebar.selectbox(f"Discretization #{key}", [member.value for member in Guider]) ) if params.guider == Guider.VANILLA: @@ -165,14 +163,10 @@ def init_sampling( num_rows, num_cols = 1, 1 if specify_num_samples: - num_cols = st.number_input( - f"num cols #{key}", value=2, min_value=1, max_value=10 - ) + num_cols = st.number_input(f"num cols #{key}", value=2, min_value=1, max_value=10) params.steps = int( - st.sidebar.number_input( - f"steps #{key}", value=params.steps, min_value=1, max_value=1000 - ) + st.sidebar.number_input(f"steps #{key}", value=params.steps, min_value=1, max_value=1000) ) params.sampler = Sampler( @@ -220,15 +214,11 @@ def get_sampler(params: SamplingParams, key=1) -> SamplingParams: ) elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL): - params.s_noise = st.sidebar.number_input( - "s_noise", value=params.s_noise, min_value=0.0 - ) + params.s_noise = st.sidebar.number_input("s_noise", value=params.s_noise, min_value=0.0) params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0) elif params.sampler == Sampler.LINEAR_MULTISTEP: - params.order = int( - st.sidebar.number_input("order", value=params.order, min_value=1) - ) + params.order = int(st.sidebar.number_input("order", value=params.order, min_value=1)) return params diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 9ca1111b..d2d5a7d2 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -24,8 +24,8 @@ class ModelArchitecture(str, Enum): - SDXL_V1_BASE = "stable-diffusion-xl-v1-base" - SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner" + SDXL_V1_0_BASE = "stable-diffusion-xl-v1-base" + SDXL_V1_0_REFINER = "stable-diffusion-xl-v1-refiner" SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base" SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner" SD_2_1 = "stable-diffusion-v2-1" @@ -59,24 +59,21 @@ class Thresholder(str, Enum): class SamplingParams: """ Parameters for sampling. - The defaults here are derived from user preference testing. - They will be subject to change in the future, likely pulled - from model specs instead of global defaults. """ - width: int = 1024 - height: int = 1024 - steps: int = 40 + width: int + height: int + steps: int sampler: Sampler = Sampler.EULER_EDM discretization: Discretization = Discretization.LEGACY_DDPM guider: Guider = Guider.VANILLA thresholder: Thresholder = Thresholder.NONE - scale: float = 5.0 + scale: float aesthetic_score: float = 6.0 negative_aesthetic_score: float = 2.5 img2img_strength: float = 1.0 - orig_width: int = width - orig_height: int = height + orig_width: int = 1024 + orig_height: int = 1024 crop_coords_top: int = 0 crop_coords_left: int = 0 sigma_min: float = 0.0292 @@ -100,8 +97,10 @@ class SamplingSpec: config: str ckpt: str is_guided: bool + default_params: SamplingParams +# The defaults here are derived from user preference testing. model_specs = { ModelArchitecture.SD_2_1: SamplingSpec( height=512, @@ -112,6 +111,12 @@ class SamplingSpec: config="sd_2_1.yaml", ckpt="v2-1_512-ema-pruned.safetensors", is_guided=True, + default_params=SamplingParams( + width=512, + height=512, + steps=40, + scale=7.0, + ), ), ModelArchitecture.SD_2_1_768: SamplingSpec( height=768, @@ -122,6 +127,12 @@ class SamplingSpec: config="sd_2_1_768.yaml", ckpt="v2-1_768-ema-pruned.safetensors", is_guided=True, + default_params=SamplingParams( + width=768, + height=768, + steps=40, + scale=7.0, + ), ), ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec( height=1024, @@ -132,6 +143,7 @@ class SamplingSpec: config="sd_xl_base.yaml", ckpt="sd_xl_base_0.9.safetensors", is_guided=True, + default_params=SamplingParams(width=1024, height=1024, steps=40, scale=5.0), ), ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec( height=1024, @@ -142,8 +154,11 @@ class SamplingSpec: config="sd_xl_refiner.yaml", ckpt="sd_xl_refiner_0.9.safetensors", is_guided=True, + default_params=SamplingParams( + width=1024, height=1024, steps=40, scale=5.0, img2img_strength=0.15 + ), ), - ModelArchitecture.SDXL_V1_BASE: SamplingSpec( + ModelArchitecture.SDXL_V1_0_BASE: SamplingSpec( height=1024, width=1024, channels=4, @@ -152,8 +167,9 @@ class SamplingSpec: config="sd_xl_base.yaml", ckpt="sd_xl_base_1.0.safetensors", is_guided=True, + default_params=SamplingParams(width=1024, height=1024, steps=40, scale=5.0), ), - ModelArchitecture.SDXL_V1_REFINER: SamplingSpec( + ModelArchitecture.SDXL_V1_0_REFINER: SamplingSpec( height=1024, width=1024, channels=4, @@ -162,10 +178,39 @@ class SamplingSpec: config="sd_xl_refiner.yaml", ckpt="sd_xl_refiner_1.0.safetensors", is_guided=True, + default_params=SamplingParams( + width=1024, height=1024, steps=40, scale=5.0, img2img_strength=0.15 + ), ), } +def wrap_discretization( + discretization, image_strength=None, noise_strength=None, steps=None +): + if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance( + discretization, Txt2NoisyDiscretizationWrapper + ): + return discretization # Already wrapped + if image_strength is not None and image_strength < 1.0 and image_strength > 0.0: + discretization = Img2ImgDiscretizationWrapper( + discretization, strength=image_strength + ) + + if ( + noise_strength is not None + and noise_strength < 1.0 + and noise_strength > 0.0 + and steps is not None + ): + discretization = Txt2NoisyDiscretizationWrapper( + discretization, + strength=noise_strength, + original_steps=steps, + ) + return discretization + + class SamplingPipeline: def __init__( self, @@ -231,17 +276,19 @@ def _load_model(self, device_manager: DeviceModelManager, use_fp16=True): def text_to_image( self, - params: SamplingParams, prompt: str, + params: Optional[SamplingParams] = None, negative_prompt: str = "", samples: int = 1, return_latents: bool = False, noise_strength: Optional[float] = None, filter=None, ): + if params is None: + params = self.specs.default_params sampler = get_sampler_config(params) - sampler.discretization = self.wrap_discretization( + sampler.discretization = wrap_discretization( sampler.discretization, image_strength=None, noise_strength=noise_strength, @@ -270,18 +317,20 @@ def text_to_image( def image_to_image( self, - params: SamplingParams, image: torch.Tensor, prompt: str, + params: Optional[SamplingParams] = None, negative_prompt: str = "", samples: int = 1, return_latents: bool = False, noise_strength: Optional[float] = None, filter=None, ): + if params is None: + params = self.specs.default_params sampler = get_sampler_config(params) - sampler.discretization = self.wrap_discretization( + sampler.discretization = wrap_discretization( sampler.discretization, image_strength=params.img2img_strength, noise_strength=noise_strength, @@ -308,44 +357,20 @@ def image_to_image( device=self.device_manager, ) - def wrap_discretization( - self, discretization, image_strength=None, noise_strength=None, steps=None - ): - if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance( - discretization, Txt2NoisyDiscretizationWrapper - ): - return discretization # Already wrapped - if image_strength is not None and image_strength < 1.0 and image_strength > 0.0: - discretization = Img2ImgDiscretizationWrapper( - discretization, strength=image_strength - ) - - if ( - noise_strength is not None - and noise_strength < 1.0 - and noise_strength > 0.0 - and steps is not None - ): - discretization = Txt2NoisyDiscretizationWrapper( - discretization, - strength=noise_strength, - original_steps=steps, - ) - return discretization - def refiner( self, image: torch.Tensor, prompt: str, negative_prompt: str = "", - params: SamplingParams = SamplingParams( - sampler=Sampler.EULER_EDM, steps=40, img2img_strength=0.15 - ), + params: Optional[SamplingParams] = None, samples: int = 1, return_latents: bool = False, filter: Any = None, add_noise: bool = False, ): + if params is None: + params = self.specs.default_params + sampler = get_sampler_config(params) value_dict = { "orig_width": image.shape[3] * 8, diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index 617e4088..04eceb7a 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -27,7 +27,7 @@ def pipeline(self, request) -> SamplingPipeline: @fixture( scope="class", params=[ - [ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER], + [ModelArchitecture.SDXL_V1_0_BASE, ModelArchitecture.SDXL_V1_0_REFINER], [ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER], ], ids=["SDXL_V1", "SDXL_V0_9"], From 5fde7e73b80ff781bb753d1fa2bf6c0b7d1c8c45 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sat, 12 Aug 2023 05:35:36 -0700 Subject: [PATCH 52/56] set a default scale --- sgm/inference/api.py | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index d2d5a7d2..51269b69 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -68,7 +68,7 @@ class SamplingParams: discretization: Discretization = Discretization.LEGACY_DDPM guider: Guider = Guider.VANILLA thresholder: Thresholder = Thresholder.NONE - scale: float + scale: float = 5.0 aesthetic_score: float = 6.0 negative_aesthetic_score: float = 2.5 img2img_strength: float = 1.0 @@ -185,17 +185,13 @@ class SamplingSpec: } -def wrap_discretization( - discretization, image_strength=None, noise_strength=None, steps=None -): +def wrap_discretization(discretization, image_strength=None, noise_strength=None, steps=None): if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance( discretization, Txt2NoisyDiscretizationWrapper ): return discretization # Already wrapped if image_strength is not None and image_strength < 1.0 and image_strength > 0.0: - discretization = Img2ImgDiscretizationWrapper( - discretization, strength=image_strength - ) + discretization = Img2ImgDiscretizationWrapper(discretization, strength=image_strength) if ( noise_strength is not None @@ -249,19 +245,13 @@ def __init__( self.config = os.path.join(config_path, "inference", self.specs.config) self.ckpt = os.path.join(model_path, self.specs.ckpt) if not os.path.exists(self.config): - raise ValueError( - f"Config {self.config} not found, check model spec or config_path" - ) + raise ValueError(f"Config {self.config} not found, check model spec or config_path") if not os.path.exists(self.ckpt): - raise ValueError( - f"Checkpoint {self.ckpt} not found, check model spec or config_path" - ) + raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path") self.device_manager = get_model_manager(device) - self.model = self._load_model( - device_manager=self.device_manager, use_fp16=use_fp16 - ) + self.model = self._load_model(device_manager=self.device_manager, use_fp16=use_fp16) def _load_model(self, device_manager: DeviceModelManager, use_fp16=True): config = OmegaConf.load(self.config) @@ -406,9 +396,7 @@ def refiner( def get_guider_config(params: SamplingParams) -> Dict[str, Any]: guider_config: Dict[str, Any] if params.guider == Guider.IDENTITY: - guider_config = { - "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" - } + guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} elif params.guider == Guider.VANILLA: scale = params.scale From 65c6ec1cecd3ce6fb47ce714ec071fc4af3c94f2 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sat, 12 Aug 2023 05:40:25 -0700 Subject: [PATCH 53/56] run black --- scripts/demo/sampling.py | 24 ++++++++++++++++++------ scripts/demo/streamlit_helpers.py | 20 +++++++++++++++----- sgm/inference/api.py | 24 ++++++++++++++++++------ 3 files changed, 51 insertions(+), 17 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 20a8f03e..017db211 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -66,7 +66,9 @@ def load_img(display=True, key=None, device="cuda"): st.image(image) w, h = image.size print(f"loaded input image of size ({w}, {h})") - width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 + width, height = map( + lambda x: x - x % 64, (w, h) + ) # resize to integer multiple of 64 image = image.resize((width, height)) image = np.array(image.convert("RGB")) image = image[None].transpose(0, 3, 1, 2) @@ -85,10 +87,16 @@ def run_txt2img( model: SamplingPipeline = state["model"] params: SamplingParams = state["params"] if model_id in sdxl_base_model_list: - width, height = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10) + width, height = st.selectbox( + "Resolution:", list(SD_XL_BASE_RATIOS.values()), 10 + ) else: - height = int(st.number_input("H", value=params.height, min_value=64, max_value=2048)) - width = int(st.number_input("W", value=params.width, min_value=64, max_value=2048)) + height = int( + st.number_input("H", value=params.height, min_value=64, max_value=2048) + ) + width = int( + st.number_input("W", value=params.width, min_value=64, max_value=2048) + ) params = init_embedder_options( get_unique_embedder_keys_from_conditioner(model.model.conditioner), @@ -230,7 +238,9 @@ def apply_refiner( else: add_pipeline = False - seed = int(st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))) + seed = int( + st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)) + ) seed_everything(seed) save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version))) @@ -270,7 +280,9 @@ def apply_refiner( st.write("**Refiner Options:**") specs2 = model_specs[version2] - state2 = init_st(specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap) + state2 = init_st( + specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap + ) params2 = state2["params"] params2.img2img_strength = st.number_input( diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index a1770a86..eeb0f203 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -132,7 +132,9 @@ def show_samples(samples, outputs): def get_guider(params: SamplingParams, key=1) -> SamplingParams: params.guider = Guider( - st.sidebar.selectbox(f"Discretization #{key}", [member.value for member in Guider]) + st.sidebar.selectbox( + f"Discretization #{key}", [member.value for member in Guider] + ) ) if params.guider == Guider.VANILLA: @@ -163,10 +165,14 @@ def init_sampling( num_rows, num_cols = 1, 1 if specify_num_samples: - num_cols = st.number_input(f"num cols #{key}", value=2, min_value=1, max_value=10) + num_cols = st.number_input( + f"num cols #{key}", value=2, min_value=1, max_value=10 + ) params.steps = int( - st.sidebar.number_input(f"steps #{key}", value=params.steps, min_value=1, max_value=1000) + st.sidebar.number_input( + f"steps #{key}", value=params.steps, min_value=1, max_value=1000 + ) ) params.sampler = Sampler( @@ -214,11 +220,15 @@ def get_sampler(params: SamplingParams, key=1) -> SamplingParams: ) elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL): - params.s_noise = st.sidebar.number_input("s_noise", value=params.s_noise, min_value=0.0) + params.s_noise = st.sidebar.number_input( + "s_noise", value=params.s_noise, min_value=0.0 + ) params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0) elif params.sampler == Sampler.LINEAR_MULTISTEP: - params.order = int(st.sidebar.number_input("order", value=params.order, min_value=1)) + params.order = int( + st.sidebar.number_input("order", value=params.order, min_value=1) + ) return params diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 51269b69..d863f5ee 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -185,13 +185,17 @@ class SamplingSpec: } -def wrap_discretization(discretization, image_strength=None, noise_strength=None, steps=None): +def wrap_discretization( + discretization, image_strength=None, noise_strength=None, steps=None +): if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance( discretization, Txt2NoisyDiscretizationWrapper ): return discretization # Already wrapped if image_strength is not None and image_strength < 1.0 and image_strength > 0.0: - discretization = Img2ImgDiscretizationWrapper(discretization, strength=image_strength) + discretization = Img2ImgDiscretizationWrapper( + discretization, strength=image_strength + ) if ( noise_strength is not None @@ -245,13 +249,19 @@ def __init__( self.config = os.path.join(config_path, "inference", self.specs.config) self.ckpt = os.path.join(model_path, self.specs.ckpt) if not os.path.exists(self.config): - raise ValueError(f"Config {self.config} not found, check model spec or config_path") + raise ValueError( + f"Config {self.config} not found, check model spec or config_path" + ) if not os.path.exists(self.ckpt): - raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path") + raise ValueError( + f"Checkpoint {self.ckpt} not found, check model spec or config_path" + ) self.device_manager = get_model_manager(device) - self.model = self._load_model(device_manager=self.device_manager, use_fp16=use_fp16) + self.model = self._load_model( + device_manager=self.device_manager, use_fp16=use_fp16 + ) def _load_model(self, device_manager: DeviceModelManager, use_fp16=True): config = OmegaConf.load(self.config) @@ -396,7 +406,9 @@ def refiner( def get_guider_config(params: SamplingParams) -> Dict[str, Any]: guider_config: Dict[str, Any] if params.guider == Guider.IDENTITY: - guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" + } elif params.guider == Guider.VANILLA: scale = params.scale From e32972b85bc6648ec44b3c7cb6a538a9d8d9838e Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sat, 12 Aug 2023 05:42:22 -0700 Subject: [PATCH 54/56] remove extra init --- scripts/demo/streamlit_helpers.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index eeb0f203..a5f1b03a 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -132,9 +132,7 @@ def show_samples(samples, outputs): def get_guider(params: SamplingParams, key=1) -> SamplingParams: params.guider = Guider( - st.sidebar.selectbox( - f"Discretization #{key}", [member.value for member in Guider] - ) + st.sidebar.selectbox(f"Discretization #{key}", [member.value for member in Guider]) ) if params.guider == Guider.VANILLA: @@ -161,18 +159,12 @@ def init_sampling( key=1, specify_num_samples=True, ) -> Tuple[SamplingParams, int, int]: - params = SamplingParams(img2img_strength=params.img2img_strength) - num_rows, num_cols = 1, 1 if specify_num_samples: - num_cols = st.number_input( - f"num cols #{key}", value=2, min_value=1, max_value=10 - ) + num_cols = st.number_input(f"num cols #{key}", value=2, min_value=1, max_value=10) params.steps = int( - st.sidebar.number_input( - f"steps #{key}", value=params.steps, min_value=1, max_value=1000 - ) + st.sidebar.number_input(f"steps #{key}", value=params.steps, min_value=1, max_value=1000) ) params.sampler = Sampler( @@ -220,15 +212,11 @@ def get_sampler(params: SamplingParams, key=1) -> SamplingParams: ) elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL): - params.s_noise = st.sidebar.number_input( - "s_noise", value=params.s_noise, min_value=0.0 - ) + params.s_noise = st.sidebar.number_input("s_noise", value=params.s_noise, min_value=0.0) params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0) elif params.sampler == Sampler.LINEAR_MULTISTEP: - params.order = int( - st.sidebar.number_input("order", value=params.order, min_value=1) - ) + params.order = int(st.sidebar.number_input("order", value=params.order, min_value=1)) return params From 2fc4680bf9fc82ae2a94357ed62574916de1dad2 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sat, 12 Aug 2023 13:22:04 -0700 Subject: [PATCH 55/56] Easier default params --- scripts/demo/streamlit_helpers.py | 20 +++++++++++++++----- sgm/inference/api.py | 15 ++++++++++++--- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index a5f1b03a..a0f3848a 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -132,7 +132,9 @@ def show_samples(samples, outputs): def get_guider(params: SamplingParams, key=1) -> SamplingParams: params.guider = Guider( - st.sidebar.selectbox(f"Discretization #{key}", [member.value for member in Guider]) + st.sidebar.selectbox( + f"Discretization #{key}", [member.value for member in Guider] + ) ) if params.guider == Guider.VANILLA: @@ -161,10 +163,14 @@ def init_sampling( ) -> Tuple[SamplingParams, int, int]: num_rows, num_cols = 1, 1 if specify_num_samples: - num_cols = st.number_input(f"num cols #{key}", value=2, min_value=1, max_value=10) + num_cols = st.number_input( + f"num cols #{key}", value=2, min_value=1, max_value=10 + ) params.steps = int( - st.sidebar.number_input(f"steps #{key}", value=params.steps, min_value=1, max_value=1000) + st.sidebar.number_input( + f"steps #{key}", value=params.steps, min_value=1, max_value=1000 + ) ) params.sampler = Sampler( @@ -212,11 +218,15 @@ def get_sampler(params: SamplingParams, key=1) -> SamplingParams: ) elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL): - params.s_noise = st.sidebar.number_input("s_noise", value=params.s_noise, min_value=0.0) + params.s_noise = st.sidebar.number_input( + "s_noise", value=params.s_noise, min_value=0.0 + ) params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0) elif params.sampler == Sampler.LINEAR_MULTISTEP: - params.order = int(st.sidebar.number_input("order", value=params.order, min_value=1)) + params.order = int( + st.sidebar.number_input("order", value=params.order, min_value=1) + ) return params diff --git a/sgm/inference/api.py b/sgm/inference/api.py index d863f5ee..87592dc8 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -61,9 +61,9 @@ class SamplingParams: Parameters for sampling. """ - width: int - height: int - steps: int + width: Optional[int] = None + height: Optional[int] = None + steps: Optional[int] = None sampler: Sampler = Sampler.EULER_EDM discretization: Discretization = Discretization.LEGACY_DDPM guider: Guider = Guider.VANILLA @@ -286,6 +286,15 @@ def text_to_image( ): if params is None: params = self.specs.default_params + else: + # Set defaults if optional params are not specified + if params.width is None: + params.width = self.specs.default_params.width + if params.height is None: + params.height = self.specs.default_params.height + if params.steps is None: + params.steps = self.specs.default_params.steps + sampler = get_sampler_config(params) sampler.discretization = wrap_discretization( From e28962199273d72d691a3764cb0377969123dd42 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sat, 12 Aug 2023 13:52:46 -0700 Subject: [PATCH 56/56] fix reference --- sgm/inference/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 87592dc8..f4f2faa7 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -394,7 +394,7 @@ def refiner( "negative_aesthetic_score": 2.5, } - sampler.discretization = self.wrap_discretization( + sampler.discretization = wrap_discretization( sampler.discretization, image_strength=params.img2img_strength )