From 82daa98fe865c30a34638acc145d6f4ea8c193db Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 1 Nov 2024 21:43:47 +0900 Subject: [PATCH] remove duplicate resolution for scaled pos embed --- library/sd3_models.py | 3 ++- sd3_train.py | 1 + sd3_train_network.py | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index 15a5b1db4..b09a57dbd 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -871,7 +871,8 @@ def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Opti # remove pos_embed to free up memory up to 0.4 GB self.pos_embed = None - # sort latent sizes in ascending order + # remove duplcates and sort latent sizes in ascending order + latent_sizes = list(set(latent_sizes)) latent_sizes = sorted(latent_sizes) patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes] diff --git a/sd3_train.py b/sd3_train.py index 40f8c7e1f..f64e2da2c 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -366,6 +366,7 @@ def train(args): if args.enable_scaled_pos_embed: resolutions = train_dataset_group.get_resolutions() latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in resolutions] # 8 is stride for latent + latent_sizes = list(set(latent_sizes)) # remove duplicates logger.info(f"Prepare scaled positional embeddings for resolutions: {resolutions}, sizes: {latent_sizes}") mmdit.enable_scaled_pos_embed(True, latent_sizes) diff --git a/sd3_train_network.py b/sd3_train_network.py index 9eeac05ca..0739e094d 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -73,6 +73,7 @@ def load_target_model(self, args, weight_dtype, accelerator): # set resolutions for positional embeddings if args.enable_scaled_pos_embed: latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent + latent_sizes = list(set(latent_sizes)) # remove duplicates logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}") mmdit.enable_scaled_pos_embed(True, latent_sizes)