From f89468bb5e20a6c4701fd91928a1991d9354344e Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 25 Apr 2024 16:26:06 -0600 Subject: [PATCH 01/37] rounding never ends, fix --aspect_bucket_rounding not being applied correctly, and a unit test failure that indicates it now works --- helpers/multiaspect/image.py | 5 ++++- tests/test_image.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/helpers/multiaspect/image.py b/helpers/multiaspect/image.py index beaa196f..a4a563f2 100644 --- a/helpers/multiaspect/image.py +++ b/helpers/multiaspect/image.py @@ -471,13 +471,16 @@ def calculate_image_aspect_ratio(image, rounding: int = 2): float: The rounded aspect ratio of the image. """ to_round = StateTracker.get_args().aspect_bucket_rounding - if not isinstance(image, int): + if to_round is None: to_round = rounding if isinstance(image, Image.Image): + # An actual image was passed in. width, height = image.size elif isinstance(image, tuple): + # An image.size or a similar (W, H) tuple was provided. width, height = image elif isinstance(image, float): + # An externally-calculated aspect ratio was given to round. return round(image, to_round) else: width, height = image.size diff --git a/tests/test_image.py b/tests/test_image.py index 2cec4a1e..7bf1aa23 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -33,6 +33,7 @@ def test_aspect_ratio_calculation(self): """ Test that the aspect ratio calculation returns expected results. """ + StateTracker.set_args(MagicMock(aspect_bucket_rounding=2)) self.assertEqual( MultiaspectImage.calculate_image_aspect_ratio((1920, 1080)), 1.78 ) From d8f777a957e203038a85d0f5a34648c0bb9ae181 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 25 Apr 2024 17:06:02 -0600 Subject: [PATCH 02/37] parquet: use global bucket precision --- helpers/metadata/backends/parquet.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/helpers/metadata/backends/parquet.py b/helpers/metadata/backends/parquet.py index 07c6275a..cbd7a057 100644 --- a/helpers/metadata/backends/parquet.py +++ b/helpers/metadata/backends/parquet.py @@ -355,7 +355,9 @@ def _process_for_bucket( image_metadata["original_size"][0] / image_metadata["original_size"][1] ) - aspect_ratio = round(aspect_ratio, aspect_ratio_rounding) + aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio( + image_metadata["original_size"] + ) target_size, crop_coordinates, new_aspect_ratio = ( MultiaspectImage.prepare_image( image_metadata=image_metadata, From fa813a2e5e75e4bb0e515c33dbec5bcabe8bcb8a Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 25 Apr 2024 17:07:40 -0600 Subject: [PATCH 03/37] (#365) exit only after epoch has completed. reduce ambiguity on epochs vs steps by requiring one or the other, not both --- TUTORIAL.md | 7 +++---- helpers/arguments.py | 6 ++++++ train_sd21.py | 13 ++++++------- train_sd2x.sh | 7 ++++++- train_sdxl.py | 6 +++--- train_sdxl.sh | 7 ++++++- 6 files changed, 30 insertions(+), 16 deletions(-) diff --git a/TUTORIAL.md b/TUTORIAL.md index 63fc3e61..1795c6e9 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -225,9 +225,8 @@ Here's a breakdown of what each environment variable does: #### Data Locations -- `BASE_DIR`, `INSTANCE_DIR`, `OUTPUT_DIR`: Directories for the training data, instance data, and output models. +- `BASE_DIR`, `OUTPUT_DIR`: Directories for the training data, instance data, and output models. - `BASE_DIR` - Used for populating other variables, mostly. - - `INSTANCE_DIR` - Where your actual training data is. This can be anywhere, it does not need to be underneath `BASE_DIR`. - `OUTPUT_DIR` - Where the model pipeline results are stored during training, and after it completes. #### Training Parameters @@ -236,9 +235,9 @@ Here's a breakdown of what each environment variable does: - If you use `MAX_NUM_STEPS`, it's recommended to set `NUM_EPOCHS` to `0`. - Similarly, if you use `NUM_EPOCHS`, it is recommended to set `MAX_NUM_STEPS` to `0`. - This simply signals to the trainer that you explicitly wish to use one or the other. - - If you supply `NUM_EPOCHS` and `MAX_NUM_STEPS` together, the training will stop running at whichever happens first. + - Don't supply `NUM_EPOCHS` and `MAX_NUM_STEPS` values together, it won't let you begin training, to ensure there is no ambiguity about which you expect to take priority. - `LR_SCHEDULE`, `LR_WARMUP_STEPS`: Learning rate schedule and warmup steps. - - `LR_SCHEDULE` - stick to `constant`, as it is most likely to be stable and less chaotic. However, `polynomial` and `constant_with_warmup` have potential of moving the model's local minima before settling in and reducing the loss. Experimentation can pay off here. + - `LR_SCHEDULE` - stick to `constant`, as it is most likely to be stable and less chaotic. However, `polynomial` and `constant_with_warmup` have potential of moving the model's local minima before settling in and reducing the loss. Experimentation can pay off here, especially using the `cosine` and `sine` schedulers, which offer a unique approach to learning rate scheduling. - `TRAIN_BATCH_SIZE`: Batch size for training. You want this **as high as you can get it** without running out of VRAM or making your training unnecessarily **slow** (eg. 300-400% increase in training runtime - yikes! 💸) ## Additional Notes diff --git a/helpers/arguments.py b/helpers/arguments.py index c3669c89..14d7b7d5 100644 --- a/helpers/arguments.py +++ b/helpers/arguments.py @@ -1346,6 +1346,12 @@ def parse_args(input_args=None): ) sys.exit(1) + if args.max_train_steps is not None and args.num_train_epochs > 0: + logger.error( + "When using --max_train_steps (MAX_NUM_STEPS), you must set --num_train_epochs (NUM_EPOCHS) to 0." + ) + sys.exit(1) + if ( args.pretrained_vae_model_name_or_path is not None and StateTracker.get_model_type() == "legacy" diff --git a/train_sd21.py b/train_sd21.py index 20f29226..49a0342a 100644 --- a/train_sd21.py +++ b/train_sd21.py @@ -888,7 +888,7 @@ def main(): f"Resuming from epoch {first_epoch}, which leaves us with {total_steps_remaining_at_start}." ) current_epoch = first_epoch - if current_epoch >= args.num_train_epochs: + if current_epoch >= args.num_train_epochs + 1: logger.info( f"Reached the end ({current_epoch} epochs) of our training run ({args.num_train_epochs} epochs). This run will do zero steps." ) @@ -960,10 +960,10 @@ def main(): current_epoch_step = None for epoch in range(first_epoch, args.num_train_epochs): - if current_epoch >= args.num_train_epochs: + if current_epoch >= args.num_train_epochs + 1: # This might immediately end training, but that's useful for simply exporting the model. logger.info( - f"Reached the end ({current_epoch} epochs) of our training run ({args.num_train_epochs} epochs)." + f"Training run is complete ({args.num_train_epochs}/{args.num_epochs} epochs, {global_step}/{args.max_train_steps} steps)." ) break if first_epoch != epoch: @@ -1391,13 +1391,12 @@ def main(): SCHEDULER_NAME_MAP=SCHEDULER_NAME_MAP, ) - if global_step >= args.max_train_steps or epoch > args.num_train_epochs: + if global_step >= args.max_train_steps or epoch > args.num_train_epochs + 1: logger.info( - f"Training has completed.", - f"\n -> global_step = {global_step}, max_train_steps = {args.max_train_steps}, epoch = {epoch}, num_train_epochs = {args.num_train_epochs}", + f"Training run is complete ({args.num_train_epochs}/{args.num_epochs} epochs, {global_step}/{args.max_train_steps} steps)." ) break - if global_step >= args.max_train_steps or epoch > args.num_train_epochs: + if global_step >= args.max_train_steps or epoch > args.num_train_epochs + 1: logger.info( f"Exiting training loop. Beginning model unwind at epoch {epoch}, step {global_step}" ) diff --git a/train_sd2x.sh b/train_sd2x.sh index dab53b40..b2310403 100644 --- a/train_sd2x.sh +++ b/train_sd2x.sh @@ -299,10 +299,15 @@ if [ -n "$ASPECT_BUCKET_ROUNDING" ]; then export ASPECT_BUCKET_ROUNDING_ARGS="--aspect_bucket_rounding=${ASPECT_BUCKET_ROUNDING}" fi +export MAX_TRAIN_STEPS_ARGS="" +if [ -n "$MAX_NUM_STEPS" ] && [[ "$MAX_NUM_STEPS" != 0 ]]; then + export MAX_TRAIN_STEPS_ARGS="--max_train_steps=${MAX_NUM_STEPS}" +fi + accelerate launch ${ACCELERATE_EXTRA_ARGS} --mixed_precision="${MIXED_PRECISION}" --num_processes="${TRAINING_NUM_PROCESSES}" --num_machines="${TRAINING_NUM_MACHINES}" --dynamo_backend="${TRAINING_DYNAMO_BACKEND}" train_sd21.py \ --model_type="${MODEL_TYPE}" ${DORA_ARGS} --pretrained_model_name_or_path="${MODEL_NAME}" ${XFORMERS_ARG} ${GRADIENT_ARG} --set_grads_to_none --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ --resume_from_checkpoint="${RESUME_CHECKPOINT}" ${DELETE_ARGS} ${SNR_GAMMA_ARG} --data_backend_config="${DATALOADER_CONFIG}" ${OVERRIDE_DATALOADER_CONFIG_ARG} \ - --num_train_epochs=${NUM_EPOCHS} --max_train_steps=${MAX_NUM_STEPS} --metadata_update_interval=${METADATA_UPDATE_INTERVAL} \ + --num_train_epochs=${NUM_EPOCHS} ${MAX_TRAIN_STEPS_ARGS} --metadata_update_interval=${METADATA_UPDATE_INTERVAL} \ --learning_rate="${LEARNING_RATE}" --lr_scheduler="${LR_SCHEDULE}" --seed "${TRAINING_SEED}" --lr_warmup_steps="${LR_WARMUP_STEPS}" --lr_end="${LEARNING_RATE_END}" \ --output_dir="${OUTPUT_DIR}" \ --inference_scheduler_timestep_spacing="${INFERENCE_SCHEDULER_TIMESTEP_SPACING}" --training_scheduler_timestep_spacing="${TRAINING_SCHEDULER_TIMESTEP_SPACING}" \ diff --git a/train_sdxl.py b/train_sdxl.py index b348e1a1..88631143 100644 --- a/train_sdxl.py +++ b/train_sdxl.py @@ -984,7 +984,7 @@ def main(): current_epoch = first_epoch StateTracker.set_epoch(current_epoch) - if current_epoch > args.num_train_epochs: + if current_epoch > args.num_train_epochs + 1: logger.info( f"Reached the end ({current_epoch} epochs) of our training run ({args.num_train_epochs} epochs). This run will do zero steps." ) @@ -1054,10 +1054,10 @@ def main(): current_epoch_step = None for epoch in range(first_epoch, args.num_train_epochs + 1): - if current_epoch > args.num_train_epochs: + if current_epoch > args.num_train_epochs + 1: # This might immediately end training, but that's useful for simply exporting the model. logger.info( - f"Reached the end ({current_epoch} epochs) of our training run ({args.num_train_epochs} epochs)." + f"Training run is complete ({args.num_train_epochs}/{args.num_epochs} epochs, {global_step}/{args.max_train_steps} steps)." ) break if first_epoch != epoch: diff --git a/train_sdxl.sh b/train_sdxl.sh index 15c20d45..b7bd28f1 100644 --- a/train_sdxl.sh +++ b/train_sdxl.sh @@ -280,11 +280,16 @@ if [ -n "$ASPECT_BUCKET_ROUNDING" ]; then export ASPECT_BUCKET_ROUNDING_ARGS="--aspect_bucket_rounding=${ASPECT_BUCKET_ROUNDING}" fi +export MAX_TRAIN_STEPS_ARGS="" +if [ -n "$MAX_NUM_STEPS" ] && [[ "$MAX_NUM_STEPS" != 0 ]]; then + export MAX_TRAIN_STEPS_ARGS="--max_train_steps=${MAX_NUM_STEPS}" +fi + # Run the training script. accelerate launch ${ACCELERATE_EXTRA_ARGS} --mixed_precision="${MIXED_PRECISION}" --num_processes="${TRAINING_NUM_PROCESSES}" --num_machines="${TRAINING_NUM_MACHINES}" --dynamo_backend="${TRAINING_DYNAMO_BACKEND}" train_sdxl.py \ --model_type="${MODEL_TYPE}" ${DORA_ARGS} --pretrained_model_name_or_path="${MODEL_NAME}" ${XFORMERS_ARG} ${GRADIENT_ARG} --set_grads_to_none --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ --resume_from_checkpoint="${RESUME_CHECKPOINT}" ${DELETE_ARGS} ${SNR_GAMMA_ARG} --data_backend_config="${DATALOADER_CONFIG}" \ - --num_train_epochs=${NUM_EPOCHS} --max_train_steps=${MAX_NUM_STEPS} --metadata_update_interval=${METADATA_UPDATE_INTERVAL} \ + --num_train_epochs=${NUM_EPOCHS} ${MAX_TRAIN_STEPS_ARGS} --metadata_update_interval=${METADATA_UPDATE_INTERVAL} \ ${OPTIMIZER_ARG} --learning_rate="${LEARNING_RATE}" --lr_scheduler="${LR_SCHEDULE}" --seed "${TRAINING_SEED}" --lr_warmup_steps="${LR_WARMUP_STEPS}" \ --output_dir="${OUTPUT_DIR}" ${BITFIT_ARGS} ${ASPECT_BUCKET_ROUNDING_ARGS} \ --inference_scheduler_timestep_spacing="${INFERENCE_SCHEDULER_TIMESTEP_SPACING}" --training_scheduler_timestep_spacing="${TRAINING_SCHEDULER_TIMESTEP_SPACING}" \ From 9ff3f04f4afee396e483d82872a36fb8ee714e55 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 25 Apr 2024 17:09:09 -0600 Subject: [PATCH 04/37] (#365) update example configurations to have only max train steps set --- sd21-env.sh.example | 10 ++-------- sdxl-env.sh.example | 4 ++-- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/sd21-env.sh.example b/sd21-env.sh.example index dab4cd20..03f64f26 100644 --- a/sd21-env.sh.example +++ b/sd21-env.sh.example @@ -67,19 +67,13 @@ export TRACKER_RUN_NAME TRACKER_RUN_NAME="$(date +%s)" # Location of training data. export BASE_DIR="/notebooks/datasets" -export INSTANCE_DIR="${BASE_DIR}/training_data" export OUTPUT_DIR="${BASE_DIR}/models" export DATALOADER_CONFIG="multidatabackend_sd2x.json" -# Some data that we generate will be cached here. -export STATE_PATH="${BASE_DIR}/training_state.json" -# Store whether we've seen an image or not, to prevent repeats. -export SEEN_STATE_PATH="${BASE_DIR}/training_images_seen.json" - -# Max number of steps OR epochs can be used. But we default to Epochs. +# Max number of steps OR epochs can be used. Not both. export MAX_NUM_STEPS=30000 # Will likely overtrain, but that's fine. -export NUM_EPOCHS=25 +export NUM_EPOCHS=0 # Adjust this for your GPU memory size. export TRAIN_BATCH_SIZE=1 diff --git a/sdxl-env.sh.example b/sdxl-env.sh.example index 1649512c..2c719ae7 100644 --- a/sdxl-env.sh.example +++ b/sdxl-env.sh.example @@ -45,10 +45,10 @@ export DEBUG_EXTRA_ARGS="--report_to=wandb" export TRACKER_PROJECT_NAME="sdxl-training" export TRACKER_RUN_NAME="simpletuner-sdxl" -# Max number of steps OR epochs can be used. But we default to Epochs. +# Max number of steps OR epochs can be used. Not both. export MAX_NUM_STEPS=30000 # Will likely overtrain, but that's fine. -export NUM_EPOCHS=25 +export NUM_EPOCHS=0 # A convenient prefix for all of your training paths. export BASE_DIR="/notebooks/datasets" From ccbc2780bd19921f05d5466650994c97bed45209 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 25 Apr 2024 19:14:29 -0600 Subject: [PATCH 05/37] look up metadata value during processing instead of calculating it again --- helpers/caching/vae.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/helpers/caching/vae.py b/helpers/caching/vae.py index e29c9c73..1759fab0 100644 --- a/helpers/caching/vae.py +++ b/helpers/caching/vae.py @@ -595,7 +595,11 @@ def _process_images_in_batch( # retrieve image data from Generator, image_data: filepath = image_paths.pop() image = image_data.pop() - aspect_bucket = MultiaspectImage.calculate_image_aspect_ratio(image) + aspect_bucket = ( + self.metadata_backend.get_metadata_attribute_by_filepath( + filepath=filepath, attribute="aspect_bucket" + ) + ) else: filepath, image, aspect_bucket = self.process_queue.get() if self.minimum_image_size is not None: From 40b004b34983fba822ec0a4fb166e816150b61d5 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 25 Apr 2024 20:11:57 -0600 Subject: [PATCH 06/37] prevent overwrite of the text encoder value when loading the model and training the text encoder fixes #371 --- helpers/legacy/sd_files.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/helpers/legacy/sd_files.py b/helpers/legacy/sd_files.py index 3ada352a..2c1be614 100644 --- a/helpers/legacy/sd_files.py +++ b/helpers/legacy/sd_files.py @@ -134,14 +134,14 @@ def load_model_hook(models, input_dir): ) unet_ = None - text_encoder = None + text_encoder_ = None while len(models) > 0: model = models.pop() if isinstance(model, type(unwrap_model(accelerator, unet))): unet_ = model elif isinstance(model, type(unwrap_model(accelerator, text_encoder))): - text_encoder = model + text_encoder_ = model else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -173,7 +173,7 @@ def load_model_hook(models, input_dir): _set_state_dict_into_text_encoder( lora_state_dict, prefix="text_encoder.", - text_encoder=text_encoder, + text_encoder=text_encoder_, ) logger.info("Completed loading LoRA weights.") From 0a4e0e2b3bb6641067addadc4273ca15e4e41335 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 25 Apr 2024 20:18:15 -0600 Subject: [PATCH 07/37] move loop into lora block --- helpers/legacy/sd_files.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/helpers/legacy/sd_files.py b/helpers/legacy/sd_files.py index 2c1be614..8b7c14ff 100644 --- a/helpers/legacy/sd_files.py +++ b/helpers/legacy/sd_files.py @@ -133,19 +133,18 @@ def load_model_hook(models, input_dir): f"Could not find training_state.json in checkpoint dir {input_dir}" ) - unet_ = None - text_encoder_ = None - while len(models) > 0: - model = models.pop() - - if isinstance(model, type(unwrap_model(accelerator, unet))): - unet_ = model - elif isinstance(model, type(unwrap_model(accelerator, text_encoder))): - text_encoder_ = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") - if "lora" in args.model_type: + unet_ = None + text_encoder_ = None + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(accelerator, unet))): + unet_ = model + elif isinstance(model, type(unwrap_model(accelerator, text_encoder))): + text_encoder_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") logger.info(f"Loading LoRA weights from Path: {input_dir}") lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) From 1272d71f7cc96a64a1795e8367dfd028c3ed9154 Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 28 Apr 2024 09:28:05 -0600 Subject: [PATCH 08/37] write timesteps buffer using global_step as the x coordinate --- train_sd21.py | 6 +++--- train_sdxl.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/train_sd21.py b/train_sd21.py index 49a0342a..63a8a498 100644 --- a/train_sd21.py +++ b/train_sd21.py @@ -963,7 +963,7 @@ def main(): if current_epoch >= args.num_train_epochs + 1: # This might immediately end training, but that's useful for simply exporting the model. logger.info( - f"Training run is complete ({args.num_train_epochs}/{args.num_epochs} epochs, {global_step}/{args.max_train_steps} steps)." + f"Training run is complete ({args.num_train_epochs}/{args.num_train_epochs} epochs, {global_step}/{args.max_train_steps} steps)." ) break if first_epoch != epoch: @@ -1080,7 +1080,7 @@ def main(): # Prepare the data for the scatter plot for timestep in timesteps.tolist(): - timesteps_buffer.append((step, timestep)) + timesteps_buffer.append((global_step, timestep)) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -1393,7 +1393,7 @@ def main(): if global_step >= args.max_train_steps or epoch > args.num_train_epochs + 1: logger.info( - f"Training run is complete ({args.num_train_epochs}/{args.num_epochs} epochs, {global_step}/{args.max_train_steps} steps)." + f"Training run is complete ({args.num_train_epochs}/{args.num_train_epochs} epochs, {global_step}/{args.max_train_steps} steps)." ) break if global_step >= args.max_train_steps or epoch > args.num_train_epochs + 1: diff --git a/train_sdxl.py b/train_sdxl.py index 88631143..e4323bec 100644 --- a/train_sdxl.py +++ b/train_sdxl.py @@ -1057,7 +1057,7 @@ def main(): if current_epoch > args.num_train_epochs + 1: # This might immediately end training, but that's useful for simply exporting the model. logger.info( - f"Training run is complete ({args.num_train_epochs}/{args.num_epochs} epochs, {global_step}/{args.max_train_steps} steps)." + f"Training run is complete ({args.num_train_epochs}/{args.num_train_epochs} epochs, {global_step}/{args.max_train_steps} steps)." ) break if first_epoch != epoch: @@ -1175,7 +1175,7 @@ def main(): # Prepare the data for the scatter plot for timestep in timesteps.tolist(): - timesteps_buffer.append((step, timestep)) + timesteps_buffer.append((global_step, timestep)) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) From 38ba5011a8eecc73df1c4fb98c2758360f9f859c Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 28 Apr 2024 09:31:37 -0600 Subject: [PATCH 09/37] validations: fix df stage I eval (pt2) --- helpers/legacy/validation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/helpers/legacy/validation.py b/helpers/legacy/validation.py index 269c9c9f..9954df8c 100644 --- a/helpers/legacy/validation.py +++ b/helpers/legacy/validation.py @@ -491,7 +491,7 @@ def log_validations( validation_resolutions = ( get_validation_resolutions() - if "deepfloyd" not in args.model_type + if "deepfloyd-stage2" not in args.model_type else ["base-256"] ) logger.debug(f"Resolutions for validation: {validation_resolutions}") @@ -499,7 +499,7 @@ def log_validations( validation_images[validation_shortname] = [] for resolution in validation_resolutions: - if "deepfloyd" not in args.model_type: + if "deepfloyd-stage2" not in args.model_type: validation_resolution_width, validation_resolution_height = ( resolution ) From 15dbf85d2b82e8b724ab6da5428a28efcde457be Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 30 Apr 2024 18:59:43 -0600 Subject: [PATCH 10/37] metadata scan cannot find a single image dataset, maybe sleeping is helpful --- helpers/metadata/backends/base.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/helpers/metadata/backends/base.py b/helpers/metadata/backends/base.py index de000c07..766305c0 100644 --- a/helpers/metadata/backends/base.py +++ b/helpers/metadata/backends/base.py @@ -110,14 +110,22 @@ def _bucket_worker( for file in files: if str(file) not in existing_files_set: logger.debug(f"Processing file {file}.") - local_aspect_ratio_bucket_indices = self._process_for_bucket( - file, - local_aspect_ratio_bucket_indices, - metadata_updates=local_metadata_updates, - delete_problematic_images=self.delete_problematic_images, - statistics=statistics, + try: + local_aspect_ratio_bucket_indices = self._process_for_bucket( + file, + local_aspect_ratio_bucket_indices, + metadata_updates=local_metadata_updates, + delete_problematic_images=self.delete_problematic_images, + statistics=statistics, + ) + except Exception as e: + logger.error( + f"Error processing file {file}. Reason: {e}. Skipping." + ) + statistics["skipped"]["error"] += 1 + logger.debug( + f"Statistics: {statistics}, total: {sum([len(bucket) for bucket in local_aspect_ratio_bucket_indices.values()])}" ) - logger.debug(f"Statistics: {statistics}") processed_file_count += 1 # Successfully processed statistics["total_processed"] = processed_file_count @@ -146,6 +154,7 @@ def _bucket_worker( metadata_updates_queue.put(local_metadata_updates) # At the end of the _bucket_worker method metadata_updates_queue.put(("statistics", statistics)) + time.sleep(0.001) logger.debug(f"Bucket worker completed processing. Returning to main thread.") def compute_aspect_ratio_bucket_indices(self): @@ -745,9 +754,7 @@ def handle_vae_cache_inconsistencies(self, vae_cache, vae_cache_behavior: str): # Update any state or metadata post-processing self.save_cache() - def _recalculate_target_resolution( - self, original_aspect_ratio: float - ) -> tuple: + def _recalculate_target_resolution(self, original_aspect_ratio: float) -> tuple: """Given the original resolution, use our backend config to properly recalculate the size.""" resolution_type = StateTracker.get_data_backend_config(self.id)[ "resolution_type" @@ -803,7 +810,9 @@ def is_cache_inconsistent(self, vae_cache, cache_file, cache_content): target_resolution = tuple(metadata_target_size) recalculated_width, recalculated_height, recalculated_aspect_ratio = ( self._recalculate_target_resolution( - original_aspect_ratio=MultiaspectImage.calculate_image_aspect_ratio(original_resolution) + original_aspect_ratio=MultiaspectImage.calculate_image_aspect_ratio( + original_resolution + ) ) ) recalculated_target_resolution = (recalculated_width, recalculated_height) From 1149c70c49caee6f26316355ad8dd86ff4f68bfd Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Tue, 30 Apr 2024 20:29:45 -0700 Subject: [PATCH 11/37] Add install/rocm poetry config Barebones, based on the Mac config. --- install/rocm/pyproject.toml | 51 +++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 install/rocm/pyproject.toml diff --git a/install/rocm/pyproject.toml b/install/rocm/pyproject.toml new file mode 100644 index 00000000..38a1cc24 --- /dev/null +++ b/install/rocm/pyproject.toml @@ -0,0 +1,51 @@ +[tool.poetry] +name = "simpletuner" +version = "0.1.0" +description = "Stable Diffusion 2.x and XL tuner." +authors = ["bghira"] +license = "SUL" +readme = "README.md" + +[tool.poetry.dependencies] +python = ">=3.9,<3.13" +torch = {version = "^2.3", source = "pytorch-rocm"} +torchmetrics = {version = "^1", source = "pytorch-rocm"} +torchvision = {version = "*", source = "pytorch-rocm"} +triton = {version = "*", source = "pytorch-rocm"} +pytorch-triton-rocm = {version = "*", source = "pytorch-rocm"} +accelerate = "^0.26" +boto3 = "^1" +botocore = "^1" +clip-interrogator = "^0.6" +colorama = "^0.4" +compel = "^2" +dadaptation = "^3" +datasets = "^2" +diffusers = "^0.27" +iterutils = "^0.1" +numpy = "^1" +open-clip-torch = "^2" +opencv-python = "^4" +pandas = "^2" +peft = "^0.9" +pillow = "^10" +prodigyopt = "^1" +regex = "^2023.12.25" +requests = "^2" +safetensors = "^0.4" +scipy = "^1" +tensorboard = "^2" +torchsde = "^0.2" +transformers = "^4" +urllib3 = "<1.27" +wandb = "^0.16" + + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" + +[[tool.poetry.source]] +priority = "explicit" +name = "pytorch-rocm" +url = "https://download.pytorch.org/whl/rocm6.0" From 68e469c1386ed499a79dcd1953d7461dfe9bb347 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Tue, 30 Apr 2024 20:54:54 -0700 Subject: [PATCH 12/37] Update INSTALL.md for ROCm --- INSTALL.md | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/INSTALL.md b/INSTALL.md index 067fbe30..bc8ef45c 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -22,7 +22,7 @@ To install the Apple-specific requirements: poetry install --no-root -C install/apple ``` -### Linux +### Linux + Nvidia/CUDA The first command you'll run will install most of the dependencies: @@ -59,6 +59,16 @@ If the egg install for Xformers does not work, try including `xformers` on the f pip3 install --pre xformers torch torchvision torchaudio torchtriton --extra-index-url https://download.pytorch.org/whl/nightly/cu118 --force ``` +### Linux + AMD / ROCm +Due to `xformers` not supporting the ROCm platform, memory requirements for training will likely be higher than otherwise stated. + +To install the ROCm-specific requirements: + +```bash +poetry install --no-root -C install/rocm +``` + + ### All platforms 2. For SD2.1, copy `sd21-env.sh.example` to `env.sh` - be sure to fill out the details. Try to change as little as possible. @@ -89,4 +99,4 @@ For SDXL, run the `train_sdxl.sh` script, redirecting outputs to the log file: bash train_sdxl.sh > /path/to/training-$(date +%s).log 2>&1 ``` -> ⚠️ At this point, the commands will work, but further configuration is required. See [the tutorial](/TUTORIAL.md) for more information. \ No newline at end of file +> ⚠️ At this point, the commands will work, but further configuration is required. See [the tutorial](/TUTORIAL.md) for more information. From c688c1ffc524f9674a9f64a40974b56afce0d18c Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Tue, 30 Apr 2024 20:56:43 -0700 Subject: [PATCH 13/37] Update README.md for ROCm --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d4c91d8b..a13a6fa4 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ EMA (exponential moving average) weights are a memory-heavy affair, but provide ### GPU vendors * NVIDIA - pretty much anything 3090 and up is a safe bet. YMMV. -* AMD - No one has reported anything, we don't know. +* AMD - LoRA is tested to work on a 7900 XTX. Lacking `xformers`, it will likely use more memory than Nvidia equivalents * Apple - LoRA and full u-net tuning are tested to work on an M3 Max with 128G memory, taking about **12G** of "Wired" memory and **4G** of system memory for SDXL. * You likely need a 24G or greater machine for machine learning with M-series hardware due to the lack of memory-efficient attention. From cd31a99322a341b7336a6b7a4f1a340513c1e220 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Tue, 30 Apr 2024 23:20:08 -0700 Subject: [PATCH 14/37] Add deepspeed to ROCm requirements 0.14 instead of 0.10 because the old versions don't work on Navi --- install/rocm/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/install/rocm/pyproject.toml b/install/rocm/pyproject.toml index 38a1cc24..38ad68e0 100644 --- a/install/rocm/pyproject.toml +++ b/install/rocm/pyproject.toml @@ -21,6 +21,7 @@ colorama = "^0.4" compel = "^2" dadaptation = "^3" datasets = "^2" +deepspeed = "^0.14" diffusers = "^0.27" iterutils = "^0.1" numpy = "^1" From 337ddb7b26ecb566aa801966430d13fe5b4e480d Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Tue, 30 Apr 2024 23:22:55 -0700 Subject: [PATCH 15/37] Update README for ROCm UNet --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a13a6fa4..d30a269b 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ EMA (exponential moving average) weights are a memory-heavy affair, but provide ### GPU vendors * NVIDIA - pretty much anything 3090 and up is a safe bet. YMMV. -* AMD - LoRA is tested to work on a 7900 XTX. Lacking `xformers`, it will likely use more memory than Nvidia equivalents +* AMD - SDXL LoRA and UNet are verified working on a 7900 XTX 24GB. Lacking `xformers`, it will likely use more memory than Nvidia equivalents * Apple - LoRA and full u-net tuning are tested to work on an M3 Max with 128G memory, taking about **12G** of "Wired" memory and **4G** of system memory for SDXL. * You likely need a 24G or greater machine for machine learning with M-series hardware due to the lack of memory-efficient attention. From 7f21979e092f38d2fabc75ca553887ece761cec2 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Wed, 1 May 2024 17:54:31 -0700 Subject: [PATCH 16/37] Add TorchAudio to ROCm for CUDA parity Should have all libs now except - `triton-library` - `bitsandbytes` - `xformers` --- install/rocm/pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/install/rocm/pyproject.toml b/install/rocm/pyproject.toml index 38ad68e0..af0a1128 100644 --- a/install/rocm/pyproject.toml +++ b/install/rocm/pyproject.toml @@ -9,6 +9,7 @@ readme = "README.md" [tool.poetry.dependencies] python = ">=3.9,<3.13" torch = {version = "^2.3", source = "pytorch-rocm"} +torchaudio = {version = "*", source = "pytorch-rocm"} torchmetrics = {version = "^1", source = "pytorch-rocm"} torchvision = {version = "*", source = "pytorch-rocm"} triton = {version = "*", source = "pytorch-rocm"} @@ -31,13 +32,13 @@ pandas = "^2" peft = "^0.9" pillow = "^10" prodigyopt = "^1" -regex = "^2023.12.25" requests = "^2" safetensors = "^0.4" scipy = "^1" tensorboard = "^2" torchsde = "^0.2" transformers = "^4" +# triton-library = "^1.0.0rc2" urllib3 = "<1.27" wandb = "^0.16" From 681316b64a3b38114ef68b68bfb660fef9ee9e47 Mon Sep 17 00:00:00 2001 From: bghira Date: Wed, 1 May 2024 22:07:19 -0600 Subject: [PATCH 17/37] remove obsoleted config line --- documentation/DEEPFLOYD.md | 1 - 1 file changed, 1 deletion(-) diff --git a/documentation/DEEPFLOYD.md b/documentation/DEEPFLOYD.md index 95dd4521..f908b308 100644 --- a/documentation/DEEPFLOYD.md +++ b/documentation/DEEPFLOYD.md @@ -108,7 +108,6 @@ export LEARNING_RATE_END=4e-6 #@param {type:"number"} # Where to store your results. export BASE_DIR="/training" -export INSTANCE_DIR="${BASE_DIR}/data" export OUTPUT_DIR="${BASE_DIR}/models/deepfloyd" export DATALOADER_CONFIG="multidatabackend_deepfloyd.json" From 09beb93bb1515b8b16059246b645665ab03ba877 Mon Sep 17 00:00:00 2001 From: bghira Date: Wed, 1 May 2024 22:07:57 -0600 Subject: [PATCH 18/37] parquet logging fix for aspect ratio referenced --- helpers/metadata/backends/parquet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helpers/metadata/backends/parquet.py b/helpers/metadata/backends/parquet.py index cbd7a057..b1cb48fd 100644 --- a/helpers/metadata/backends/parquet.py +++ b/helpers/metadata/backends/parquet.py @@ -377,7 +377,7 @@ def _process_for_bucket( else: image_metadata["luminance"] = 0 logger.debug( - f"Image {image_path_str} has aspect ratio {aspect_ratio} and size {image_metadata['target_size']}." + f"Image {image_path_str} has aspect ratio {new_aspect_ratio} and size {image_metadata['target_size']}." ) # Create a new bucket if it doesn't exist From 51e473db5e0160009c37dfbd265c4d1ae4b88de8 Mon Sep 17 00:00:00 2001 From: bghira Date: Wed, 1 May 2024 22:09:27 -0600 Subject: [PATCH 19/37] multiaspectimage: extract the cropping logic to a module --- helpers/multiaspect/image.py | 159 +++++++---------------------------- 1 file changed, 29 insertions(+), 130 deletions(-) diff --git a/helpers/multiaspect/image.py b/helpers/multiaspect/image.py index a4a563f2..3ad6e5d8 100644 --- a/helpers/multiaspect/image.py +++ b/helpers/multiaspect/image.py @@ -6,9 +6,20 @@ import logging, os, random from math import sqrt from helpers.training.state_tracker import StateTracker +from helpers.image_manipulation.cropping import ( + CornerCropping, + CenterCropping, + RandomCropping, +) logger = logging.getLogger("MultiaspectImage") logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) +crop_handlers = { + "corner": CornerCropping, + "centre": CenterCropping, + "center": CenterCropping, + "random": RandomCropping, +} class MultiaspectImage: @@ -116,6 +127,10 @@ def prepare_image( ) if crop: + crop_handler_cls = crop_handlers.get(crop_style) + if not crop_handler_cls: + raise ValueError(f"Unknown crop style: {crop_style}") + crop_handler = crop_handler_cls(image=image, image_metadata=image_metadata) if downsample_before_crop: logger.debug( f"Resizing image before crop, as its size is too large. Data backend: {id}, image size: {image.size}, target size: {target_width}x{target_height}" @@ -154,48 +169,13 @@ def prepare_image( else (target_width, target_height) ) + crop_result = crop_handler.crop(crop_width, crop_height) + if image: - if crop_style == "corner": - image, crop_coordinates = MultiaspectImage._crop_corner( - image, crop_width, crop_height - ) - elif crop_style in ["centre", "center"]: - image, crop_coordinates = MultiaspectImage._crop_center( - image, crop_width, crop_height - ) - elif crop_style == "face": - image, crop_coordinates = MultiaspectImage._crop_face( - image, crop_width, crop_height - ) - elif crop_style == "random": - image, crop_coordinates = MultiaspectImage._crop_random( - image, crop_width, crop_height - ) - else: - raise ValueError(f"Unknown crop style: {crop_style}") + image, crop_coordinates = crop_result + logger.debug(f"After cropping, our image size: {image.size}") elif image_metadata: - if crop_style == "corner": - _, crop_coordinates = MultiaspectImage._crop_corner( - image_metadata=image_metadata, - crop_width=crop_width, - crop_height=crop_height, - ) - elif crop_style in ["centre", "center"]: - _, crop_coordinates = MultiaspectImage._crop_center( - image_metadata=image_metadata, - crop_width=crop_width, - crop_height=crop_height, - ) - elif crop_style == "random" or crop_style == "face": - _, crop_coordinates = MultiaspectImage._crop_random( - image_metadata=image_metadata, - crop_width=crop_width, - crop_height=crop_height, - ) - else: - raise ValueError(f"Unknown crop style: {crop_style}") - - logger.debug(f"After cropping, our image size: {image.size}") + _, crop_coordinates = crop_result else: # Resize unconditionally if cropping is not enabled if image: @@ -209,96 +189,6 @@ def prepare_image( elif image_metadata: return (target_width, target_height), crop_coordinates, new_aspect_ratio - @staticmethod - def _crop_corner( - image: Image = None, - target_width=None, - target_height=None, - image_metadata: dict = None, - ): - """Crop the image from the bottom-right corner.""" - if image: - original_width, original_height = image.size - elif image_metadata: - original_width, original_height = image_metadata["original_size"] - left = max(0, original_width - target_width) - top = max(0, original_height - target_height) - right = original_width - bottom = original_height - if image: - return image.crop((left, top, right, bottom)), (left, top) - elif image_metadata: - return image_metadata, (left, top) - - @staticmethod - def _crop_center( - image: Image = None, - target_width=None, - target_height=None, - image_metadata: dict = None, - ): - """Crop the image from the center.""" - original_width, original_height = image.size - left = (original_width - target_width) / 2 - top = (original_height - target_height) / 2 - right = (original_width + target_width) / 2 - bottom = (original_height + target_height) / 2 - if image: - return image.crop((left, top, right, bottom)), (left, top) - elif image_metadata: - return image_metadata, (left, top) - - @staticmethod - def _crop_random( - image: Image = None, - target_width=None, - target_height=None, - image_metadata: dict = None, - ): - """Crop the image from a random position.""" - original_width, original_height = image.size - left = random.randint(0, max(0, original_width - target_width)) - top = random.randint(0, max(0, original_height - target_height)) - right = left + target_width - bottom = top + target_height - if image: - return image.crop((left, top, right, bottom)), (left, top) - elif image_metadata: - return image_metadata, (left, top) - - @staticmethod - def _crop_face( - image: Image, - target_width: int, - target_height: int, - ): - """Crop the image to include a face, or the most 'interesting' part of the image, without a face.""" - # Import modules - import cv2 - import numpy as np - - # Detect a face in the image - face_cascade = cv2.CascadeClassifier( - cv2.data.haarcascades + "haarcascade_frontalface_default.xml" - ) - image = image.convert("RGB") - image = np.array(image) - gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) - faces = face_cascade.detectMultiScale(gray, 1.1, 4) - if len(faces) > 0: - # Get the largest face - face = max(faces, key=lambda f: f[2] * f[3]) - x, y, w, h = face - left = max(0, x - 0.5 * w) - top = max(0, y - 0.5 * h) - right = min(image.shape[1], x + 1.5 * w) - bottom = min(image.shape[0], y + 1.5 * h) - image = Image.fromarray(image) - return image.crop((left, top, right, bottom)), (left, top) - else: - # Crop the image from a random position - return MultiaspectImage._crop_random(image, target_width, target_height) - @staticmethod def _round_to_nearest_multiple(value): """Round a value to the nearest multiple.""" @@ -439,6 +329,15 @@ def calculate_new_size_by_pixel_edge(aspect_ratio: float, resolution: int): def calculate_new_size_by_pixel_area(aspect_ratio: float, megapixels: float): if type(aspect_ratio) != float: raise ValueError(f"Aspect ratio must be a float, not {type(aspect_ratio)}") + # Special case for 1024px (1.0) megapixel images + if aspect_ratio == 1.0 and megapixels == 1.0: + return 1024, 1024, 1.0 + # Special case for 768px (0.75mp) images + if aspect_ratio == 1.0 and megapixels == 0.75: + return 768, 768, 1.0 + # Special case for 512px (0.5mp) images + if aspect_ratio == 1.0 and megapixels == 0.5: + return 512, 512, 1.0 total_pixels = max(megapixels * 1e6, 1e6) W_initial = int(round((total_pixels * aspect_ratio) ** 0.5)) H_initial = int(round((total_pixels / aspect_ratio) ** 0.5)) From 950e755e4ac54bf3ca7f0d57e2d6f534319581d3 Mon Sep 17 00:00:00 2001 From: bghira Date: Wed, 1 May 2024 22:10:42 -0600 Subject: [PATCH 20/37] add crop module --- helpers/image_manipulation/cropping.py | 88 ++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 helpers/image_manipulation/cropping.py diff --git a/helpers/image_manipulation/cropping.py b/helpers/image_manipulation/cropping.py new file mode 100644 index 00000000..3fdf2de6 --- /dev/null +++ b/helpers/image_manipulation/cropping.py @@ -0,0 +1,88 @@ +from PIL import Image + + +class BaseCropping: + def __init__(self, image: Image = None, image_metadata: dict = None): + self.image = image + self.image_metadata = image_metadata + if self.image: + self.original_width, self.original_height = self.image.size + elif self.image_metadata: + self.original_width, self.original_height = self.image_metadata[ + "original_size" + ] + + def crop(self, target_width, target_height): + raise NotImplementedError("Subclasses must implement this method") + + +class CornerCropping(BaseCropping): + def crop(self, target_width, target_height): + left = max(0, self.original_width - target_width) + top = max(0, self.original_height - target_height) + right = self.original_width + bottom = self.original_height + if self.image: + return self.image.crop((left, top, right, bottom)), (left, top) + elif self.image_metadata: + return self.image_metadata, (left, top) + + +class CenterCropping(BaseCropping): + def crop(self, target_width, target_height): + left = (self.original_width - target_width) / 2 + top = (self.original_height - target_height) / 2 + right = (self.original_width + target_width) / 2 + bottom = (self.original_height + target_height) / 2 + if self.image: + return self.image.crop((left, top, right, bottom)), (left, top) + elif self.image_metadata: + return self.image_metadata, (left, top) + + +class RandomCropping(BaseCropping): + def crop(self, target_width, target_height): + import random + + left = random.randint(0, max(0, self.original_width - target_width)) + top = random.randint(0, max(0, self.original_height - target_height)) + right = left + target_width + bottom = top + target_height + if self.image: + return self.image.crop((left, top, right, bottom)), (left, top) + elif self.image_metadata: + return self.image_metadata, (left, top) + + +class FaceCropping(RandomCropping): + def crop( + self, + image: Image, + target_width: int, + target_height: int, + ): + # Import modules + import cv2 + import numpy as np + + # Detect a face in the image + face_cascade = cv2.CascadeClassifier( + cv2.data.haarcascades + "haarcascade_frontalface_default.xml" + ) + image = image.convert("RGB") + image = np.array(image) + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + faces = face_cascade.detectMultiScale(gray, 1.1, 4) + if len(faces) > 0: + # Get the largest face + face = max(faces, key=lambda f: f[2] * f[3]) + x, y, w, h = face + left = max(0, x - 0.5 * w) + top = max(0, y - 0.5 * h) + right = min(image.shape[1], x + 1.5 * w) + bottom = min(image.shape[0], y + 1.5 * h) + image = Image.fromarray(image) + return image.crop((left, top, right, bottom)), (left, top) + else: + # Crop the image from a random position + return super.crop(image, target_width, target_height) From 06678b704a3876a5fc1df9777f386a095b39edc4 Mon Sep 17 00:00:00 2001 From: bghira Date: Wed, 1 May 2024 22:13:34 -0600 Subject: [PATCH 21/37] remove dead code --- helpers/multiaspect/image.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/helpers/multiaspect/image.py b/helpers/multiaspect/image.py index 3ad6e5d8..ef2e0c80 100644 --- a/helpers/multiaspect/image.py +++ b/helpers/multiaspect/image.py @@ -385,17 +385,3 @@ def calculate_image_aspect_ratio(image, rounding: int = 2): width, height = image.size aspect_ratio = round(width / height, to_round) return aspect_ratio - - @staticmethod - def determine_bucket_for_aspect_ratio(aspect_ratio): - """ - Determine the correct bucket for a given aspect ratio. - - Args: - aspect_ratio (float): The aspect ratio of an image. - - Returns: - str: The bucket corresponding to the aspect ratio. - """ - # The logic for determining the bucket can be based on the aspect ratio directly - return str(aspect_ratio) From b9a92fe28f76fe8e28049967a79798d5a69b0442 Mon Sep 17 00:00:00 2001 From: bghira Date: Wed, 1 May 2024 22:13:48 -0600 Subject: [PATCH 22/37] crop tests reorganise --- tests/test_cropping.py | 22 +++++++++++++--------- tests/test_image.py | 18 ------------------ 2 files changed, 13 insertions(+), 27 deletions(-) diff --git a/tests/test_cropping.py b/tests/test_cropping.py index f5a73f8b..ab21b0a3 100644 --- a/tests/test_cropping.py +++ b/tests/test_cropping.py @@ -5,16 +5,17 @@ ) # Adjust import according to your project structure -class TestMultiaspectImage(unittest.TestCase): +class TestCropping(unittest.TestCase): def setUp(self): # Creating a sample image for testing self.sample_image = Image.new("RGB", (500, 300), "white") def test_crop_corner(self): target_width, target_height = 300, 200 - cropped_image, (left, top) = MultiaspectImage._crop_corner( - self.sample_image, target_width, target_height - ) + from helpers.image_manipulation.cropping import CornerCropping + + cropper = CornerCropping(self.sample_image) + cropped_image, (left, top) = cropper.crop(target_width, target_height) # Check if cropped coordinates are within original image bounds self.assertTrue(0 <= left < self.sample_image.width) @@ -24,10 +25,11 @@ def test_crop_corner(self): self.assertEqual(cropped_image.size, (target_width, target_height)) def test_crop_center(self): + from helpers.image_manipulation.cropping import CenterCropping + + cropper = CenterCropping(self.sample_image) target_width, target_height = 300, 200 - cropped_image, (left, top) = MultiaspectImage._crop_center( - self.sample_image, target_width, target_height - ) + cropped_image, (left, top) = cropper.crop(target_width, target_height) # Similar checks as above self.assertTrue(0 <= left < self.sample_image.width) @@ -37,9 +39,11 @@ def test_crop_center(self): self.assertEqual(cropped_image.size, (target_width, target_height)) def test_crop_random(self): + from helpers.image_manipulation.cropping import RandomCropping + target_width, target_height = 300, 200 - cropped_image, (left, top) = MultiaspectImage._crop_random( - self.sample_image, target_width, target_height + cropped_image, (left, top) = RandomCropping(self.sample_image).crop( + target_width, target_height ) # Similar checks as above diff --git a/tests/test_image.py b/tests/test_image.py index 7bf1aa23..17a383dc 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -160,24 +160,6 @@ def test_image_size_consistency(self): f"Output sizes are not consistent for {resolution} MP", ) - def test_crop_corner(self): - cropped_image, _ = MultiaspectImage._crop_corner( - self.test_image, self.resolution, self.resolution - ) - self.assertEqual(cropped_image.size, (self.resolution, self.resolution)) - - def test_crop_center(self): - cropped_image, _ = MultiaspectImage._crop_center( - self.test_image, self.resolution, self.resolution - ) - self.assertEqual(cropped_image.size, (self.resolution, self.resolution)) - - def test_crop_random(self): - cropped_image, _ = MultiaspectImage._crop_random( - self.test_image, self.resolution, self.resolution - ) - self.assertEqual(cropped_image.size, (self.resolution, self.resolution)) - def test_prepare_image_valid(self): with patch("helpers.training.state_tracker.StateTracker.get_args") as mock_args: mock_args.return_value = Mock( From 4bcce5c853f23094249a30c2293d7ef1a6153a04 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Wed, 1 May 2024 21:23:24 -0700 Subject: [PATCH 23/37] Remove `triton` from ROCm depends Incompatible with `pytorch-triton-rocm` --- install/rocm/pyproject.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/install/rocm/pyproject.toml b/install/rocm/pyproject.toml index af0a1128..473ba97e 100644 --- a/install/rocm/pyproject.toml +++ b/install/rocm/pyproject.toml @@ -12,7 +12,6 @@ torch = {version = "^2.3", source = "pytorch-rocm"} torchaudio = {version = "*", source = "pytorch-rocm"} torchmetrics = {version = "^1", source = "pytorch-rocm"} torchvision = {version = "*", source = "pytorch-rocm"} -triton = {version = "*", source = "pytorch-rocm"} pytorch-triton-rocm = {version = "*", source = "pytorch-rocm"} accelerate = "^0.26" boto3 = "^1" @@ -38,7 +37,6 @@ scipy = "^1" tensorboard = "^2" torchsde = "^0.2" transformers = "^4" -# triton-library = "^1.0.0rc2" urllib3 = "<1.27" wandb = "^0.16" From a4a780e3dd85aa87667de8fcc57dce50b58520ad Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 2 May 2024 01:01:28 -0600 Subject: [PATCH 24/37] wip: training sample handler --- helpers/image_manipulation/cropping.py | 10 +- helpers/image_manipulation/training_sample.py | 204 ++++++++++++++++++ helpers/multiaspect/image.py | 20 +- 3 files changed, 221 insertions(+), 13 deletions(-) create mode 100644 helpers/image_manipulation/training_sample.py diff --git a/helpers/image_manipulation/cropping.py b/helpers/image_manipulation/cropping.py index 3fdf2de6..24a694c2 100644 --- a/helpers/image_manipulation/cropping.py +++ b/helpers/image_manipulation/cropping.py @@ -57,7 +57,7 @@ def crop(self, target_width, target_height): class FaceCropping(RandomCropping): def crop( self, - image: Image, + image: Image.Image, target_width: int, target_height: int, ): @@ -86,3 +86,11 @@ def crop( else: # Crop the image from a random position return super.crop(image, target_width, target_height) + + +crop_handlers = { + "corner": CornerCropping, + "centre": CenterCropping, + "center": CenterCropping, + "random": RandomCropping, +} diff --git a/helpers/image_manipulation/training_sample.py b/helpers/image_manipulation/training_sample.py new file mode 100644 index 00000000..ca59f92e --- /dev/null +++ b/helpers/image_manipulation/training_sample.py @@ -0,0 +1,204 @@ +from PIL import Image +from PIL.ImageOps import exif_transpose +from helpers.multiaspect.image import MultiaspectImage, resize_helpers +from helpers.multiaspect.image import crop_handlers +from helpers.training.state_tracker import StateTracker +import logging + +logger = logging.getLogger(__name__) + + +class TrainingSample: + def __init__(self, image: Image.Image, data_backend_id: str, metadata: dict = None): + """ + Initializes a new TrainingSample instance with a provided PIL.Image object and a data backend identifier. + + Args: + image (Image.Image): A PIL Image object. + data_backend_id (str): Identifier for the data backend used for additional operations. + metadata (dict): Optional metadata associated with the image. + """ + self.image = image + self.data_backend_id = data_backend_id + self.metadata = metadata if metadata else {} + if hasattr(image, "size"): + self.original_size = self.image.size + elif metadata is not None: + self.original_size = metadata.get("original_size") + + if not self.original_size: + raise Exception("Original size not found in metadata.") + + # Torchvision transforms turn the pixels into a Tensor and normalize them for the VAE. + self.transforms = MultiaspectImage.get_image_transforms() + # EXIT, RGB conversions. + self.correct_image() + + # Backend config details + self.data_backend_config = StateTracker.get_data_backend_config(data_backend_id) + self.crop_enabled = self.data_backend_config.get("crop", False) + self.crop_style = self.data_backend_config.get("crop_style", "random") + self.crop_aspect = self.data_backend_config.get("crop_aspect", "random") + crop_handler_cls = crop_handlers.get(self.crop_style) + if not crop_handler_cls: + raise ValueError(f"Unknown crop style: {self.crop_style}") + self.cropper = crop_handler_cls(image=self.image, image_metadata=metadata) + self.target_size_calculator = resize_helpers.get(self.resolution_type) + if self.target_size_calculator is None: + raise ValueError(f"Unknown resolution type: {self.resolution_type}") + self.resolution = self.data_backend_config.get("resolution") + self.resolution_type = self.data_backend_config.get("resolution_type") + if self.resolution_type == "pixel": + self.target_area = self.resolution + # Store the pixel value, eg. 1024 + self.pixel_resolution = self.resolution + # Store the megapixel value, eg. 1.0 + self.megapixel_resolution = self.resolution / 1e6 + elif self.resolution_type == "area": + self.target_area = self.resolution * 1e6 # Convert megapixels to pixels + # Store the pixel value, eg. 1024 + self.pixel_resolution = self.resolution * 1e6 + # Store the megapixel value, eg. 1.0 + self.megapixel_resolution = self.resolution + else: + raise Exception(f"Unknown resolution type: {self.resolution_type}") + self.target_downsample_size = self.data_backend_config.get( + "target_downsample_size", None + ) + self.maximum_image_size = self.data_backend_config.get( + "maximum_image_size", None + ) + + def prepare(self, return_tensor: bool = False): + """ + Perform initial image preparations such as converting to RGB and applying EXIF transformations. + + Args: + image (Image.Image): The image to prepare. + + Returns: + (image, crop_coordinates, aspect_ratio) + """ + self.crop() + if not self.crop_enabled: + self.resize() + + image = self.image + if return_tensor: + # Return normalised tensor. + image = self.transforms(image) + return image, self.crop_coordinates, self.aspect_ratio + + def area(self) -> int: + """ + Calculate the area of the image. + + Returns: + int: The area of the image. + """ + if self.image is not None: + return self.image.size[0] * self.image.size[1] + if self.original_size: + return self.original_size[0] * self.original_size[1] + + def should_downsample_before_crop(self) -> bool: + """ + Returns: + bool: True if the image should be downsampled before cropping, False otherwise. + """ + if ( + not self.crop_enabled + or not self.maximum_image_size + or not self.target_downsample_size + ): + return False + if self.data_backend_config.get("resolution_type") == "pixel": + return ( + self.image.size[0] > self.pixel_resolution + or self.image.size[1] > self.pixel_resolution + ) + elif self.data_backend_config.get("resolution_type") == "area": + logger.debug( + f"Image is too large? {self.area() > self.target_area} (image area: {self.area()}, target area: {self.target_area})" + ) + return self.area() > self.target_area + else: + raise ValueError( + f"Unknown resolution type: {self.data_backend_config.get('resolution_type')}" + ) + + def downsample_before_crop(self): + """ + Downsample the image before cropping, to preserve scene details. + """ + if self.image and self.should_downsample_before_crop(): + width, height, _ = self.calculate_target_size( + self.image, downsample_before_crop=True + ) + self.image = self.resize((width, height)) + return self + + def calculate_target_size(self, downsample_before_crop: bool = False): + if downsample_before_crop and self.target_downsample_size is not None: + self.target_size = self.target_size_calculator( + self.image, self.target_downsample_size + ) + else: + self.target_size = self.target_size_calculator(self.image, self.resolution) + self.aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio( + self.target_size + ) + + return self.target_size[0], self.target_size[1], self.aspect_ratio + + def correct_image(self): + """ + Apply a series of transformations to the image to "correct" it. + """ + if self.image: + # Convert image to RGB to remove any alpha channel and apply EXIF data transformations + self.image = self.image.convert("RGB") + self.image = exif_transpose(self.image) + self.original_size = self.image.size + return self + + def crop(self): + """ + Crop the image using the detected crop handler class. + """ + if not self.crop_enabled: + return self + + # Too-big of an image, resize before we crop. + self.downsample_before_crop() + width, height, aspect_ratio = self.calculate_target_size( + downsample_before_crop=False + ) + self.image, self.crop_coordinates = self.cropper.crop(width, height) + return self + + def resize(self, target_size: tuple = None): + """ + Resize the image to a new size. + + Args: + target_size (tuple): The target size as (width, height). + """ + if target_size is None: + target_width, target_height, aspect_ratio = self.calculate_target_size() + target_size = (target_width, target_height) + if self.image: + self.image = self.image.resize(target_size, Image.LANCZOS) + self.aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio( + self.image.size + ) + return self + + def get_image(self): + """ + Returns the current state of the image. + + Returns: + Image.Image: The current image. + """ + return self.image diff --git a/helpers/multiaspect/image.py b/helpers/multiaspect/image.py index ef2e0c80..c7871bbb 100644 --- a/helpers/multiaspect/image.py +++ b/helpers/multiaspect/image.py @@ -6,20 +6,10 @@ import logging, os, random from math import sqrt from helpers.training.state_tracker import StateTracker -from helpers.image_manipulation.cropping import ( - CornerCropping, - CenterCropping, - RandomCropping, -) +from helpers.image_manipulation.cropping import crop_handlers logger = logging.getLogger("MultiaspectImage") logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) -crop_handlers = { - "corner": CornerCropping, - "centre": CenterCropping, - "center": CenterCropping, - "random": RandomCropping, -} class MultiaspectImage: @@ -198,7 +188,7 @@ def _round_to_nearest_multiple(value): @staticmethod def _resize_image( - input_image: Image, + input_image: Image.Image, target_width: int, target_height: int, image_metadata: dict = None, @@ -385,3 +375,9 @@ def calculate_image_aspect_ratio(image, rounding: int = 2): width, height = image.size aspect_ratio = round(width / height, to_round) return aspect_ratio + + +resize_helpers = { + "pixel": MultiaspectImage.calculate_new_size_by_pixel_edge, + "area": MultiaspectImage.calculate_new_size_by_pixel_area, +} From ad441e9086ae6635bdde87f23b0aa88a75bd25d4 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 2 May 2024 11:38:01 -0600 Subject: [PATCH 25/37] fix typo in pixel alignment warning --- helpers/arguments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helpers/arguments.py b/helpers/arguments.py index 14d7b7d5..1726eca9 100644 --- a/helpers/arguments.py +++ b/helpers/arguments.py @@ -1372,7 +1372,7 @@ def parse_args(input_args=None): deepfloyd_pixel_alignment = 8 if args.aspect_bucket_alignment != deepfloyd_pixel_alignment: logger.warning( - f"Overriding aspect bucket alignment pixel interval to {deepfloyd_pixel_alignment}px instead of{args.aspect_bucket_alignment}px." + f"Overriding aspect bucket alignment pixel interval to {deepfloyd_pixel_alignment}px instead of {args.aspect_bucket_alignment}px." ) args.aspect_bucket_alignment = deepfloyd_pixel_alignment From d2f768d9ec82de518a63318167f6ef65ebcfee83 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 2 May 2024 11:38:42 -0600 Subject: [PATCH 26/37] training sample (wip) should set up resolution_type before the target size is calculated. we will keep target sizes up to date on the object. --- helpers/image_manipulation/training_sample.py | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/helpers/image_manipulation/training_sample.py b/helpers/image_manipulation/training_sample.py index ca59f92e..06a775e2 100644 --- a/helpers/image_manipulation/training_sample.py +++ b/helpers/image_manipulation/training_sample.py @@ -39,15 +39,16 @@ def __init__(self, image: Image.Image, data_backend_id: str, metadata: dict = No self.crop_enabled = self.data_backend_config.get("crop", False) self.crop_style = self.data_backend_config.get("crop_style", "random") self.crop_aspect = self.data_backend_config.get("crop_aspect", "random") + self.crop_coordinates = (0, 0) crop_handler_cls = crop_handlers.get(self.crop_style) if not crop_handler_cls: raise ValueError(f"Unknown crop style: {self.crop_style}") self.cropper = crop_handler_cls(image=self.image, image_metadata=metadata) + self.resolution = self.data_backend_config.get("resolution") + self.resolution_type = self.data_backend_config.get("resolution_type") self.target_size_calculator = resize_helpers.get(self.resolution_type) if self.target_size_calculator is None: raise ValueError(f"Unknown resolution type: {self.resolution_type}") - self.resolution = self.data_backend_config.get("resolution") - self.resolution_type = self.data_backend_config.get("resolution_type") if self.resolution_type == "pixel": self.target_area = self.resolution # Store the pixel value, eg. 1024 @@ -132,19 +133,25 @@ def downsample_before_crop(self): Downsample the image before cropping, to preserve scene details. """ if self.image and self.should_downsample_before_crop(): - width, height, _ = self.calculate_target_size( - self.image, downsample_before_crop=True - ) + width, height, _ = self.calculate_target_size(downsample_before_crop=True) self.image = self.resize((width, height)) return self def calculate_target_size(self, downsample_before_crop: bool = False): + self.aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio( + self.original_size + ) if downsample_before_crop and self.target_downsample_size is not None: - self.target_size = self.target_size_calculator( - self.image, self.target_downsample_size + target_width, target_height, self.aspect_ratio = ( + self.target_size_calculator( + self.aspect_ratio, self.target_downsample_size + ) ) else: - self.target_size = self.target_size_calculator(self.image, self.resolution) + target_width, target_height, self.aspect_ratio = ( + self.target_size_calculator(self.aspect_ratio, self.resolution) + ) + self.target_size = (target_width, target_height) self.aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio( self.target_size ) From 7044f86c5b3ffa4af58dc7448ed28092cfb0a165 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 2 May 2024 11:40:31 -0600 Subject: [PATCH 27/37] use the training sample wrapper for image preparation during metadata/aspect bucket scan --- helpers/metadata/backends/json.py | 11 ++++------- helpers/metadata/backends/parquet.py | 14 +++++++------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/helpers/metadata/backends/json.py b/helpers/metadata/backends/json.py index f24ad149..bb127554 100644 --- a/helpers/metadata/backends/json.py +++ b/helpers/metadata/backends/json.py @@ -2,6 +2,7 @@ from helpers.multiaspect.image import MultiaspectImage from helpers.data_backend.base import BaseDataBackend from helpers.metadata.backends.base import MetadataBackend +from helpers.image_manipulation.training_sample import TrainingSample from pathlib import Path import json, logging, os, time, re from multiprocessing import Manager @@ -217,14 +218,10 @@ def _process_for_bucket( statistics["skipped"]["too_small"] += 1 return aspect_ratio_bucket_indices image_metadata["original_size"] = image.size - image, crop_coordinates, new_aspect_ratio = ( - MultiaspectImage.prepare_image( - image=image, - resolution=self.resolution, - resolution_type=self.resolution_type, - id=self.data_backend.id, - ) + training_sample = TrainingSample( + image=image, data_backend_id=self.id, metadata=image_metadata ) + image, crop_coordinates, new_aspect_ratio = training_sample.prepare() image_metadata["crop_coordinates"] = crop_coordinates image_metadata["target_size"] = image.size # Round to avoid excessive unique buckets diff --git a/helpers/metadata/backends/parquet.py b/helpers/metadata/backends/parquet.py index b1cb48fd..23cbe91c 100644 --- a/helpers/metadata/backends/parquet.py +++ b/helpers/metadata/backends/parquet.py @@ -1,6 +1,7 @@ from helpers.training.state_tracker import StateTracker from helpers.multiaspect.image import MultiaspectImage from helpers.data_backend.base import BaseDataBackend +from helpers.image_manipulation.training_sample import TrainingSample from helpers.metadata.backends.base import MetadataBackend from tqdm import tqdm import json, logging, os, time @@ -358,13 +359,12 @@ def _process_for_bucket( aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio( image_metadata["original_size"] ) - target_size, crop_coordinates, new_aspect_ratio = ( - MultiaspectImage.prepare_image( - image_metadata=image_metadata, - resolution=self.resolution, - resolution_type=self.resolution_type, - id=self.data_backend.id, - ) + training_sample = TrainingSample( + image=None, data_backend_id=self.id, metadata=image_metadata + ) + target_size, crop_coordinates, new_aspect_ratio = training_sample.prepare() + print( + f"Prepared training sample: {target_size}, {crop_coordinates}, {new_aspect_ratio}" ) image_metadata["crop_coordinates"] = crop_coordinates image_metadata["target_size"] = target_size From f5a31c89dd7b7d776a32fdce32c5b7f86a788981 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 2 May 2024 11:41:29 -0600 Subject: [PATCH 28/37] base metadata backend should no longer call function to convert to str --- helpers/metadata/backends/base.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/helpers/metadata/backends/base.py b/helpers/metadata/backends/base.py index 766305c0..651c9bec 100644 --- a/helpers/metadata/backends/base.py +++ b/helpers/metadata/backends/base.py @@ -740,9 +740,7 @@ def handle_vae_cache_inconsistencies(self, vae_cache, vae_cache_behavior: str): continue if vae_cache_behavior == "sync": # Sync aspect buckets with the cache - expected_bucket = MultiaspectImage.determine_bucket_for_aspect_ratio( - self._get_aspect_ratio_from_tensor(cache_content) - ) + expected_bucket = str(self._get_aspect_ratio_from_tensor(cache_content)) self._modify_cache_entry_bucket(cache_file, expected_bucket) elif vae_cache_behavior == "recreate": # Delete the cache file if it doesn't match the aspect bucket indices @@ -837,9 +835,7 @@ def is_cache_inconsistent(self, vae_cache, cache_file, cache_content): ) actual_aspect_ratio = self._get_aspect_ratio_from_tensor(cache_content) - expected_bucket = MultiaspectImage.determine_bucket_for_aspect_ratio( - recalculated_aspect_ratio - ) + expected_bucket = str(recalculated_aspect_ratio) logger.debug( f"Expected bucket for {cache_file}: {expected_bucket} vs actual {actual_aspect_ratio}" ) From 32632a9747ccf71658fa440718c3976837e653a3 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 2 May 2024 11:42:09 -0600 Subject: [PATCH 29/37] update debug logging for training loop to have better phrasing, more information for DF training issues --- train_sd21.py | 10 ++++------ train_sdxl.py | 4 +--- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/train_sd21.py b/train_sd21.py index 63a8a498..34b1c005 100644 --- a/train_sd21.py +++ b/train_sd21.py @@ -1036,9 +1036,7 @@ def main(): with accelerator.accumulate( training_models ), torch.autograd.detect_anomaly(): - training_logger.debug( - f"Sending latent batch from pinned memory to device" - ) + training_logger.debug(f"Sending latent batch to GPU") latents = batch["latent_batch"].to( accelerator.device, dtype=weight_dtype ) @@ -1121,10 +1119,10 @@ def main(): f"\n -> Noise device: {noise.device}" f"\n -> Timesteps device: {timesteps.device}" f"\n -> Encoder hidden states device: {encoder_hidden_states.device}" - f"\n -> Latents dtype: {latents.dtype}" - f"\n -> Noise dtype: {noise.dtype}" + f"\n -> Latents dtype: {latents.dtype}, shape: {latents.shape if hasattr(latents, 'shape') else 'None'}" + f"\n -> Noise dtype: {noise.dtype}, shape: {noise.shape if hasattr(noise, 'shape') else 'None'}" f"\n -> Timesteps dtype: {timesteps.dtype}" - f"\n -> Encoder hidden states dtype: {encoder_hidden_states.dtype}" + f"\n -> Encoder hidden states dtype: {encoder_hidden_states.dtype}, shape: {encoder_hidden_states.shape if hasattr(encoder_hidden_states, 'shape') else 'None'}" ) if unwrap_model(accelerator, unet).config.in_channels == channels * 2: # deepfloyd stage ii requires the inputs to be doubled. note that we're working in pixels, not latents. diff --git a/train_sdxl.py b/train_sdxl.py index e4323bec..2a62845c 100644 --- a/train_sdxl.py +++ b/train_sdxl.py @@ -1130,9 +1130,7 @@ def main(): training_luminance_values.append(batch["batch_luminance"]) with accelerator.accumulate(training_models): - training_logger.debug( - f"Sending latent batch from pinned memory to device." - ) + training_logger.debug(f"Sending latent batch to GPU.") latents = batch["latent_batch"].to( accelerator.device, dtype=weight_dtype ) From 2035fa4216f63817ab0aed25351bae96b7207f78 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 2 May 2024 11:43:51 -0600 Subject: [PATCH 30/37] update PIL type hint --- toolkit/captioning/caption_with_cogvlm.py | 2 +- toolkit/captioning/caption_with_cogvlm_remote.py | 10 +++++----- toolkit/captioning/caption_with_llava.py | 2 +- toolkit/datasets/analyze_aspect_ratios_json.py | 2 +- toolkit/datasets/csv_to_s3.py | 4 ++-- toolkit/datasets/dataset_from_laion.py | 2 +- toolkit/datasets/enhance_with_controlnet.py | 2 +- 7 files changed, 12 insertions(+), 12 deletions(-) diff --git a/toolkit/captioning/caption_with_cogvlm.py b/toolkit/captioning/caption_with_cogvlm.py index c979f2b1..eb92f2cb 100644 --- a/toolkit/captioning/caption_with_cogvlm.py +++ b/toolkit/captioning/caption_with_cogvlm.py @@ -114,7 +114,7 @@ def load_filter_list(filter_list_path): def eval_image( - image: Image, + image: Image.Image, model, tokenizer, torch_dtype, diff --git a/toolkit/captioning/caption_with_cogvlm_remote.py b/toolkit/captioning/caption_with_cogvlm_remote.py index 30c1d831..344cd4ac 100644 --- a/toolkit/captioning/caption_with_cogvlm_remote.py +++ b/toolkit/captioning/caption_with_cogvlm_remote.py @@ -118,7 +118,7 @@ def parse_args(): def eval_image( - image: Image, + image: Image.Image, model, tokenizer, torch_dtype, @@ -141,7 +141,7 @@ def eval_image( return tokenizer.decode(outputs[0]) -def eval_image_with_ooba(image: Image, query: str) -> str: +def eval_image_with_ooba(image: Image.Image, query: str) -> str: CONTEXT = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n" img_str = base64.b64encode(image).decode("utf-8") @@ -192,7 +192,7 @@ def encode_images(accelerator, vae, images, image_transform): from math import sqrt -def round_to_nearest_multiple(value, multiple): +def round_to_nearest_multiple(value, multiple: int = 64): """Round a value to the nearest multiple.""" rounded = round(value / multiple) * multiple return max(rounded, multiple) # Ensure it's at least the value of 'multiple' @@ -206,8 +206,8 @@ def calculate_new_size_by_pixel_area(W: int, H: int, megapixels: float = 1.0): H_new = int(round(sqrt(total_pixels / aspect_ratio))) # Ensure they are divisible by 64 - W_new = round_to_nearest_multiple(W_new, 64) - H_new = round_to_nearest_multiple(H_new, 64) + W_new = round_to_nearest_multiple(W_new) + H_new = round_to_nearest_multiple(H_new) return W_new, H_new diff --git a/toolkit/captioning/caption_with_llava.py b/toolkit/captioning/caption_with_llava.py index 41f05bbe..8bf8d1fe 100644 --- a/toolkit/captioning/caption_with_llava.py +++ b/toolkit/captioning/caption_with_llava.py @@ -256,7 +256,7 @@ def process_and_evaluate_image(args, image_path: str, model, processor): else: image = Image.open(image_path) - def resize_for_condition_image(input_image: Image, resolution: int): + def resize_for_condition_image(input_image: Image.Image, resolution: int): if resolution == 0: return input_image input_image = input_image.convert("RGB") diff --git a/toolkit/datasets/analyze_aspect_ratios_json.py b/toolkit/datasets/analyze_aspect_ratios_json.py index 94183cc4..2695661a 100644 --- a/toolkit/datasets/analyze_aspect_ratios_json.py +++ b/toolkit/datasets/analyze_aspect_ratios_json.py @@ -39,7 +39,7 @@ # for file in files_to_delete: # import os # os.remove(file) -def _resize_for_condition_image(self, input_image: Image, resolution: int): +def _resize_for_condition_image(self, input_image: Image.Image, resolution: int): input_image = input_image.convert("RGB") W, H = input_image.size k = float(resolution) / min(H, W) diff --git a/toolkit/datasets/csv_to_s3.py b/toolkit/datasets/csv_to_s3.py index a3aceccc..493f9bf3 100644 --- a/toolkit/datasets/csv_to_s3.py +++ b/toolkit/datasets/csv_to_s3.py @@ -74,7 +74,7 @@ def shuffle_words_in_filename(filename): return "_".join(words) + ext -def resize_for_condition_image(input_image: Image, resolution: int): +def resize_for_condition_image(input_image: Image.Image, resolution: int): if resolution == 0: return input_image input_image = input_image.convert("RGB") @@ -469,7 +469,7 @@ def list_all_s3_objects(s3_client, bucket_name): return existing_files -def upload_pil_to_s3(image: Image, filename, args, s3_client): +def upload_pil_to_s3(image: Image.Image, filename, args, s3_client): """Upload a PIL Image directly to S3 bucket""" if object_exists_in_s3(s3_client, args.aws_bucket_name, filename): return diff --git a/toolkit/datasets/dataset_from_laion.py b/toolkit/datasets/dataset_from_laion.py index eb6ff0be..71d86065 100644 --- a/toolkit/datasets/dataset_from_laion.py +++ b/toolkit/datasets/dataset_from_laion.py @@ -113,7 +113,7 @@ def load_csv(file): timeouts = (conn_timeout, read_timeout) -def _resize_for_condition_image(input_image: Image, resolution: int): +def _resize_for_condition_image(input_image: Image.Image, resolution: int): input_image = input_image.convert("RGB") W, H = input_image.size aspect_ratio = round(W / H, 2) diff --git a/toolkit/datasets/enhance_with_controlnet.py b/toolkit/datasets/enhance_with_controlnet.py index 6e0b2216..4e487693 100644 --- a/toolkit/datasets/enhance_with_controlnet.py +++ b/toolkit/datasets/enhance_with_controlnet.py @@ -4,7 +4,7 @@ from diffusers.utils import load_image -def resize_for_condition_image(input_image: Image, resolution: int): +def resize_for_condition_image(input_image: Image.Image, resolution: int): input_image = input_image.convert("RGB") W, H = input_image.size k = float(resolution) / min(H, W) From 554c6d0d2be2f0de6a2ee971d8be987f7c5f12c9 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 2 May 2024 12:58:31 -0600 Subject: [PATCH 31/37] disable torch anomaly detector --- train_sd21.py | 8 ++------ train_sdxl.py | 3 +-- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/train_sd21.py b/train_sd21.py index 34b1c005..1e0bfb31 100644 --- a/train_sd21.py +++ b/train_sd21.py @@ -104,9 +104,8 @@ tokenizer = None -torch.autograd.set_detect_anomaly(True) # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.26.0.dev0") +check_min_version("0.27.0.dev0") SCHEDULER_NAME_MAP = { @@ -916,7 +915,6 @@ def main(): } }, ) - torch.autograd.set_detect_anomaly(True) logger.info("***** Running training *****") total_num_batches = sum( [ @@ -1033,9 +1031,7 @@ def main(): # Add the current batch of training data's avg luminance to a list. training_luminance_values.append(batch["batch_luminance"]) - with accelerator.accumulate( - training_models - ), torch.autograd.detect_anomaly(): + with accelerator.accumulate(training_models): training_logger.debug(f"Sending latent batch to GPU") latents = batch["latent_batch"].to( accelerator.device, dtype=weight_dtype diff --git a/train_sdxl.py b/train_sdxl.py index 2a62845c..21ecb700 100644 --- a/train_sdxl.py +++ b/train_sdxl.py @@ -86,9 +86,8 @@ from diffusers.utils.import_utils import is_xformers_available from transformers.utils import ContextManagers -torch.autograd.set_detect_anomaly(True) # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.26.0.dev0") +check_min_version("0.27.0.dev0") SCHEDULER_NAME_MAP = { "euler": EulerDiscreteScheduler, From 63dc4a11252b92bbf81cc1108c03f626d3ba20a6 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 2 May 2024 14:35:18 -0600 Subject: [PATCH 32/37] nuke prepare_image from existence --- helpers/multiaspect/image.py | 157 ----------------------------------- tests/test_image.py | 139 ------------------------------- 2 files changed, 296 deletions(-) diff --git a/helpers/multiaspect/image.py b/helpers/multiaspect/image.py index c7871bbb..44f8ce0a 100644 --- a/helpers/multiaspect/image.py +++ b/helpers/multiaspect/image.py @@ -22,163 +22,6 @@ def get_image_transforms(): ] ) - @staticmethod - def prepare_image( - resolution: float, - image: Image = None, - image_metadata: dict = None, - resolution_type: str = "pixel", - id: str = "foo", - ): - if image: - if not hasattr(image, "convert"): - raise Exception( - f"Unknown data received instead of PIL.Image object: {type(image)}" - ) - # Strip transparency - image = image.convert("RGB") - # Rotate, maybe. - logger.debug(f"Processing image filename: {image}") - logger.debug(f"Image size before EXIF transform: {image.size}") - image = exif_transpose(image) - logger.debug(f"Image size after EXIF transform: {image.size}") - image_size = image.size - elif image_metadata: - image_size = ( - image_metadata["original_size"][0], - image_metadata["original_size"][1], - ) - original_width, original_height = image_size - original_resolution = resolution - # Convert 'resolution' from eg. "1 megapixel" to "1024 pixels" - if resolution_type == "area": - original_resolution = original_resolution * 1e3 - # Make resolution a multiple of StateTracker.get_args().aspect_bucket_alignment - original_resolution = MultiaspectImage._round_to_nearest_multiple( - original_resolution - ) - - # Downsample before we handle, if necessary. - downsample_before_crop = False - crop = StateTracker.get_data_backend_config(data_backend_id=id).get( - "crop", False - ) - maximum_image_size = StateTracker.get_data_backend_config( - data_backend_id=id - ).get("maximum_image_size", None) - target_downsample_size = StateTracker.get_data_backend_config( - data_backend_id=id - ).get("target_downsample_size", None) - logger.debug( - f"Dataset: {id}, maximum_image_size: {maximum_image_size}, target_downsample_size: {target_downsample_size}" - ) - if crop and maximum_image_size and target_downsample_size: - if MultiaspectImage.is_image_too_large( - image_size, maximum_image_size, resolution_type=resolution_type - ): - # Override the target resolution with the target downsample size - logger.debug( - f"Overriding resolution {resolution} with target downsample size: {target_downsample_size}" - ) - resolution = target_downsample_size - downsample_before_crop = True - - # Calculate new size - original_aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio( - (original_width, original_height) - ) - if resolution_type == "pixel": - (target_width, target_height, new_aspect_ratio) = ( - MultiaspectImage.calculate_new_size_by_pixel_edge( - original_aspect_ratio, resolution - ) - ) - elif resolution_type == "area": - (target_width, target_height, new_aspect_ratio) = ( - MultiaspectImage.calculate_new_size_by_pixel_area( - original_aspect_ratio, resolution - ) - ) - # Convert 'resolution' from eg. "1 megapixel" to "1024 pixels" - resolution = resolution * 1e3 - # Make resolution a multiple of StateTracker.get_args().aspect_bucket_alignment - resolution = MultiaspectImage._round_to_nearest_multiple(resolution) - logger.debug( - f"After area resize, our image will be {target_width}x{target_height} with an overridden resolution of {resolution} pixels." - ) - else: - raise ValueError(f"Unknown resolution type: {resolution_type}") - - crop_style = StateTracker.get_data_backend_config(data_backend_id=id).get( - "crop_style", "random" - ) - crop_aspect = StateTracker.get_data_backend_config(data_backend_id=id).get( - "crop_aspect", "square" - ) - - if crop: - crop_handler_cls = crop_handlers.get(crop_style) - if not crop_handler_cls: - raise ValueError(f"Unknown crop style: {crop_style}") - crop_handler = crop_handler_cls(image=image, image_metadata=image_metadata) - if downsample_before_crop: - logger.debug( - f"Resizing image before crop, as its size is too large. Data backend: {id}, image size: {image.size}, target size: {target_width}x{target_height}" - ) - if image: - image = MultiaspectImage._resize_image( - image, target_width, target_height - ) - elif image_metadata: - image_metadata = MultiaspectImage._resize_image( - None, target_width, target_height, image_metadata - ) - if resolution_type == "area": - # Convert original_resolution back from eg. 1024 pixels to 1.0 mp - original_megapixel_resolution = original_resolution / 1e3 - (target_width, target_height, new_aspect_ratio) = ( - MultiaspectImage.calculate_new_size_by_pixel_area( - original_aspect_ratio, - original_megapixel_resolution, - ) - ) - elif resolution_type == "pixel": - (target_width, target_height, new_aspect_ratio) = ( - MultiaspectImage.calculate_new_size_by_pixel_edge( - original_aspect_ratio, original_resolution - ) - ) - logger.debug( - f"Recalculated target_width and target_height {target_width}x{target_height} based on original_resolution: {original_resolution}" - ) - - logger.debug(f"We are cropping the image. Data backend: {id}") - crop_width, crop_height = ( - (original_resolution, original_resolution) - if crop_aspect == "square" - else (target_width, target_height) - ) - - crop_result = crop_handler.crop(crop_width, crop_height) - - if image: - image, crop_coordinates = crop_result - logger.debug(f"After cropping, our image size: {image.size}") - elif image_metadata: - _, crop_coordinates = crop_result - else: - # Resize unconditionally if cropping is not enabled - if image: - image = MultiaspectImage._resize_image( - image, target_width, target_height - ) - crop_coordinates = (0, 0) - - if image: - return image, crop_coordinates, new_aspect_ratio - elif image_metadata: - return (target_width, target_height), crop_coordinates, new_aspect_ratio - @staticmethod def _round_to_nearest_multiple(value): """Round a value to the nearest multiple.""" diff --git a/tests/test_image.py b/tests/test_image.py index 17a383dc..736d6a4e 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -41,145 +41,6 @@ def test_aspect_ratio_calculation(self): MultiaspectImage.calculate_image_aspect_ratio((1080, 1920)), 0.56 ) - def test_image_resize(self): - """ - Test that images are resized to the expected dimensions. - """ - # Define target resolutions and expected output sizes - tests = [ - (1024, "pixel", (1792, 1024), Image.new("RGB", (1920, 1080))), - ( - 1.0, - "area", - (1344, 768), - Image.new("RGB", (1920, 1080)), - ), # Assuming target is 1 megapixel - ( - 256, - "pixel", - (448, 256), - Image.new("RGB", (1920, 1080)), - ), # From log, smaller side to 256, aspect ratio approximated - ( - 256, - "pixel", - (448, 256), - Image.new("RGB", (3840, 2160)), - ), # From log, taller image with smaller side to 256 - ] - with patch("helpers.training.state_tracker.StateTracker.get_args") as mock_args: - mock_args.return_value = Mock( - resolution_type="pixel", - resolution=self.resolution, - crop_style="random", - aspect_bucket_rounding=2, - aspect_bucket_alignment=64, - ) - - for resolution, resolution_type, expected_size, test_image in tests: - resized_image, _, _ = MultiaspectImage.prepare_image( - resolution=resolution, - image=test_image, - resolution_type=resolution_type, - id="test", - ) - - # Verify the size of the resized image - self.assertEqual(resized_image.size, expected_size) - - def test_image_size_consistency(self): - """ - Test that `prepare_image` returns consistent size for images with similar aspect ratios. - """ - # Generate random input aspect ratios and resolutions: - input_aspect_ratios = [random.uniform(0.5, 2.0) for _ in range(10)] - # Sizes should follow the list of resolutions, with between 2-4 images in each aspect - input_sizes = [] - for aspect_ratio in input_aspect_ratios: - count = 0 - for resolution in range(5, 50, 5): - count += 1 - width = resolution * 100 - height = int(width / aspect_ratio) - input_sizes.append((width, height)) - - # Sort into bucket dictionary using MultiaspectImage.calculate_image_aspect_ratio - input_sizes_dict = {} - for size in input_sizes: - aspect_ratio = size[0] / size[1] - if aspect_ratio not in input_sizes_dict: - input_sizes_dict[aspect_ratio] = [] - input_sizes_dict[aspect_ratio].append(size) - - resolutions = range( - 5, 20, 5 - ) # Using a simplified resolution from the logs for the test - with patch("helpers.training.state_tracker.StateTracker.get_args") as mock_args: - mock_args.return_value = Mock( - resolution_type="pixel", - resolution=self.resolution, - crop_style="random", - aspect_bucket_rounding=2, - aspect_bucket_alignment=64, - ) - for aspect_ratio in set(input_sizes_dict.keys()): - for resolution in resolutions: - resolution = resolution / 10 # Convert to megapixels - output_sizes = [] - new_aspect_ratios = [] - for size in input_sizes_dict[aspect_ratio]: - should_use_real_image = random.choice([True, False]) - image = ( - Image.new("RGB", size) if should_use_real_image else None - ) # Creating a dummy PIL image with the given size - image_metadata = ( - None if should_use_real_image else {"original_size": size} - ) - function_result, _, new_aspect_ratio = ( - MultiaspectImage.prepare_image( - image=image, - image_metadata=image_metadata, - resolution=resolution, - resolution_type="area", - ) - ) - if hasattr(function_result, "size"): - output_size = function_result.size - else: - output_size = function_result - output_sizes.append(output_size) - new_aspect_ratios.append(new_aspect_ratio) - - # Check if all output sizes are the same, indicating consistent resizing/cropping - self.assertTrue( - all(size == output_sizes[0] for size in output_sizes), - f"Output sizes are not consistent for {resolution} MP", - ) - self.assertTrue( - all(size == new_aspect_ratios[0] for size in new_aspect_ratios), - f"Output sizes are not consistent for {resolution} MP", - ) - - def test_prepare_image_valid(self): - with patch("helpers.training.state_tracker.StateTracker.get_args") as mock_args: - mock_args.return_value = Mock( - resolution_type="pixel", - resolution=self.resolution, - crop_style="random", - aspect_bucket_rounding=2, - aspect_bucket_alignment=64, - ) - prepared_img, crop_coordinates, aspect_ratio = ( - MultiaspectImage.prepare_image( - image=self.test_image, resolution=self.resolution - ) - ) - self.assertIsInstance(prepared_img, Image.Image) - - def test_prepare_image_invalid(self): - with self.assertRaises(Exception): - MultiaspectImage.prepare_image(None, self.resolution) - def test_resize_for_condition_image_valid(self): resized_img = MultiaspectImage._resize_image( self.test_image, self.resolution, self.resolution From de23593c3de16d988fd81320b8d4325347397260 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 2 May 2024 14:38:19 -0600 Subject: [PATCH 33/37] mps: increase allowed batch size to 16 since anomaly detector eats a lot of compute, and is now disabled. --- helpers/arguments.py | 2 +- helpers/multiaspect/image.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/helpers/arguments.py b/helpers/arguments.py index 1726eca9..960f0bcb 100644 --- a/helpers/arguments.py +++ b/helpers/arguments.py @@ -1332,7 +1332,7 @@ def parse_args(input_args=None): logger.warning( "MPS may benefit from the use of --unet_attention_slice for memory savings at the cost of speed." ) - if args.train_batch_size > 12: + if args.train_batch_size > 16: logger.error( "An M3 Max 128G will use 12 seconds per step at a batch size of 1 and 65 seconds per step at a batch size of 12." " Any higher values will result in NDArray size errors or other unstable training results and crashes." diff --git a/helpers/multiaspect/image.py b/helpers/multiaspect/image.py index 44f8ce0a..469c4a6d 100644 --- a/helpers/multiaspect/image.py +++ b/helpers/multiaspect/image.py @@ -169,7 +169,7 @@ def calculate_new_size_by_pixel_area(aspect_ratio: float, megapixels: float): if aspect_ratio == 1.0 and megapixels == 0.75: return 768, 768, 1.0 # Special case for 512px (0.5mp) images - if aspect_ratio == 1.0 and megapixels == 0.5: + if aspect_ratio == 1.0 and megapixels == 0.25: return 512, 512, 1.0 total_pixels = max(megapixels * 1e6, 1e6) W_initial = int(round((total_pixels * aspect_ratio) ** 0.5)) From e48fe870aa1651c883136b39d48129f7d05fb53d Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 2 May 2024 14:38:57 -0600 Subject: [PATCH 34/37] training sample should use a structured return for prepared samples --- helpers/image_manipulation/training_sample.py | 49 ++++++++++++++++--- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/helpers/image_manipulation/training_sample.py b/helpers/image_manipulation/training_sample.py index 06a775e2..8eef0b73 100644 --- a/helpers/image_manipulation/training_sample.py +++ b/helpers/image_manipulation/training_sample.py @@ -9,7 +9,9 @@ class TrainingSample: - def __init__(self, image: Image.Image, data_backend_id: str, metadata: dict = None): + def __init__( + self, image: Image.Image, data_backend_id: str, image_metadata: dict = None + ): """ Initializes a new TrainingSample instance with a provided PIL.Image object and a data backend identifier. @@ -20,11 +22,14 @@ def __init__(self, image: Image.Image, data_backend_id: str, metadata: dict = No """ self.image = image self.data_backend_id = data_backend_id - self.metadata = metadata if metadata else {} + self.image_metadata = image_metadata if image_metadata else {} if hasattr(image, "size"): self.original_size = self.image.size - elif metadata is not None: - self.original_size = metadata.get("original_size") + elif image_metadata is not None: + self.original_size = image_metadata.get("original_size") + print( + f"Metadata for training sample given instead of image? {image_metadata}" + ) if not self.original_size: raise Exception("Original size not found in metadata.") @@ -43,7 +48,7 @@ def __init__(self, image: Image.Image, data_backend_id: str, metadata: dict = No crop_handler_cls = crop_handlers.get(self.crop_style) if not crop_handler_cls: raise ValueError(f"Unknown crop style: {self.crop_style}") - self.cropper = crop_handler_cls(image=self.image, image_metadata=metadata) + self.cropper = crop_handler_cls(image=self.image, image_metadata=image_metadata) self.resolution = self.data_backend_config.get("resolution") self.resolution_type = self.data_backend_config.get("resolution_type") self.target_size_calculator = resize_helpers.get(self.resolution_type) @@ -88,7 +93,14 @@ def prepare(self, return_tensor: bool = False): if return_tensor: # Return normalised tensor. image = self.transforms(image) - return image, self.crop_coordinates, self.aspect_ratio + return PreparedSample( + image=image, + original_size=self.original_size, + crop_coordinates=self.crop_coordinates, + aspect_ratio=self.aspect_ratio, + image_metadata=self.image_metadata, + target_size=self.target_size, + ) def area(self) -> int: """ @@ -209,3 +221,28 @@ def get_image(self): Image.Image: The current image. """ return self.image + + +class PreparedSample: + def __init__( + self, + image: Image.Image, + image_metadata: dict, + original_size: tuple, + target_size: tuple, + aspect_ratio: float, + crop_coordinates: tuple, + ): + """ + Initializes a new PreparedSample instance with a provided PIL.Image object and optional metadata. + + Args: + image (Image.Image): A PIL Image object. + metadata (dict): Optional metadata associated with the image. + """ + self.image = image + self.image_metadata = image_metadata if image_metadata else {} + self.original_size = original_size + self.target_size = target_size + self.aspect_ratio = aspect_ratio + self.crop_coordinates = crop_coordinates From 51bb2a70fdfe07dd166979ad2427f73eb056dc74 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 2 May 2024 14:39:42 -0600 Subject: [PATCH 35/37] switch to using PreparedSample --- helpers/metadata/backends/json.py | 18 ++++++++------- helpers/metadata/backends/parquet.py | 22 ++++++++++--------- helpers/training/collate.py | 33 +++++++++------------------- 3 files changed, 32 insertions(+), 41 deletions(-) diff --git a/helpers/metadata/backends/json.py b/helpers/metadata/backends/json.py index bb127554..c2b83e7f 100644 --- a/helpers/metadata/backends/json.py +++ b/helpers/metadata/backends/json.py @@ -219,22 +219,24 @@ def _process_for_bucket( return aspect_ratio_bucket_indices image_metadata["original_size"] = image.size training_sample = TrainingSample( - image=image, data_backend_id=self.id, metadata=image_metadata + image=image, data_backend_id=self.id, image_metadata=image_metadata ) - image, crop_coordinates, new_aspect_ratio = training_sample.prepare() - image_metadata["crop_coordinates"] = crop_coordinates + prepared_sample = training_sample.prepare() + image_metadata["crop_coordinates"] = prepared_sample.crop_coordinates image_metadata["target_size"] = image.size # Round to avoid excessive unique buckets - image_metadata["aspect_ratio"] = new_aspect_ratio + image_metadata["aspect_ratio"] = prepared_sample.aspect_ratio image_metadata["luminance"] = calculate_luminance(image) logger.debug( - f"Image {image_path_str} has aspect ratio {new_aspect_ratio} and size {image.size}." + f"Image {image_path_str} has aspect ratio {prepared_sample.aspect_ratio} and size {image.size}." ) # Create a new bucket if it doesn't exist - if str(new_aspect_ratio) not in aspect_ratio_bucket_indices: - aspect_ratio_bucket_indices[str(new_aspect_ratio)] = [] - aspect_ratio_bucket_indices[str(new_aspect_ratio)].append(image_path_str) + if str(prepared_sample.aspect_ratio) not in aspect_ratio_bucket_indices: + aspect_ratio_bucket_indices[str(prepared_sample.aspect_ratio)] = [] + aspect_ratio_bucket_indices[str(prepared_sample.aspect_ratio)].append( + image_path_str + ) # Instead of directly updating, just fill the provided dictionary if metadata_updates is not None: metadata_updates[image_path_str] = image_metadata diff --git a/helpers/metadata/backends/parquet.py b/helpers/metadata/backends/parquet.py index 23cbe91c..f14d3d9b 100644 --- a/helpers/metadata/backends/parquet.py +++ b/helpers/metadata/backends/parquet.py @@ -360,15 +360,15 @@ def _process_for_bucket( image_metadata["original_size"] ) training_sample = TrainingSample( - image=None, data_backend_id=self.id, metadata=image_metadata + image=None, data_backend_id=self.id, image_metadata=image_metadata ) - target_size, crop_coordinates, new_aspect_ratio = training_sample.prepare() + prepared_sample = training_sample.prepare() print( - f"Prepared training sample: {target_size}, {crop_coordinates}, {new_aspect_ratio}" + f"Prepared training sample: {prepared_sample.target_size}, {prepared_sample.crop_coordinates}, {prepared_sample.aspect_ratio}" ) - image_metadata["crop_coordinates"] = crop_coordinates - image_metadata["target_size"] = target_size - image_metadata["aspect_ratio"] = new_aspect_ratio + image_metadata["crop_coordinates"] = prepared_sample.crop_coordinates + image_metadata["target_size"] = prepared_sample.target_size + image_metadata["aspect_ratio"] = prepared_sample.aspect_ratio luminance_column = self.parquet_config.get("luminance_column", None) if luminance_column: image_metadata["luminance"] = database_image_metadata[ @@ -377,13 +377,15 @@ def _process_for_bucket( else: image_metadata["luminance"] = 0 logger.debug( - f"Image {image_path_str} has aspect ratio {new_aspect_ratio} and size {image_metadata['target_size']}." + f"Image {image_path_str} has aspect ratio {prepared_sample.aspect_ratio} and size {image_metadata['target_size']}." ) # Create a new bucket if it doesn't exist - if str(new_aspect_ratio) not in aspect_ratio_bucket_indices: - aspect_ratio_bucket_indices[str(new_aspect_ratio)] = [] - aspect_ratio_bucket_indices[str(new_aspect_ratio)].append(image_path_str) + if str(prepared_sample.aspect_ratio) not in aspect_ratio_bucket_indices: + aspect_ratio_bucket_indices[str(prepared_sample.aspect_ratio)] = [] + aspect_ratio_bucket_indices[str(prepared_sample.aspect_ratio)].append( + image_path_str + ) # Instead of directly updating, just fill the provided dictionary if metadata_updates is not None: metadata_updates[image_path_str] = image_metadata diff --git a/helpers/training/collate.py b/helpers/training/collate.py index e3126987..101e52b4 100644 --- a/helpers/training/collate.py +++ b/helpers/training/collate.py @@ -1,10 +1,10 @@ import torch, logging, concurrent.futures, numpy as np +from PIL import Image from os import environ from helpers.training.state_tracker import StateTracker from helpers.training.multi_process import rank_info +from helpers.image_manipulation.training_sample import TrainingSample from helpers.multiaspect.image import MultiaspectImage -from helpers.image_manipulation.brightness import calculate_batch_luminance -from accelerate.logging import get_logger from concurrent.futures import ThreadPoolExecutor logger = logging.getLogger("collate_fn") @@ -75,30 +75,17 @@ def fetch_pixel_values(fp, data_backend_id: str): debug_log( f" -> pull pixels for fp {fp} from cache via data backend {data_backend_id}" ) - pixels = StateTracker.get_data_backend(data_backend_id)["data_backend"].read_image( + image = StateTracker.get_data_backend(data_backend_id)["data_backend"].read_image( fp ) - """ - def prepare_image( - resolution: float, - image: Image = None, - image_metadata: dict = None, - resolution_type: str = "pixel", - id: str = "foo", - ): - - """ - backend_config = StateTracker.get_data_backend_config(data_backend_id) - reformed_image, _, _ = MultiaspectImage.prepare_image( - resolution=backend_config["resolution"], - image=pixels, - image_metadata=None, - resolution_type=backend_config["resolution_type"], - id=data_backend_id, + training_sample = TrainingSample( + image=image, + data_backend_id=data_backend_id, + image_metadata=StateTracker.get_data_backend(data_backend_id)[ + "metadata_backend" + ].get_metadata_by_filepath(fp), ) - image_transform = MultiaspectImage.get_image_transforms()(reformed_image) - - return image_transform + return training_sample.prepare(return_tensor=True).image def fetch_latent(fp, data_backend_id: str): From 2202ab0aa1f6a039a1d7a694fa2e6a73b9c0f958 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 2 May 2024 15:46:09 -0600 Subject: [PATCH 36/37] trainingsample: fix square and aspect-preserving crops. --- helpers/image_manipulation/training_sample.py | 25 ++++- tests/test_training_sample.py | 103 ++++++++++++++++++ 2 files changed, 124 insertions(+), 4 deletions(-) create mode 100644 tests/test_training_sample.py diff --git a/helpers/image_manipulation/training_sample.py b/helpers/image_manipulation/training_sample.py index 8eef0b73..44af3b36 100644 --- a/helpers/image_manipulation/training_sample.py +++ b/helpers/image_manipulation/training_sample.py @@ -27,7 +27,7 @@ def __init__( self.original_size = self.image.size elif image_metadata is not None: self.original_size = image_metadata.get("original_size") - print( + logger.debug( f"Metadata for training sample given instead of image? {image_metadata}" ) @@ -43,7 +43,7 @@ def __init__( self.data_backend_config = StateTracker.get_data_backend_config(data_backend_id) self.crop_enabled = self.data_backend_config.get("crop", False) self.crop_style = self.data_backend_config.get("crop_style", "random") - self.crop_aspect = self.data_backend_config.get("crop_aspect", "random") + self.crop_aspect = self.data_backend_config.get("crop_aspect", "square") self.crop_coordinates = (0, 0) crop_handler_cls = crop_handlers.get(self.crop_style) if not crop_handler_cls: @@ -146,10 +146,18 @@ def downsample_before_crop(self): """ if self.image and self.should_downsample_before_crop(): width, height, _ = self.calculate_target_size(downsample_before_crop=True) - self.image = self.resize((width, height)) + logger.debug( + f"Downsampling image from {self.image.size} to {width}x{height} before cropping." + ) + self.resize((width, height)) return self def calculate_target_size(self, downsample_before_crop: bool = False): + # Square crops are always {self.pixel_resolution}x{self.pixel_resolution} + if self.crop_aspect == "square" and not downsample_before_crop: + self.aspect_ratio = 1.0 + self.target_size = (self.pixel_resolution, self.pixel_resolution) + return self.target_size[0], self.target_size[1], self.aspect_ratio self.aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio( self.original_size ) @@ -187,13 +195,22 @@ def crop(self): """ if not self.crop_enabled: return self + logger.debug( + f"Cropping image with {self.crop_style} style and {self.crop_aspect}." + ) # Too-big of an image, resize before we crop. self.downsample_before_crop() width, height, aspect_ratio = self.calculate_target_size( downsample_before_crop=False ) + logger.debug( + f"Pre-crop size: {self.image.size if hasattr(self.image, 'size') else 'Unknown'}." + ) self.image, self.crop_coordinates = self.cropper.crop(width, height) + logger.debug( + f"Post-crop size: {self.image.size if hasattr(self.image, 'size') else 'Unknown'}." + ) return self def resize(self, target_size: tuple = None): @@ -207,7 +224,7 @@ def resize(self, target_size: tuple = None): target_width, target_height, aspect_ratio = self.calculate_target_size() target_size = (target_width, target_height) if self.image: - self.image = self.image.resize(target_size, Image.LANCZOS) + self.image = self.image.resize(target_size, Image.Resampling.LANCZOS) self.aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio( self.image.size ) diff --git a/tests/test_training_sample.py b/tests/test_training_sample.py new file mode 100644 index 00000000..11c030dc --- /dev/null +++ b/tests/test_training_sample.py @@ -0,0 +1,103 @@ +import unittest +from PIL import Image +import numpy as np +from helpers.image_manipulation.training_sample import TrainingSample +from helpers.training.state_tracker import StateTracker +from helpers.multiaspect.image import resize_helpers, crop_handlers +from unittest.mock import MagicMock + + +class TestTrainingSample(unittest.TestCase): + + def setUp(self): + # Create a simple image for testing + self.image = Image.new("RGB", (1024, 768), "white") + self.data_backend_id = "test_backend" + self.image_metadata = {"original_size": (1024, 768)} + + # Assume StateTracker and other helpers are correctly set up to return meaningful values + StateTracker.get_args = MagicMock() + StateTracker.get_args.return_value = MagicMock(aspect_bucket_alignment=8) + StateTracker.get_data_backend_config = MagicMock( + return_value={ + "crop": True, + "crop_style": "center", + "crop_aspect": "square", + "resolution": 512, + "resolution_type": "pixel", + "target_downsample_size": 256, + "maximum_image_size": 1024, + "aspect_bucket_alignment": 8, + } + ) + + def test_image_initialization(self): + """Test that the image is correctly initialized and converted.""" + sample = TrainingSample(self.image, self.data_backend_id, self.image_metadata) + self.assertEqual(sample.original_size, (1024, 768)) + + def test_image_downsample(self): + """Test that downsampling is correctly applied before cropping.""" + sample = TrainingSample(self.image, self.data_backend_id, self.image_metadata) + print(f"Before size: {sample.image.size}") + sample.prepare() + print(f"After size: {sample.image.size}") + self.assertLessEqual( + sample.image.size[0], 512 + ) # Assuming downsample before crop applies + + def test_no_crop(self): + """Test handling when cropping is disabled.""" + StateTracker.get_data_backend_config = lambda x: { + "crop": False, + "crop_style": "random", + "resolution": 512, + "resolution_type": "pixel", + } + sample = TrainingSample(self.image, self.data_backend_id, self.image_metadata) + original_size = sample.image.size + sample.prepare() + self.assertNotEqual(sample.image.size, original_size) # Ensure resizing occurs + + def test_crop_coordinates(self): + """Test that cropping returns correct coordinates.""" + sample = TrainingSample(self.image, self.data_backend_id, self.image_metadata) + sample.prepare() + self.assertIsNotNone(sample.crop_coordinates) # Crop coordinates should be set + + def test_aspect_ratio_square_up(self): + """Test that the aspect ratio is preserved after processing.""" + sample = TrainingSample(self.image, self.data_backend_id, self.image_metadata) + original_aspect = sample.original_size[0] / sample.original_size[1] + sample.prepare() + processed_aspect = sample.image.size[0] / sample.image.size[1] + self.assertEqual(processed_aspect, 1.0) + + def test_return_tensor(self): + """Test tensor conversion if requested.""" + sample = TrainingSample(self.image, self.data_backend_id, self.image_metadata) + prepared_sample = sample.prepare(return_tensor=True) + # Check if returned object is a tensor (mock or check type if actual tensor transformation is applied) + self.assertTrue( + isinstance(prepared_sample.aspect_ratio, float) + ) # Placeholder check + + +# Helper mock classes and functions +class MockCropper: + def __init__(self, image, image_metadata): + self.image = image + self.image_metadata = image_metadata + + def crop(self, width, height): + return self.image.crop((0, 0, width, height)), (0, 0, width, height) + + +def mock_resize_helper(aspect_ratio, resolution): + # Simulates resizing logic + width, height = resolution, int(resolution / aspect_ratio) + return width, height, aspect_ratio + + +if __name__ == "__main__": + unittest.main() From cbf0f07bb9fce4e18452490a5115bf41f5be0852 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 2 May 2024 15:46:35 -0600 Subject: [PATCH 37/37] remove prints --- tests/test_training_sample.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_training_sample.py b/tests/test_training_sample.py index 11c030dc..1d9f943f 100644 --- a/tests/test_training_sample.py +++ b/tests/test_training_sample.py @@ -39,9 +39,7 @@ def test_image_initialization(self): def test_image_downsample(self): """Test that downsampling is correctly applied before cropping.""" sample = TrainingSample(self.image, self.data_backend_id, self.image_metadata) - print(f"Before size: {sample.image.size}") sample.prepare() - print(f"After size: {sample.image.size}") self.assertLessEqual( sample.image.size[0], 512 ) # Assuming downsample before crop applies