From 830df4abcc85ffdfe08b8f97f2c8351c86149af3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 31 Oct 2024 21:39:07 +0900 Subject: [PATCH] Fix crashing if image is too tall or wide. --- library/sd3_models.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index 0eca94e2f..15a5b1db4 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -868,7 +868,7 @@ def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Opti self.use_scaled_pos_embed = use_scaled_pos_embed if self.use_scaled_pos_embed: - # # remove pos_embed to free up memory up to 0.4 GB + # remove pos_embed to free up memory up to 0.4 GB self.pos_embed = None # sort latent sizes in ascending order @@ -977,7 +977,7 @@ def cropped_pos_embed(self, h, w, device=None, random_crop: bool = False): # patched size h = (h + 1) // p w = (w + 1) // p - if self.pos_embed is None: + if self.pos_embed is None: # should not happen return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device) assert self.pos_embed_max_size is not None assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size) @@ -1016,13 +1016,20 @@ def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: b if patched_size is None: raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.") - pos_embed_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO) + pos_embed = self.resolution_pos_embeds[patched_size] + pos_embed_size = round(math.sqrt(pos_embed.shape[1])) if h > pos_embed_size or w > pos_embed_size: - # fallback to normal pos_embed + # # fallback to normal pos_embed + # return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop) + # extend pos_embed size logger.warning( f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide." ) - return self.cropped_pos_embed(h, w, device=device, random_crop=random_crop) + pos_embed_size = max(h, w) + pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) + self.resolution_pos_embeds[patched_size] = pos_embed + logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}") if not random_crop: top = (pos_embed_size - h) // 2 @@ -1031,7 +1038,6 @@ def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: b top = torch.randint(0, pos_embed_size - h + 1, (1,)).item() left = torch.randint(0, pos_embed_size - w + 1, (1,)).item() - pos_embed = self.resolution_pos_embeds[patched_size] if pos_embed.device != device: pos_embed = pos_embed.to(device) # which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device.