Skip to content

Commit

Permalink
Allow loading custom models and improve path logic
Browse files Browse the repository at this point in the history
  • Loading branch information
palp committed Aug 3, 2023
1 parent 73287ec commit 44943df
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions sgm/inference/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass, asdict
from enum import Enum
from omegaconf import OmegaConf
import os
import pathlib
from sgm.inference.helpers import (
do_sample,
Expand Down Expand Up @@ -158,18 +159,33 @@ class SamplingSpec:
class SamplingPipeline:
def __init__(
self,
model_id: ModelArchitecture,
model_path="checkpoints",
config_path="configs/inference",
model_id: Optional[ModelArchitecture] = None,
model_spec: Optional[SamplingSpec] = None,
model_path=None,
config_path=None,
device="cuda",
use_fp16=True,
) -> None:
if model_id not in model_specs:
raise ValueError(f"Model {model_id} not supported")
) -> None:
self.model_id = model_id
self.specs = model_specs[self.model_id]
self.config = str(pathlib.Path(config_path, self.specs.config))
self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
if model_spec is not None:
self.specs = model_spec
elif model_id is not None:
if model_id not in model_specs:
raise ValueError(f"Model {model_id} not supported")
self.specs = model_specs[model_id]
else:
raise ValueError("Either model_id or model_spec should be provided")

if model_path is None:
model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints"
if config_path is None:
config_path = pathlib.Path(__file__).parent.parent.resolve() / "configs/inference"
self.config = str(config_path / self.specs.config)
self.ckpt = str(model_path / self.specs.ckpt)
if not os.path.exists(self.config):
raise ValueError(f"Config {self.config} not found, check model spec or config_path")
if not os.path.exists(self.ckpt):
raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path")
self.device = device
self.model = self._load_model(device=device, use_fp16=use_fp16)

Expand Down

0 comments on commit 44943df

Please sign in to comment.