From 97516f1385af34899737c1e8b8246eb063fd5e95 Mon Sep 17 00:00:00 2001 From: Stefan Baumann Date: Tue, 2 Apr 2024 18:29:22 +0200 Subject: [PATCH] Fix #2: Device and Shape Mismatch on new Diffusers Versions --- attribute_control/model/model.py | 9 +++++---- attribute_control/utils.py | 8 ++++++++ configs/learn_delta.yaml | 4 ++-- learn_delta.py | 2 +- requirements.txt | 2 +- 5 files changed, 17 insertions(+), 8 deletions(-) diff --git a/attribute_control/model/model.py b/attribute_control/model/model.py index 4c3b3fe..737f679 100644 --- a/attribute_control/model/model.py +++ b/attribute_control/model/model.py @@ -13,7 +13,7 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps -from ..utils import reduce_tensors_recursively +from ..utils import reduce_tensors_recursively, broadcast_trailing_dims from .. import PromptEmbedding, EmbeddingDelta @@ -138,7 +138,7 @@ def set_timesteps_custom(num_inference_steps: int = None, timesteps: torch.Tenso return set_timesteps_orig(num_inference_steps=num_inference_steps, device=device) self.pipe.scheduler.set_timesteps = set_timesteps_custom - self.num_inference_steps = num_inference_steps # self.pipe.scheduler.num_inference_steps + self.num_inference_steps = num_inference_steps @abstractmethod def _get_pipe_kwargs(self, embs: List[PromptEmbedding], start_sample: Optional[Float[torch.Tensor, 'n c h w']], **kwargs): @@ -158,7 +158,7 @@ def sample_delayed(self, embs: List[PromptEmbedding], embs_unmodified: List[Prom return self.sample(embs=embs, embs_neg=embs_neg, start_sample=intermediate, **kwargs, start_after_relative=delay_relative) def _get_eps_pred(self, t: Integer[torch.Tensor, 'n'], sample: Float[torch.Tensor, 'n ...'], model_output: Float[torch.Tensor, 'n ...']) -> Float[torch.Tensor, 'n ...']: - alpha_prod_t = self.pipe.scheduler.alphas_cumprod[t] + alpha_prod_t = broadcast_trailing_dims(self.pipe.scheduler.alphas_cumprod[t.to(self.pipe.scheduler.alphas_cumprod.device)].to(model_output.device), model_output) beta_prod_t = 1 - alpha_prod_t if self.pipe.scheduler.config.prediction_type == "epsilon": return (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) @@ -273,7 +273,8 @@ def _compute_time_ids(self, device, weight_dtype) -> torch.Tensor: def predict_eps(self, embs: List[PromptEmbedding], start_sample: Float[torch.Tensor, 'n c h w'], t_relative: Float[torch.Tensor, 'n']) -> Float[torch.Tensor, 'n c h w']: i_t = torch.round(t_relative * (self.num_inference_steps - 1)).to(torch.int64) - t = self.pipe.scheduler.timesteps[i_t.to(self.pipe.scheduler.timesteps.device)] + self.pipe.scheduler.set_timesteps(self.num_inference_steps) + t = self.pipe.scheduler.timesteps[i_t.to(self.pipe.scheduler.timesteps.device)].to(start_sample.device) p_embs = self._get_pipe_kwargs(embs, embs_neg=None, start_sample=None) add_time_ids = self._compute_time_ids(start_sample.device, start_sample.dtype) diff --git a/attribute_control/utils.py b/attribute_control/utils.py index e497e98..7045a66 100644 --- a/attribute_control/utils.py +++ b/attribute_control/utils.py @@ -4,6 +4,7 @@ from typing import Callable, Dict, List, Union, Tuple, Any, Optional import torch +from jaxtyping import Float from omegaconf import OmegaConf @@ -55,3 +56,10 @@ def getattr_recursive(obj: Any, path: str) -> Any: else: obj = getattr(obj, part) return obj + + +def broadcast_trailing_dims(tensor: Float[torch.Tensor, '(c)'], reference: Float[torch.Tensor, '(c) ...']) -> torch.Tensor: + num_trailing = len(reference.shape) - len(tensor.shape) + for _ in range(num_trailing): + tensor = tensor.unsqueeze(-1) + return tensor diff --git a/configs/learn_delta.yaml b/configs/learn_delta.yaml index a7e664f..65f357b 100644 --- a/configs/learn_delta.yaml +++ b/configs/learn_delta.yaml @@ -8,8 +8,8 @@ defaults: run_type: learn_delta max_steps: 1000 -batch_size: 2 -grad_accum_steps: 5 +batch_size: 1 +grad_accum_steps: 10 scale_batch_size: 4 scale_range: [.1, 5] randomize_scale_sign: true diff --git a/learn_delta.py b/learn_delta.py index d144791..71452a3 100644 --- a/learn_delta.py +++ b/learn_delta.py @@ -72,7 +72,7 @@ def main(cfg: DictConfig): data_iter = iter(dataloader) batch = next(data_iter) prompts_embedded = { k: [model.embed_prompt(v) for v in vs] for k, vs in batch.items() if 'prompt' in k } - t_relative = torch.rand((batch_size,)) + t_relative = torch.rand((batch_size,), device=cfg.device) if batch_size != 1: x_0 = model.sample(prompts_embedded['prompt_target'], embs_neg=None, guidance_scale=cfg.base_sample_settings.guidance_scale, output_type='latent') x_t = model.get_x_t(x_0, torch.randn_like(x_0), t_relative) diff --git a/requirements.txt b/requirements.txt index 3df10f3..5e5067d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ accelerate -diffusers>=0.25 +diffusers>=0.27 einops hydra-core jaxtyping