Skip to content

Commit

Permalink
remove duplicate resolution for scaled pos embed
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Nov 1, 2024
1 parent 9aa6f52 commit 82daa98
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 1 deletion.
3 changes: 2 additions & 1 deletion library/sd3_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,8 @@ def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Opti
# remove pos_embed to free up memory up to 0.4 GB
self.pos_embed = None

# sort latent sizes in ascending order
# remove duplcates and sort latent sizes in ascending order
latent_sizes = list(set(latent_sizes))
latent_sizes = sorted(latent_sizes)

patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes]
Expand Down
1 change: 1 addition & 0 deletions sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def train(args):
if args.enable_scaled_pos_embed:
resolutions = train_dataset_group.get_resolutions()
latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in resolutions] # 8 is stride for latent
latent_sizes = list(set(latent_sizes)) # remove duplicates
logger.info(f"Prepare scaled positional embeddings for resolutions: {resolutions}, sizes: {latent_sizes}")
mmdit.enable_scaled_pos_embed(True, latent_sizes)

Expand Down
1 change: 1 addition & 0 deletions sd3_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def load_target_model(self, args, weight_dtype, accelerator):
# set resolutions for positional embeddings
if args.enable_scaled_pos_embed:
latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent
latent_sizes = list(set(latent_sizes)) # remove duplicates
logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}")
mmdit.enable_scaled_pos_embed(True, latent_sizes)

Expand Down

0 comments on commit 82daa98

Please sign in to comment.