Skip to content

Commit

Permalink
Fix crashing if image is too tall or wide.
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 31, 2024
1 parent 9e23368 commit 830df4a
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions library/sd3_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Opti
self.use_scaled_pos_embed = use_scaled_pos_embed

if self.use_scaled_pos_embed:
# # remove pos_embed to free up memory up to 0.4 GB
# remove pos_embed to free up memory up to 0.4 GB
self.pos_embed = None

# sort latent sizes in ascending order
Expand Down Expand Up @@ -977,7 +977,7 @@ def cropped_pos_embed(self, h, w, device=None, random_crop: bool = False):
# patched size
h = (h + 1) // p
w = (w + 1) // p
if self.pos_embed is None:
if self.pos_embed is None: # should not happen
return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device)
assert self.pos_embed_max_size is not None
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
Expand Down Expand Up @@ -1016,13 +1016,20 @@ def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: b
if patched_size is None:
raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.")

pos_embed_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO)
pos_embed = self.resolution_pos_embeds[patched_size]
pos_embed_size = round(math.sqrt(pos_embed.shape[1]))
if h > pos_embed_size or w > pos_embed_size:
# fallback to normal pos_embed
# # fallback to normal pos_embed
# return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop)
# extend pos_embed size
logger.warning(
f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide."
)
return self.cropped_pos_embed(h, w, device=device, random_crop=random_crop)
pos_embed_size = max(h, w)
pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
self.resolution_pos_embeds[patched_size] = pos_embed
logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}")

if not random_crop:
top = (pos_embed_size - h) // 2
Expand All @@ -1031,7 +1038,6 @@ def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: b
top = torch.randint(0, pos_embed_size - h + 1, (1,)).item()
left = torch.randint(0, pos_embed_size - w + 1, (1,)).item()

pos_embed = self.resolution_pos_embeds[patched_size]
if pos_embed.device != device:
pos_embed = pos_embed.to(device)
# which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device.
Expand Down

0 comments on commit 830df4a

Please sign in to comment.