From b51c36b0dffefaa4a552e75427868dcb19ad508e Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 9 Aug 2023 19:31:59 -0700 Subject: [PATCH] extract path resolution method, fix/improve device swapping support --- scripts/demo/streamlit_helpers.py | 11 +++--- sgm/inference/api.py | 61 ++++++++++++++++++++----------- 2 files changed, 45 insertions(+), 27 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index c25284f5..2d6972e8 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -33,11 +33,12 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A config = spec.config ckpt = spec.ckpt - pipeline = SamplingPipeline( - model_spec=spec, - use_fp16=lowvram_mode, - device="cpu" if lowvram_mode else "cuda", - ) + if lowvram_mode: + pipeline = SamplingPipeline( + model_spec=spec, use_fp16=True, device="cuda", swap_device="cpu" + ) + else: + pipeline = SamplingPipeline(model_spec=spec, use_fp16=True, device="cuda") state["spec"] = spec state["model"] = pipeline diff --git a/sgm/inference/api.py b/sgm/inference/api.py index e3f3d17d..ecdf7066 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -19,7 +19,7 @@ ) from sgm.util import load_model_from_config import torch -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, Union class ModelArchitecture(str, Enum): @@ -163,11 +163,24 @@ def __init__( self, model_id: Optional[ModelArchitecture] = None, model_spec: Optional[SamplingSpec] = None, - model_path=None, - config_path=None, - device="cuda", - use_fp16=True, + model_path: Optional[Union[str, pathlib.Path]] = None, + config_path: Optional[Union[str, pathlib.Path]] = None, + device: Union[str, torch.Device] = "cuda", + swap_device: Optional[Union[str, torch.Device]] = None, + use_fp16: bool = True, ) -> None: + """ + Sampling pipeline for generating images from a model. + + @param model_id: Model architecture to use. If not specified, model_spec must be specified. + @param model_spec: Model specification to use. If not specified, model_id must be specified. + @param model_path: Path to model checkpoints folder. + @param config_path: Path to model config folder. + @param device: Device to use for sampling. + @param swap_device: Device to swap models to when not in use. + @param use_fp16: Whether to use fp16 for sampling. + """ + self.model_id = model_id if model_spec is not None: self.specs = model_spec @@ -179,23 +192,9 @@ def __init__( raise ValueError("Either model_id or model_spec should be provided") if model_path is None: - model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints" - if not os.path.exists(model_path): - # This supports development installs where checkpoints is root level of the repo - model_path = ( - pathlib.Path(__file__).parent.parent.parent.resolve() - / "checkpoints" - ) + model_path = self._resolve_default_path("checkpoints") if config_path is None: - config_path = ( - pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" - ) - if not os.path.exists(config_path): - # This supports development installs where configs is root level of the repo - config_path = ( - pathlib.Path(__file__).parent.parent.parent.resolve() - / "configs/inference" - ) + config_path = self._resolve_default_path("configs/inference") self.config = str(pathlib.Path(config_path) / self.specs.config) self.ckpt = str(pathlib.Path(model_path) / self.specs.ckpt) if not os.path.exists(self.config): @@ -207,7 +206,22 @@ def __init__( f"Checkpoint {self.ckpt} not found, check model spec or config_path" ) self.device = device - self.model = self._load_model(device=device, use_fp16=use_fp16) + self.swap_device = swap_device + load_device = device if swap_device is None else swap_device + self.model = self._load_model(device=load_device, use_fp16=use_fp16) + + def _resolve_default_path(self, suffix: str) -> pathlib.Path: + # Resolves a path relative to the root of the module or repo + repo_path = pathlib.Path(__file__).parent.parent.parent.resolve() / suffix + module_path = pathlib.Path(__file__).parent.parent.resolve() / suffix + path = module_path / suffix + if not os.path.exists(path): + path = repo_path / suffix + if not os.path.exists(path): + raise ValueError( + f"Default locations for {suffix} not found, please specify path" + ) + return pathlib.Path(path) def _load_model(self, device="cuda", use_fp16=True): config = OmegaConf.load(self.config) @@ -256,6 +270,7 @@ def text_to_image( force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, filter=filter, + device=self.device, ) def image_to_image( @@ -293,6 +308,7 @@ def image_to_image( force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], return_latents=return_latents, filter=filter, + device=self.device, ) def wrap_discretization( @@ -361,6 +377,7 @@ def refiner( return_latents=return_latents, filter=filter, add_noise=add_noise, + device=self.device, )