Skip to content

Commit

Permalink
Fix #2: Device and Shape Mismatch on new Diffusers Versions
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-baumann committed Apr 2, 2024
1 parent d21e918 commit 97516f1
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 8 deletions.
9 changes: 5 additions & 4 deletions attribute_control/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps

from ..utils import reduce_tensors_recursively
from ..utils import reduce_tensors_recursively, broadcast_trailing_dims
from .. import PromptEmbedding, EmbeddingDelta


Expand Down Expand Up @@ -138,7 +138,7 @@ def set_timesteps_custom(num_inference_steps: int = None, timesteps: torch.Tenso
return set_timesteps_orig(num_inference_steps=num_inference_steps, device=device)
self.pipe.scheduler.set_timesteps = set_timesteps_custom

self.num_inference_steps = num_inference_steps # self.pipe.scheduler.num_inference_steps
self.num_inference_steps = num_inference_steps

@abstractmethod
def _get_pipe_kwargs(self, embs: List[PromptEmbedding], start_sample: Optional[Float[torch.Tensor, 'n c h w']], **kwargs):
Expand All @@ -158,7 +158,7 @@ def sample_delayed(self, embs: List[PromptEmbedding], embs_unmodified: List[Prom
return self.sample(embs=embs, embs_neg=embs_neg, start_sample=intermediate, **kwargs, start_after_relative=delay_relative)

def _get_eps_pred(self, t: Integer[torch.Tensor, 'n'], sample: Float[torch.Tensor, 'n ...'], model_output: Float[torch.Tensor, 'n ...']) -> Float[torch.Tensor, 'n ...']:
alpha_prod_t = self.pipe.scheduler.alphas_cumprod[t]
alpha_prod_t = broadcast_trailing_dims(self.pipe.scheduler.alphas_cumprod[t.to(self.pipe.scheduler.alphas_cumprod.device)].to(model_output.device), model_output)
beta_prod_t = 1 - alpha_prod_t
if self.pipe.scheduler.config.prediction_type == "epsilon":
return (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
Expand Down Expand Up @@ -273,7 +273,8 @@ def _compute_time_ids(self, device, weight_dtype) -> torch.Tensor:

def predict_eps(self, embs: List[PromptEmbedding], start_sample: Float[torch.Tensor, 'n c h w'], t_relative: Float[torch.Tensor, 'n']) -> Float[torch.Tensor, 'n c h w']:
i_t = torch.round(t_relative * (self.num_inference_steps - 1)).to(torch.int64)
t = self.pipe.scheduler.timesteps[i_t.to(self.pipe.scheduler.timesteps.device)]
self.pipe.scheduler.set_timesteps(self.num_inference_steps)
t = self.pipe.scheduler.timesteps[i_t.to(self.pipe.scheduler.timesteps.device)].to(start_sample.device)

p_embs = self._get_pipe_kwargs(embs, embs_neg=None, start_sample=None)
add_time_ids = self._compute_time_ids(start_sample.device, start_sample.dtype)
Expand Down
8 changes: 8 additions & 0 deletions attribute_control/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import Callable, Dict, List, Union, Tuple, Any, Optional
import torch
from jaxtyping import Float
from omegaconf import OmegaConf


Expand Down Expand Up @@ -55,3 +56,10 @@ def getattr_recursive(obj: Any, path: str) -> Any:
else:
obj = getattr(obj, part)
return obj


def broadcast_trailing_dims(tensor: Float[torch.Tensor, '(c)'], reference: Float[torch.Tensor, '(c) ...']) -> torch.Tensor:
num_trailing = len(reference.shape) - len(tensor.shape)
for _ in range(num_trailing):
tensor = tensor.unsqueeze(-1)
return tensor
4 changes: 2 additions & 2 deletions configs/learn_delta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ defaults:
run_type: learn_delta

max_steps: 1000
batch_size: 2
grad_accum_steps: 5
batch_size: 1
grad_accum_steps: 10
scale_batch_size: 4
scale_range: [.1, 5]
randomize_scale_sign: true
Expand Down
2 changes: 1 addition & 1 deletion learn_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def main(cfg: DictConfig):
data_iter = iter(dataloader)
batch = next(data_iter)
prompts_embedded = { k: [model.embed_prompt(v) for v in vs] for k, vs in batch.items() if 'prompt' in k }
t_relative = torch.rand((batch_size,))
t_relative = torch.rand((batch_size,), device=cfg.device)
if batch_size != 1:
x_0 = model.sample(prompts_embedded['prompt_target'], embs_neg=None, guidance_scale=cfg.base_sample_settings.guidance_scale, output_type='latent')
x_t = model.get_x_t(x_0, torch.randn_like(x_0), t_relative)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
accelerate
diffusers>=0.25
diffusers>=0.27
einops
hydra-core
jaxtyping
Expand Down

0 comments on commit 97516f1

Please sign in to comment.