diff --git a/INSTALL.md b/INSTALL.md index 067fbe30..bc8ef45c 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -22,7 +22,7 @@ To install the Apple-specific requirements: poetry install --no-root -C install/apple ``` -### Linux +### Linux + Nvidia/CUDA The first command you'll run will install most of the dependencies: @@ -59,6 +59,16 @@ If the egg install for Xformers does not work, try including `xformers` on the f pip3 install --pre xformers torch torchvision torchaudio torchtriton --extra-index-url https://download.pytorch.org/whl/nightly/cu118 --force ``` +### Linux + AMD / ROCm +Due to `xformers` not supporting the ROCm platform, memory requirements for training will likely be higher than otherwise stated. + +To install the ROCm-specific requirements: + +```bash +poetry install --no-root -C install/rocm +``` + + ### All platforms 2. For SD2.1, copy `sd21-env.sh.example` to `env.sh` - be sure to fill out the details. Try to change as little as possible. @@ -89,4 +99,4 @@ For SDXL, run the `train_sdxl.sh` script, redirecting outputs to the log file: bash train_sdxl.sh > /path/to/training-$(date +%s).log 2>&1 ``` -> ⚠️ At this point, the commands will work, but further configuration is required. See [the tutorial](/TUTORIAL.md) for more information. \ No newline at end of file +> ⚠️ At this point, the commands will work, but further configuration is required. See [the tutorial](/TUTORIAL.md) for more information. diff --git a/README.md b/README.md index d4c91d8b..d30a269b 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ EMA (exponential moving average) weights are a memory-heavy affair, but provide ### GPU vendors * NVIDIA - pretty much anything 3090 and up is a safe bet. YMMV. -* AMD - No one has reported anything, we don't know. +* AMD - SDXL LoRA and UNet are verified working on a 7900 XTX 24GB. Lacking `xformers`, it will likely use more memory than Nvidia equivalents * Apple - LoRA and full u-net tuning are tested to work on an M3 Max with 128G memory, taking about **12G** of "Wired" memory and **4G** of system memory for SDXL. * You likely need a 24G or greater machine for machine learning with M-series hardware due to the lack of memory-efficient attention. diff --git a/TUTORIAL.md b/TUTORIAL.md index 63fc3e61..1795c6e9 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -225,9 +225,8 @@ Here's a breakdown of what each environment variable does: #### Data Locations -- `BASE_DIR`, `INSTANCE_DIR`, `OUTPUT_DIR`: Directories for the training data, instance data, and output models. +- `BASE_DIR`, `OUTPUT_DIR`: Directories for the training data, instance data, and output models. - `BASE_DIR` - Used for populating other variables, mostly. - - `INSTANCE_DIR` - Where your actual training data is. This can be anywhere, it does not need to be underneath `BASE_DIR`. - `OUTPUT_DIR` - Where the model pipeline results are stored during training, and after it completes. #### Training Parameters @@ -236,9 +235,9 @@ Here's a breakdown of what each environment variable does: - If you use `MAX_NUM_STEPS`, it's recommended to set `NUM_EPOCHS` to `0`. - Similarly, if you use `NUM_EPOCHS`, it is recommended to set `MAX_NUM_STEPS` to `0`. - This simply signals to the trainer that you explicitly wish to use one or the other. - - If you supply `NUM_EPOCHS` and `MAX_NUM_STEPS` together, the training will stop running at whichever happens first. + - Don't supply `NUM_EPOCHS` and `MAX_NUM_STEPS` values together, it won't let you begin training, to ensure there is no ambiguity about which you expect to take priority. - `LR_SCHEDULE`, `LR_WARMUP_STEPS`: Learning rate schedule and warmup steps. - - `LR_SCHEDULE` - stick to `constant`, as it is most likely to be stable and less chaotic. However, `polynomial` and `constant_with_warmup` have potential of moving the model's local minima before settling in and reducing the loss. Experimentation can pay off here. + - `LR_SCHEDULE` - stick to `constant`, as it is most likely to be stable and less chaotic. However, `polynomial` and `constant_with_warmup` have potential of moving the model's local minima before settling in and reducing the loss. Experimentation can pay off here, especially using the `cosine` and `sine` schedulers, which offer a unique approach to learning rate scheduling. - `TRAIN_BATCH_SIZE`: Batch size for training. You want this **as high as you can get it** without running out of VRAM or making your training unnecessarily **slow** (eg. 300-400% increase in training runtime - yikes! 💸) ## Additional Notes diff --git a/documentation/DEEPFLOYD.md b/documentation/DEEPFLOYD.md index 95dd4521..f908b308 100644 --- a/documentation/DEEPFLOYD.md +++ b/documentation/DEEPFLOYD.md @@ -108,7 +108,6 @@ export LEARNING_RATE_END=4e-6 #@param {type:"number"} # Where to store your results. export BASE_DIR="/training" -export INSTANCE_DIR="${BASE_DIR}/data" export OUTPUT_DIR="${BASE_DIR}/models/deepfloyd" export DATALOADER_CONFIG="multidatabackend_deepfloyd.json" diff --git a/helpers/arguments.py b/helpers/arguments.py index c3669c89..960f0bcb 100644 --- a/helpers/arguments.py +++ b/helpers/arguments.py @@ -1332,7 +1332,7 @@ def parse_args(input_args=None): logger.warning( "MPS may benefit from the use of --unet_attention_slice for memory savings at the cost of speed." ) - if args.train_batch_size > 12: + if args.train_batch_size > 16: logger.error( "An M3 Max 128G will use 12 seconds per step at a batch size of 1 and 65 seconds per step at a batch size of 12." " Any higher values will result in NDArray size errors or other unstable training results and crashes." @@ -1346,6 +1346,12 @@ def parse_args(input_args=None): ) sys.exit(1) + if args.max_train_steps is not None and args.num_train_epochs > 0: + logger.error( + "When using --max_train_steps (MAX_NUM_STEPS), you must set --num_train_epochs (NUM_EPOCHS) to 0." + ) + sys.exit(1) + if ( args.pretrained_vae_model_name_or_path is not None and StateTracker.get_model_type() == "legacy" @@ -1366,7 +1372,7 @@ def parse_args(input_args=None): deepfloyd_pixel_alignment = 8 if args.aspect_bucket_alignment != deepfloyd_pixel_alignment: logger.warning( - f"Overriding aspect bucket alignment pixel interval to {deepfloyd_pixel_alignment}px instead of{args.aspect_bucket_alignment}px." + f"Overriding aspect bucket alignment pixel interval to {deepfloyd_pixel_alignment}px instead of {args.aspect_bucket_alignment}px." ) args.aspect_bucket_alignment = deepfloyd_pixel_alignment diff --git a/helpers/caching/vae.py b/helpers/caching/vae.py index e29c9c73..1759fab0 100644 --- a/helpers/caching/vae.py +++ b/helpers/caching/vae.py @@ -595,7 +595,11 @@ def _process_images_in_batch( # retrieve image data from Generator, image_data: filepath = image_paths.pop() image = image_data.pop() - aspect_bucket = MultiaspectImage.calculate_image_aspect_ratio(image) + aspect_bucket = ( + self.metadata_backend.get_metadata_attribute_by_filepath( + filepath=filepath, attribute="aspect_bucket" + ) + ) else: filepath, image, aspect_bucket = self.process_queue.get() if self.minimum_image_size is not None: diff --git a/helpers/image_manipulation/cropping.py b/helpers/image_manipulation/cropping.py new file mode 100644 index 00000000..24a694c2 --- /dev/null +++ b/helpers/image_manipulation/cropping.py @@ -0,0 +1,96 @@ +from PIL import Image + + +class BaseCropping: + def __init__(self, image: Image = None, image_metadata: dict = None): + self.image = image + self.image_metadata = image_metadata + if self.image: + self.original_width, self.original_height = self.image.size + elif self.image_metadata: + self.original_width, self.original_height = self.image_metadata[ + "original_size" + ] + + def crop(self, target_width, target_height): + raise NotImplementedError("Subclasses must implement this method") + + +class CornerCropping(BaseCropping): + def crop(self, target_width, target_height): + left = max(0, self.original_width - target_width) + top = max(0, self.original_height - target_height) + right = self.original_width + bottom = self.original_height + if self.image: + return self.image.crop((left, top, right, bottom)), (left, top) + elif self.image_metadata: + return self.image_metadata, (left, top) + + +class CenterCropping(BaseCropping): + def crop(self, target_width, target_height): + left = (self.original_width - target_width) / 2 + top = (self.original_height - target_height) / 2 + right = (self.original_width + target_width) / 2 + bottom = (self.original_height + target_height) / 2 + if self.image: + return self.image.crop((left, top, right, bottom)), (left, top) + elif self.image_metadata: + return self.image_metadata, (left, top) + + +class RandomCropping(BaseCropping): + def crop(self, target_width, target_height): + import random + + left = random.randint(0, max(0, self.original_width - target_width)) + top = random.randint(0, max(0, self.original_height - target_height)) + right = left + target_width + bottom = top + target_height + if self.image: + return self.image.crop((left, top, right, bottom)), (left, top) + elif self.image_metadata: + return self.image_metadata, (left, top) + + +class FaceCropping(RandomCropping): + def crop( + self, + image: Image.Image, + target_width: int, + target_height: int, + ): + # Import modules + import cv2 + import numpy as np + + # Detect a face in the image + face_cascade = cv2.CascadeClassifier( + cv2.data.haarcascades + "haarcascade_frontalface_default.xml" + ) + image = image.convert("RGB") + image = np.array(image) + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + faces = face_cascade.detectMultiScale(gray, 1.1, 4) + if len(faces) > 0: + # Get the largest face + face = max(faces, key=lambda f: f[2] * f[3]) + x, y, w, h = face + left = max(0, x - 0.5 * w) + top = max(0, y - 0.5 * h) + right = min(image.shape[1], x + 1.5 * w) + bottom = min(image.shape[0], y + 1.5 * h) + image = Image.fromarray(image) + return image.crop((left, top, right, bottom)), (left, top) + else: + # Crop the image from a random position + return super.crop(image, target_width, target_height) + + +crop_handlers = { + "corner": CornerCropping, + "centre": CenterCropping, + "center": CenterCropping, + "random": RandomCropping, +} diff --git a/helpers/image_manipulation/training_sample.py b/helpers/image_manipulation/training_sample.py new file mode 100644 index 00000000..44af3b36 --- /dev/null +++ b/helpers/image_manipulation/training_sample.py @@ -0,0 +1,265 @@ +from PIL import Image +from PIL.ImageOps import exif_transpose +from helpers.multiaspect.image import MultiaspectImage, resize_helpers +from helpers.multiaspect.image import crop_handlers +from helpers.training.state_tracker import StateTracker +import logging + +logger = logging.getLogger(__name__) + + +class TrainingSample: + def __init__( + self, image: Image.Image, data_backend_id: str, image_metadata: dict = None + ): + """ + Initializes a new TrainingSample instance with a provided PIL.Image object and a data backend identifier. + + Args: + image (Image.Image): A PIL Image object. + data_backend_id (str): Identifier for the data backend used for additional operations. + metadata (dict): Optional metadata associated with the image. + """ + self.image = image + self.data_backend_id = data_backend_id + self.image_metadata = image_metadata if image_metadata else {} + if hasattr(image, "size"): + self.original_size = self.image.size + elif image_metadata is not None: + self.original_size = image_metadata.get("original_size") + logger.debug( + f"Metadata for training sample given instead of image? {image_metadata}" + ) + + if not self.original_size: + raise Exception("Original size not found in metadata.") + + # Torchvision transforms turn the pixels into a Tensor and normalize them for the VAE. + self.transforms = MultiaspectImage.get_image_transforms() + # EXIT, RGB conversions. + self.correct_image() + + # Backend config details + self.data_backend_config = StateTracker.get_data_backend_config(data_backend_id) + self.crop_enabled = self.data_backend_config.get("crop", False) + self.crop_style = self.data_backend_config.get("crop_style", "random") + self.crop_aspect = self.data_backend_config.get("crop_aspect", "square") + self.crop_coordinates = (0, 0) + crop_handler_cls = crop_handlers.get(self.crop_style) + if not crop_handler_cls: + raise ValueError(f"Unknown crop style: {self.crop_style}") + self.cropper = crop_handler_cls(image=self.image, image_metadata=image_metadata) + self.resolution = self.data_backend_config.get("resolution") + self.resolution_type = self.data_backend_config.get("resolution_type") + self.target_size_calculator = resize_helpers.get(self.resolution_type) + if self.target_size_calculator is None: + raise ValueError(f"Unknown resolution type: {self.resolution_type}") + if self.resolution_type == "pixel": + self.target_area = self.resolution + # Store the pixel value, eg. 1024 + self.pixel_resolution = self.resolution + # Store the megapixel value, eg. 1.0 + self.megapixel_resolution = self.resolution / 1e6 + elif self.resolution_type == "area": + self.target_area = self.resolution * 1e6 # Convert megapixels to pixels + # Store the pixel value, eg. 1024 + self.pixel_resolution = self.resolution * 1e6 + # Store the megapixel value, eg. 1.0 + self.megapixel_resolution = self.resolution + else: + raise Exception(f"Unknown resolution type: {self.resolution_type}") + self.target_downsample_size = self.data_backend_config.get( + "target_downsample_size", None + ) + self.maximum_image_size = self.data_backend_config.get( + "maximum_image_size", None + ) + + def prepare(self, return_tensor: bool = False): + """ + Perform initial image preparations such as converting to RGB and applying EXIF transformations. + + Args: + image (Image.Image): The image to prepare. + + Returns: + (image, crop_coordinates, aspect_ratio) + """ + self.crop() + if not self.crop_enabled: + self.resize() + + image = self.image + if return_tensor: + # Return normalised tensor. + image = self.transforms(image) + return PreparedSample( + image=image, + original_size=self.original_size, + crop_coordinates=self.crop_coordinates, + aspect_ratio=self.aspect_ratio, + image_metadata=self.image_metadata, + target_size=self.target_size, + ) + + def area(self) -> int: + """ + Calculate the area of the image. + + Returns: + int: The area of the image. + """ + if self.image is not None: + return self.image.size[0] * self.image.size[1] + if self.original_size: + return self.original_size[0] * self.original_size[1] + + def should_downsample_before_crop(self) -> bool: + """ + Returns: + bool: True if the image should be downsampled before cropping, False otherwise. + """ + if ( + not self.crop_enabled + or not self.maximum_image_size + or not self.target_downsample_size + ): + return False + if self.data_backend_config.get("resolution_type") == "pixel": + return ( + self.image.size[0] > self.pixel_resolution + or self.image.size[1] > self.pixel_resolution + ) + elif self.data_backend_config.get("resolution_type") == "area": + logger.debug( + f"Image is too large? {self.area() > self.target_area} (image area: {self.area()}, target area: {self.target_area})" + ) + return self.area() > self.target_area + else: + raise ValueError( + f"Unknown resolution type: {self.data_backend_config.get('resolution_type')}" + ) + + def downsample_before_crop(self): + """ + Downsample the image before cropping, to preserve scene details. + """ + if self.image and self.should_downsample_before_crop(): + width, height, _ = self.calculate_target_size(downsample_before_crop=True) + logger.debug( + f"Downsampling image from {self.image.size} to {width}x{height} before cropping." + ) + self.resize((width, height)) + return self + + def calculate_target_size(self, downsample_before_crop: bool = False): + # Square crops are always {self.pixel_resolution}x{self.pixel_resolution} + if self.crop_aspect == "square" and not downsample_before_crop: + self.aspect_ratio = 1.0 + self.target_size = (self.pixel_resolution, self.pixel_resolution) + return self.target_size[0], self.target_size[1], self.aspect_ratio + self.aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio( + self.original_size + ) + if downsample_before_crop and self.target_downsample_size is not None: + target_width, target_height, self.aspect_ratio = ( + self.target_size_calculator( + self.aspect_ratio, self.target_downsample_size + ) + ) + else: + target_width, target_height, self.aspect_ratio = ( + self.target_size_calculator(self.aspect_ratio, self.resolution) + ) + self.target_size = (target_width, target_height) + self.aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio( + self.target_size + ) + + return self.target_size[0], self.target_size[1], self.aspect_ratio + + def correct_image(self): + """ + Apply a series of transformations to the image to "correct" it. + """ + if self.image: + # Convert image to RGB to remove any alpha channel and apply EXIF data transformations + self.image = self.image.convert("RGB") + self.image = exif_transpose(self.image) + self.original_size = self.image.size + return self + + def crop(self): + """ + Crop the image using the detected crop handler class. + """ + if not self.crop_enabled: + return self + logger.debug( + f"Cropping image with {self.crop_style} style and {self.crop_aspect}." + ) + + # Too-big of an image, resize before we crop. + self.downsample_before_crop() + width, height, aspect_ratio = self.calculate_target_size( + downsample_before_crop=False + ) + logger.debug( + f"Pre-crop size: {self.image.size if hasattr(self.image, 'size') else 'Unknown'}." + ) + self.image, self.crop_coordinates = self.cropper.crop(width, height) + logger.debug( + f"Post-crop size: {self.image.size if hasattr(self.image, 'size') else 'Unknown'}." + ) + return self + + def resize(self, target_size: tuple = None): + """ + Resize the image to a new size. + + Args: + target_size (tuple): The target size as (width, height). + """ + if target_size is None: + target_width, target_height, aspect_ratio = self.calculate_target_size() + target_size = (target_width, target_height) + if self.image: + self.image = self.image.resize(target_size, Image.Resampling.LANCZOS) + self.aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio( + self.image.size + ) + return self + + def get_image(self): + """ + Returns the current state of the image. + + Returns: + Image.Image: The current image. + """ + return self.image + + +class PreparedSample: + def __init__( + self, + image: Image.Image, + image_metadata: dict, + original_size: tuple, + target_size: tuple, + aspect_ratio: float, + crop_coordinates: tuple, + ): + """ + Initializes a new PreparedSample instance with a provided PIL.Image object and optional metadata. + + Args: + image (Image.Image): A PIL Image object. + metadata (dict): Optional metadata associated with the image. + """ + self.image = image + self.image_metadata = image_metadata if image_metadata else {} + self.original_size = original_size + self.target_size = target_size + self.aspect_ratio = aspect_ratio + self.crop_coordinates = crop_coordinates diff --git a/helpers/legacy/sd_files.py b/helpers/legacy/sd_files.py index 3ada352a..8b7c14ff 100644 --- a/helpers/legacy/sd_files.py +++ b/helpers/legacy/sd_files.py @@ -133,19 +133,18 @@ def load_model_hook(models, input_dir): f"Could not find training_state.json in checkpoint dir {input_dir}" ) - unet_ = None - text_encoder = None - while len(models) > 0: - model = models.pop() - - if isinstance(model, type(unwrap_model(accelerator, unet))): - unet_ = model - elif isinstance(model, type(unwrap_model(accelerator, text_encoder))): - text_encoder = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") - if "lora" in args.model_type: + unet_ = None + text_encoder_ = None + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(accelerator, unet))): + unet_ = model + elif isinstance(model, type(unwrap_model(accelerator, text_encoder))): + text_encoder_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") logger.info(f"Loading LoRA weights from Path: {input_dir}") lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) @@ -173,7 +172,7 @@ def load_model_hook(models, input_dir): _set_state_dict_into_text_encoder( lora_state_dict, prefix="text_encoder.", - text_encoder=text_encoder, + text_encoder=text_encoder_, ) logger.info("Completed loading LoRA weights.") diff --git a/helpers/legacy/validation.py b/helpers/legacy/validation.py index 269c9c9f..9954df8c 100644 --- a/helpers/legacy/validation.py +++ b/helpers/legacy/validation.py @@ -491,7 +491,7 @@ def log_validations( validation_resolutions = ( get_validation_resolutions() - if "deepfloyd" not in args.model_type + if "deepfloyd-stage2" not in args.model_type else ["base-256"] ) logger.debug(f"Resolutions for validation: {validation_resolutions}") @@ -499,7 +499,7 @@ def log_validations( validation_images[validation_shortname] = [] for resolution in validation_resolutions: - if "deepfloyd" not in args.model_type: + if "deepfloyd-stage2" not in args.model_type: validation_resolution_width, validation_resolution_height = ( resolution ) diff --git a/helpers/metadata/backends/base.py b/helpers/metadata/backends/base.py index de000c07..651c9bec 100644 --- a/helpers/metadata/backends/base.py +++ b/helpers/metadata/backends/base.py @@ -110,14 +110,22 @@ def _bucket_worker( for file in files: if str(file) not in existing_files_set: logger.debug(f"Processing file {file}.") - local_aspect_ratio_bucket_indices = self._process_for_bucket( - file, - local_aspect_ratio_bucket_indices, - metadata_updates=local_metadata_updates, - delete_problematic_images=self.delete_problematic_images, - statistics=statistics, + try: + local_aspect_ratio_bucket_indices = self._process_for_bucket( + file, + local_aspect_ratio_bucket_indices, + metadata_updates=local_metadata_updates, + delete_problematic_images=self.delete_problematic_images, + statistics=statistics, + ) + except Exception as e: + logger.error( + f"Error processing file {file}. Reason: {e}. Skipping." + ) + statistics["skipped"]["error"] += 1 + logger.debug( + f"Statistics: {statistics}, total: {sum([len(bucket) for bucket in local_aspect_ratio_bucket_indices.values()])}" ) - logger.debug(f"Statistics: {statistics}") processed_file_count += 1 # Successfully processed statistics["total_processed"] = processed_file_count @@ -146,6 +154,7 @@ def _bucket_worker( metadata_updates_queue.put(local_metadata_updates) # At the end of the _bucket_worker method metadata_updates_queue.put(("statistics", statistics)) + time.sleep(0.001) logger.debug(f"Bucket worker completed processing. Returning to main thread.") def compute_aspect_ratio_bucket_indices(self): @@ -731,9 +740,7 @@ def handle_vae_cache_inconsistencies(self, vae_cache, vae_cache_behavior: str): continue if vae_cache_behavior == "sync": # Sync aspect buckets with the cache - expected_bucket = MultiaspectImage.determine_bucket_for_aspect_ratio( - self._get_aspect_ratio_from_tensor(cache_content) - ) + expected_bucket = str(self._get_aspect_ratio_from_tensor(cache_content)) self._modify_cache_entry_bucket(cache_file, expected_bucket) elif vae_cache_behavior == "recreate": # Delete the cache file if it doesn't match the aspect bucket indices @@ -745,9 +752,7 @@ def handle_vae_cache_inconsistencies(self, vae_cache, vae_cache_behavior: str): # Update any state or metadata post-processing self.save_cache() - def _recalculate_target_resolution( - self, original_aspect_ratio: float - ) -> tuple: + def _recalculate_target_resolution(self, original_aspect_ratio: float) -> tuple: """Given the original resolution, use our backend config to properly recalculate the size.""" resolution_type = StateTracker.get_data_backend_config(self.id)[ "resolution_type" @@ -803,7 +808,9 @@ def is_cache_inconsistent(self, vae_cache, cache_file, cache_content): target_resolution = tuple(metadata_target_size) recalculated_width, recalculated_height, recalculated_aspect_ratio = ( self._recalculate_target_resolution( - original_aspect_ratio=MultiaspectImage.calculate_image_aspect_ratio(original_resolution) + original_aspect_ratio=MultiaspectImage.calculate_image_aspect_ratio( + original_resolution + ) ) ) recalculated_target_resolution = (recalculated_width, recalculated_height) @@ -828,9 +835,7 @@ def is_cache_inconsistent(self, vae_cache, cache_file, cache_content): ) actual_aspect_ratio = self._get_aspect_ratio_from_tensor(cache_content) - expected_bucket = MultiaspectImage.determine_bucket_for_aspect_ratio( - recalculated_aspect_ratio - ) + expected_bucket = str(recalculated_aspect_ratio) logger.debug( f"Expected bucket for {cache_file}: {expected_bucket} vs actual {actual_aspect_ratio}" ) diff --git a/helpers/metadata/backends/json.py b/helpers/metadata/backends/json.py index f24ad149..c2b83e7f 100644 --- a/helpers/metadata/backends/json.py +++ b/helpers/metadata/backends/json.py @@ -2,6 +2,7 @@ from helpers.multiaspect.image import MultiaspectImage from helpers.data_backend.base import BaseDataBackend from helpers.metadata.backends.base import MetadataBackend +from helpers.image_manipulation.training_sample import TrainingSample from pathlib import Path import json, logging, os, time, re from multiprocessing import Manager @@ -217,27 +218,25 @@ def _process_for_bucket( statistics["skipped"]["too_small"] += 1 return aspect_ratio_bucket_indices image_metadata["original_size"] = image.size - image, crop_coordinates, new_aspect_ratio = ( - MultiaspectImage.prepare_image( - image=image, - resolution=self.resolution, - resolution_type=self.resolution_type, - id=self.data_backend.id, - ) + training_sample = TrainingSample( + image=image, data_backend_id=self.id, image_metadata=image_metadata ) - image_metadata["crop_coordinates"] = crop_coordinates + prepared_sample = training_sample.prepare() + image_metadata["crop_coordinates"] = prepared_sample.crop_coordinates image_metadata["target_size"] = image.size # Round to avoid excessive unique buckets - image_metadata["aspect_ratio"] = new_aspect_ratio + image_metadata["aspect_ratio"] = prepared_sample.aspect_ratio image_metadata["luminance"] = calculate_luminance(image) logger.debug( - f"Image {image_path_str} has aspect ratio {new_aspect_ratio} and size {image.size}." + f"Image {image_path_str} has aspect ratio {prepared_sample.aspect_ratio} and size {image.size}." ) # Create a new bucket if it doesn't exist - if str(new_aspect_ratio) not in aspect_ratio_bucket_indices: - aspect_ratio_bucket_indices[str(new_aspect_ratio)] = [] - aspect_ratio_bucket_indices[str(new_aspect_ratio)].append(image_path_str) + if str(prepared_sample.aspect_ratio) not in aspect_ratio_bucket_indices: + aspect_ratio_bucket_indices[str(prepared_sample.aspect_ratio)] = [] + aspect_ratio_bucket_indices[str(prepared_sample.aspect_ratio)].append( + image_path_str + ) # Instead of directly updating, just fill the provided dictionary if metadata_updates is not None: metadata_updates[image_path_str] = image_metadata diff --git a/helpers/metadata/backends/parquet.py b/helpers/metadata/backends/parquet.py index 07c6275a..f14d3d9b 100644 --- a/helpers/metadata/backends/parquet.py +++ b/helpers/metadata/backends/parquet.py @@ -1,6 +1,7 @@ from helpers.training.state_tracker import StateTracker from helpers.multiaspect.image import MultiaspectImage from helpers.data_backend.base import BaseDataBackend +from helpers.image_manipulation.training_sample import TrainingSample from helpers.metadata.backends.base import MetadataBackend from tqdm import tqdm import json, logging, os, time @@ -355,18 +356,19 @@ def _process_for_bucket( image_metadata["original_size"][0] / image_metadata["original_size"][1] ) - aspect_ratio = round(aspect_ratio, aspect_ratio_rounding) - target_size, crop_coordinates, new_aspect_ratio = ( - MultiaspectImage.prepare_image( - image_metadata=image_metadata, - resolution=self.resolution, - resolution_type=self.resolution_type, - id=self.data_backend.id, - ) + aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio( + image_metadata["original_size"] + ) + training_sample = TrainingSample( + image=None, data_backend_id=self.id, image_metadata=image_metadata ) - image_metadata["crop_coordinates"] = crop_coordinates - image_metadata["target_size"] = target_size - image_metadata["aspect_ratio"] = new_aspect_ratio + prepared_sample = training_sample.prepare() + print( + f"Prepared training sample: {prepared_sample.target_size}, {prepared_sample.crop_coordinates}, {prepared_sample.aspect_ratio}" + ) + image_metadata["crop_coordinates"] = prepared_sample.crop_coordinates + image_metadata["target_size"] = prepared_sample.target_size + image_metadata["aspect_ratio"] = prepared_sample.aspect_ratio luminance_column = self.parquet_config.get("luminance_column", None) if luminance_column: image_metadata["luminance"] = database_image_metadata[ @@ -375,13 +377,15 @@ def _process_for_bucket( else: image_metadata["luminance"] = 0 logger.debug( - f"Image {image_path_str} has aspect ratio {aspect_ratio} and size {image_metadata['target_size']}." + f"Image {image_path_str} has aspect ratio {prepared_sample.aspect_ratio} and size {image_metadata['target_size']}." ) # Create a new bucket if it doesn't exist - if str(new_aspect_ratio) not in aspect_ratio_bucket_indices: - aspect_ratio_bucket_indices[str(new_aspect_ratio)] = [] - aspect_ratio_bucket_indices[str(new_aspect_ratio)].append(image_path_str) + if str(prepared_sample.aspect_ratio) not in aspect_ratio_bucket_indices: + aspect_ratio_bucket_indices[str(prepared_sample.aspect_ratio)] = [] + aspect_ratio_bucket_indices[str(prepared_sample.aspect_ratio)].append( + image_path_str + ) # Instead of directly updating, just fill the provided dictionary if metadata_updates is not None: metadata_updates[image_path_str] = image_metadata diff --git a/helpers/multiaspect/image.py b/helpers/multiaspect/image.py index beaa196f..469c4a6d 100644 --- a/helpers/multiaspect/image.py +++ b/helpers/multiaspect/image.py @@ -6,6 +6,7 @@ import logging, os, random from math import sqrt from helpers.training.state_tracker import StateTracker +from helpers.image_manipulation.cropping import crop_handlers logger = logging.getLogger("MultiaspectImage") logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) @@ -21,284 +22,6 @@ def get_image_transforms(): ] ) - @staticmethod - def prepare_image( - resolution: float, - image: Image = None, - image_metadata: dict = None, - resolution_type: str = "pixel", - id: str = "foo", - ): - if image: - if not hasattr(image, "convert"): - raise Exception( - f"Unknown data received instead of PIL.Image object: {type(image)}" - ) - # Strip transparency - image = image.convert("RGB") - # Rotate, maybe. - logger.debug(f"Processing image filename: {image}") - logger.debug(f"Image size before EXIF transform: {image.size}") - image = exif_transpose(image) - logger.debug(f"Image size after EXIF transform: {image.size}") - image_size = image.size - elif image_metadata: - image_size = ( - image_metadata["original_size"][0], - image_metadata["original_size"][1], - ) - original_width, original_height = image_size - original_resolution = resolution - # Convert 'resolution' from eg. "1 megapixel" to "1024 pixels" - if resolution_type == "area": - original_resolution = original_resolution * 1e3 - # Make resolution a multiple of StateTracker.get_args().aspect_bucket_alignment - original_resolution = MultiaspectImage._round_to_nearest_multiple( - original_resolution - ) - - # Downsample before we handle, if necessary. - downsample_before_crop = False - crop = StateTracker.get_data_backend_config(data_backend_id=id).get( - "crop", False - ) - maximum_image_size = StateTracker.get_data_backend_config( - data_backend_id=id - ).get("maximum_image_size", None) - target_downsample_size = StateTracker.get_data_backend_config( - data_backend_id=id - ).get("target_downsample_size", None) - logger.debug( - f"Dataset: {id}, maximum_image_size: {maximum_image_size}, target_downsample_size: {target_downsample_size}" - ) - if crop and maximum_image_size and target_downsample_size: - if MultiaspectImage.is_image_too_large( - image_size, maximum_image_size, resolution_type=resolution_type - ): - # Override the target resolution with the target downsample size - logger.debug( - f"Overriding resolution {resolution} with target downsample size: {target_downsample_size}" - ) - resolution = target_downsample_size - downsample_before_crop = True - - # Calculate new size - original_aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio( - (original_width, original_height) - ) - if resolution_type == "pixel": - (target_width, target_height, new_aspect_ratio) = ( - MultiaspectImage.calculate_new_size_by_pixel_edge( - original_aspect_ratio, resolution - ) - ) - elif resolution_type == "area": - (target_width, target_height, new_aspect_ratio) = ( - MultiaspectImage.calculate_new_size_by_pixel_area( - original_aspect_ratio, resolution - ) - ) - # Convert 'resolution' from eg. "1 megapixel" to "1024 pixels" - resolution = resolution * 1e3 - # Make resolution a multiple of StateTracker.get_args().aspect_bucket_alignment - resolution = MultiaspectImage._round_to_nearest_multiple(resolution) - logger.debug( - f"After area resize, our image will be {target_width}x{target_height} with an overridden resolution of {resolution} pixels." - ) - else: - raise ValueError(f"Unknown resolution type: {resolution_type}") - - crop_style = StateTracker.get_data_backend_config(data_backend_id=id).get( - "crop_style", "random" - ) - crop_aspect = StateTracker.get_data_backend_config(data_backend_id=id).get( - "crop_aspect", "square" - ) - - if crop: - if downsample_before_crop: - logger.debug( - f"Resizing image before crop, as its size is too large. Data backend: {id}, image size: {image.size}, target size: {target_width}x{target_height}" - ) - if image: - image = MultiaspectImage._resize_image( - image, target_width, target_height - ) - elif image_metadata: - image_metadata = MultiaspectImage._resize_image( - None, target_width, target_height, image_metadata - ) - if resolution_type == "area": - # Convert original_resolution back from eg. 1024 pixels to 1.0 mp - original_megapixel_resolution = original_resolution / 1e3 - (target_width, target_height, new_aspect_ratio) = ( - MultiaspectImage.calculate_new_size_by_pixel_area( - original_aspect_ratio, - original_megapixel_resolution, - ) - ) - elif resolution_type == "pixel": - (target_width, target_height, new_aspect_ratio) = ( - MultiaspectImage.calculate_new_size_by_pixel_edge( - original_aspect_ratio, original_resolution - ) - ) - logger.debug( - f"Recalculated target_width and target_height {target_width}x{target_height} based on original_resolution: {original_resolution}" - ) - - logger.debug(f"We are cropping the image. Data backend: {id}") - crop_width, crop_height = ( - (original_resolution, original_resolution) - if crop_aspect == "square" - else (target_width, target_height) - ) - - if image: - if crop_style == "corner": - image, crop_coordinates = MultiaspectImage._crop_corner( - image, crop_width, crop_height - ) - elif crop_style in ["centre", "center"]: - image, crop_coordinates = MultiaspectImage._crop_center( - image, crop_width, crop_height - ) - elif crop_style == "face": - image, crop_coordinates = MultiaspectImage._crop_face( - image, crop_width, crop_height - ) - elif crop_style == "random": - image, crop_coordinates = MultiaspectImage._crop_random( - image, crop_width, crop_height - ) - else: - raise ValueError(f"Unknown crop style: {crop_style}") - elif image_metadata: - if crop_style == "corner": - _, crop_coordinates = MultiaspectImage._crop_corner( - image_metadata=image_metadata, - crop_width=crop_width, - crop_height=crop_height, - ) - elif crop_style in ["centre", "center"]: - _, crop_coordinates = MultiaspectImage._crop_center( - image_metadata=image_metadata, - crop_width=crop_width, - crop_height=crop_height, - ) - elif crop_style == "random" or crop_style == "face": - _, crop_coordinates = MultiaspectImage._crop_random( - image_metadata=image_metadata, - crop_width=crop_width, - crop_height=crop_height, - ) - else: - raise ValueError(f"Unknown crop style: {crop_style}") - - logger.debug(f"After cropping, our image size: {image.size}") - else: - # Resize unconditionally if cropping is not enabled - if image: - image = MultiaspectImage._resize_image( - image, target_width, target_height - ) - crop_coordinates = (0, 0) - - if image: - return image, crop_coordinates, new_aspect_ratio - elif image_metadata: - return (target_width, target_height), crop_coordinates, new_aspect_ratio - - @staticmethod - def _crop_corner( - image: Image = None, - target_width=None, - target_height=None, - image_metadata: dict = None, - ): - """Crop the image from the bottom-right corner.""" - if image: - original_width, original_height = image.size - elif image_metadata: - original_width, original_height = image_metadata["original_size"] - left = max(0, original_width - target_width) - top = max(0, original_height - target_height) - right = original_width - bottom = original_height - if image: - return image.crop((left, top, right, bottom)), (left, top) - elif image_metadata: - return image_metadata, (left, top) - - @staticmethod - def _crop_center( - image: Image = None, - target_width=None, - target_height=None, - image_metadata: dict = None, - ): - """Crop the image from the center.""" - original_width, original_height = image.size - left = (original_width - target_width) / 2 - top = (original_height - target_height) / 2 - right = (original_width + target_width) / 2 - bottom = (original_height + target_height) / 2 - if image: - return image.crop((left, top, right, bottom)), (left, top) - elif image_metadata: - return image_metadata, (left, top) - - @staticmethod - def _crop_random( - image: Image = None, - target_width=None, - target_height=None, - image_metadata: dict = None, - ): - """Crop the image from a random position.""" - original_width, original_height = image.size - left = random.randint(0, max(0, original_width - target_width)) - top = random.randint(0, max(0, original_height - target_height)) - right = left + target_width - bottom = top + target_height - if image: - return image.crop((left, top, right, bottom)), (left, top) - elif image_metadata: - return image_metadata, (left, top) - - @staticmethod - def _crop_face( - image: Image, - target_width: int, - target_height: int, - ): - """Crop the image to include a face, or the most 'interesting' part of the image, without a face.""" - # Import modules - import cv2 - import numpy as np - - # Detect a face in the image - face_cascade = cv2.CascadeClassifier( - cv2.data.haarcascades + "haarcascade_frontalface_default.xml" - ) - image = image.convert("RGB") - image = np.array(image) - gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) - faces = face_cascade.detectMultiScale(gray, 1.1, 4) - if len(faces) > 0: - # Get the largest face - face = max(faces, key=lambda f: f[2] * f[3]) - x, y, w, h = face - left = max(0, x - 0.5 * w) - top = max(0, y - 0.5 * h) - right = min(image.shape[1], x + 1.5 * w) - bottom = min(image.shape[0], y + 1.5 * h) - image = Image.fromarray(image) - return image.crop((left, top, right, bottom)), (left, top) - else: - # Crop the image from a random position - return MultiaspectImage._crop_random(image, target_width, target_height) - @staticmethod def _round_to_nearest_multiple(value): """Round a value to the nearest multiple.""" @@ -308,7 +31,7 @@ def _round_to_nearest_multiple(value): @staticmethod def _resize_image( - input_image: Image, + input_image: Image.Image, target_width: int, target_height: int, image_metadata: dict = None, @@ -439,6 +162,15 @@ def calculate_new_size_by_pixel_edge(aspect_ratio: float, resolution: int): def calculate_new_size_by_pixel_area(aspect_ratio: float, megapixels: float): if type(aspect_ratio) != float: raise ValueError(f"Aspect ratio must be a float, not {type(aspect_ratio)}") + # Special case for 1024px (1.0) megapixel images + if aspect_ratio == 1.0 and megapixels == 1.0: + return 1024, 1024, 1.0 + # Special case for 768px (0.75mp) images + if aspect_ratio == 1.0 and megapixels == 0.75: + return 768, 768, 1.0 + # Special case for 512px (0.5mp) images + if aspect_ratio == 1.0 and megapixels == 0.25: + return 512, 512, 1.0 total_pixels = max(megapixels * 1e6, 1e6) W_initial = int(round((total_pixels * aspect_ratio) ** 0.5)) H_initial = int(round((total_pixels / aspect_ratio) ** 0.5)) @@ -471,29 +203,24 @@ def calculate_image_aspect_ratio(image, rounding: int = 2): float: The rounded aspect ratio of the image. """ to_round = StateTracker.get_args().aspect_bucket_rounding - if not isinstance(image, int): + if to_round is None: to_round = rounding if isinstance(image, Image.Image): + # An actual image was passed in. width, height = image.size elif isinstance(image, tuple): + # An image.size or a similar (W, H) tuple was provided. width, height = image elif isinstance(image, float): + # An externally-calculated aspect ratio was given to round. return round(image, to_round) else: width, height = image.size aspect_ratio = round(width / height, to_round) return aspect_ratio - @staticmethod - def determine_bucket_for_aspect_ratio(aspect_ratio): - """ - Determine the correct bucket for a given aspect ratio. - - Args: - aspect_ratio (float): The aspect ratio of an image. - Returns: - str: The bucket corresponding to the aspect ratio. - """ - # The logic for determining the bucket can be based on the aspect ratio directly - return str(aspect_ratio) +resize_helpers = { + "pixel": MultiaspectImage.calculate_new_size_by_pixel_edge, + "area": MultiaspectImage.calculate_new_size_by_pixel_area, +} diff --git a/helpers/training/collate.py b/helpers/training/collate.py index e3126987..101e52b4 100644 --- a/helpers/training/collate.py +++ b/helpers/training/collate.py @@ -1,10 +1,10 @@ import torch, logging, concurrent.futures, numpy as np +from PIL import Image from os import environ from helpers.training.state_tracker import StateTracker from helpers.training.multi_process import rank_info +from helpers.image_manipulation.training_sample import TrainingSample from helpers.multiaspect.image import MultiaspectImage -from helpers.image_manipulation.brightness import calculate_batch_luminance -from accelerate.logging import get_logger from concurrent.futures import ThreadPoolExecutor logger = logging.getLogger("collate_fn") @@ -75,30 +75,17 @@ def fetch_pixel_values(fp, data_backend_id: str): debug_log( f" -> pull pixels for fp {fp} from cache via data backend {data_backend_id}" ) - pixels = StateTracker.get_data_backend(data_backend_id)["data_backend"].read_image( + image = StateTracker.get_data_backend(data_backend_id)["data_backend"].read_image( fp ) - """ - def prepare_image( - resolution: float, - image: Image = None, - image_metadata: dict = None, - resolution_type: str = "pixel", - id: str = "foo", - ): - - """ - backend_config = StateTracker.get_data_backend_config(data_backend_id) - reformed_image, _, _ = MultiaspectImage.prepare_image( - resolution=backend_config["resolution"], - image=pixels, - image_metadata=None, - resolution_type=backend_config["resolution_type"], - id=data_backend_id, + training_sample = TrainingSample( + image=image, + data_backend_id=data_backend_id, + image_metadata=StateTracker.get_data_backend(data_backend_id)[ + "metadata_backend" + ].get_metadata_by_filepath(fp), ) - image_transform = MultiaspectImage.get_image_transforms()(reformed_image) - - return image_transform + return training_sample.prepare(return_tensor=True).image def fetch_latent(fp, data_backend_id: str): diff --git a/install/rocm/pyproject.toml b/install/rocm/pyproject.toml new file mode 100644 index 00000000..473ba97e --- /dev/null +++ b/install/rocm/pyproject.toml @@ -0,0 +1,51 @@ +[tool.poetry] +name = "simpletuner" +version = "0.1.0" +description = "Stable Diffusion 2.x and XL tuner." +authors = ["bghira"] +license = "SUL" +readme = "README.md" + +[tool.poetry.dependencies] +python = ">=3.9,<3.13" +torch = {version = "^2.3", source = "pytorch-rocm"} +torchaudio = {version = "*", source = "pytorch-rocm"} +torchmetrics = {version = "^1", source = "pytorch-rocm"} +torchvision = {version = "*", source = "pytorch-rocm"} +pytorch-triton-rocm = {version = "*", source = "pytorch-rocm"} +accelerate = "^0.26" +boto3 = "^1" +botocore = "^1" +clip-interrogator = "^0.6" +colorama = "^0.4" +compel = "^2" +dadaptation = "^3" +datasets = "^2" +deepspeed = "^0.14" +diffusers = "^0.27" +iterutils = "^0.1" +numpy = "^1" +open-clip-torch = "^2" +opencv-python = "^4" +pandas = "^2" +peft = "^0.9" +pillow = "^10" +prodigyopt = "^1" +requests = "^2" +safetensors = "^0.4" +scipy = "^1" +tensorboard = "^2" +torchsde = "^0.2" +transformers = "^4" +urllib3 = "<1.27" +wandb = "^0.16" + + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" + +[[tool.poetry.source]] +priority = "explicit" +name = "pytorch-rocm" +url = "https://download.pytorch.org/whl/rocm6.0" diff --git a/sd21-env.sh.example b/sd21-env.sh.example index dab4cd20..03f64f26 100644 --- a/sd21-env.sh.example +++ b/sd21-env.sh.example @@ -67,19 +67,13 @@ export TRACKER_RUN_NAME TRACKER_RUN_NAME="$(date +%s)" # Location of training data. export BASE_DIR="/notebooks/datasets" -export INSTANCE_DIR="${BASE_DIR}/training_data" export OUTPUT_DIR="${BASE_DIR}/models" export DATALOADER_CONFIG="multidatabackend_sd2x.json" -# Some data that we generate will be cached here. -export STATE_PATH="${BASE_DIR}/training_state.json" -# Store whether we've seen an image or not, to prevent repeats. -export SEEN_STATE_PATH="${BASE_DIR}/training_images_seen.json" - -# Max number of steps OR epochs can be used. But we default to Epochs. +# Max number of steps OR epochs can be used. Not both. export MAX_NUM_STEPS=30000 # Will likely overtrain, but that's fine. -export NUM_EPOCHS=25 +export NUM_EPOCHS=0 # Adjust this for your GPU memory size. export TRAIN_BATCH_SIZE=1 diff --git a/sdxl-env.sh.example b/sdxl-env.sh.example index 1649512c..2c719ae7 100644 --- a/sdxl-env.sh.example +++ b/sdxl-env.sh.example @@ -45,10 +45,10 @@ export DEBUG_EXTRA_ARGS="--report_to=wandb" export TRACKER_PROJECT_NAME="sdxl-training" export TRACKER_RUN_NAME="simpletuner-sdxl" -# Max number of steps OR epochs can be used. But we default to Epochs. +# Max number of steps OR epochs can be used. Not both. export MAX_NUM_STEPS=30000 # Will likely overtrain, but that's fine. -export NUM_EPOCHS=25 +export NUM_EPOCHS=0 # A convenient prefix for all of your training paths. export BASE_DIR="/notebooks/datasets" diff --git a/tests/test_cropping.py b/tests/test_cropping.py index f5a73f8b..ab21b0a3 100644 --- a/tests/test_cropping.py +++ b/tests/test_cropping.py @@ -5,16 +5,17 @@ ) # Adjust import according to your project structure -class TestMultiaspectImage(unittest.TestCase): +class TestCropping(unittest.TestCase): def setUp(self): # Creating a sample image for testing self.sample_image = Image.new("RGB", (500, 300), "white") def test_crop_corner(self): target_width, target_height = 300, 200 - cropped_image, (left, top) = MultiaspectImage._crop_corner( - self.sample_image, target_width, target_height - ) + from helpers.image_manipulation.cropping import CornerCropping + + cropper = CornerCropping(self.sample_image) + cropped_image, (left, top) = cropper.crop(target_width, target_height) # Check if cropped coordinates are within original image bounds self.assertTrue(0 <= left < self.sample_image.width) @@ -24,10 +25,11 @@ def test_crop_corner(self): self.assertEqual(cropped_image.size, (target_width, target_height)) def test_crop_center(self): + from helpers.image_manipulation.cropping import CenterCropping + + cropper = CenterCropping(self.sample_image) target_width, target_height = 300, 200 - cropped_image, (left, top) = MultiaspectImage._crop_center( - self.sample_image, target_width, target_height - ) + cropped_image, (left, top) = cropper.crop(target_width, target_height) # Similar checks as above self.assertTrue(0 <= left < self.sample_image.width) @@ -37,9 +39,11 @@ def test_crop_center(self): self.assertEqual(cropped_image.size, (target_width, target_height)) def test_crop_random(self): + from helpers.image_manipulation.cropping import RandomCropping + target_width, target_height = 300, 200 - cropped_image, (left, top) = MultiaspectImage._crop_random( - self.sample_image, target_width, target_height + cropped_image, (left, top) = RandomCropping(self.sample_image).crop( + target_width, target_height ) # Similar checks as above diff --git a/tests/test_image.py b/tests/test_image.py index 2cec4a1e..736d6a4e 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -33,6 +33,7 @@ def test_aspect_ratio_calculation(self): """ Test that the aspect ratio calculation returns expected results. """ + StateTracker.set_args(MagicMock(aspect_bucket_rounding=2)) self.assertEqual( MultiaspectImage.calculate_image_aspect_ratio((1920, 1080)), 1.78 ) @@ -40,163 +41,6 @@ def test_aspect_ratio_calculation(self): MultiaspectImage.calculate_image_aspect_ratio((1080, 1920)), 0.56 ) - def test_image_resize(self): - """ - Test that images are resized to the expected dimensions. - """ - # Define target resolutions and expected output sizes - tests = [ - (1024, "pixel", (1792, 1024), Image.new("RGB", (1920, 1080))), - ( - 1.0, - "area", - (1344, 768), - Image.new("RGB", (1920, 1080)), - ), # Assuming target is 1 megapixel - ( - 256, - "pixel", - (448, 256), - Image.new("RGB", (1920, 1080)), - ), # From log, smaller side to 256, aspect ratio approximated - ( - 256, - "pixel", - (448, 256), - Image.new("RGB", (3840, 2160)), - ), # From log, taller image with smaller side to 256 - ] - with patch("helpers.training.state_tracker.StateTracker.get_args") as mock_args: - mock_args.return_value = Mock( - resolution_type="pixel", - resolution=self.resolution, - crop_style="random", - aspect_bucket_rounding=2, - aspect_bucket_alignment=64, - ) - - for resolution, resolution_type, expected_size, test_image in tests: - resized_image, _, _ = MultiaspectImage.prepare_image( - resolution=resolution, - image=test_image, - resolution_type=resolution_type, - id="test", - ) - - # Verify the size of the resized image - self.assertEqual(resized_image.size, expected_size) - - def test_image_size_consistency(self): - """ - Test that `prepare_image` returns consistent size for images with similar aspect ratios. - """ - # Generate random input aspect ratios and resolutions: - input_aspect_ratios = [random.uniform(0.5, 2.0) for _ in range(10)] - # Sizes should follow the list of resolutions, with between 2-4 images in each aspect - input_sizes = [] - for aspect_ratio in input_aspect_ratios: - count = 0 - for resolution in range(5, 50, 5): - count += 1 - width = resolution * 100 - height = int(width / aspect_ratio) - input_sizes.append((width, height)) - - # Sort into bucket dictionary using MultiaspectImage.calculate_image_aspect_ratio - input_sizes_dict = {} - for size in input_sizes: - aspect_ratio = size[0] / size[1] - if aspect_ratio not in input_sizes_dict: - input_sizes_dict[aspect_ratio] = [] - input_sizes_dict[aspect_ratio].append(size) - - resolutions = range( - 5, 20, 5 - ) # Using a simplified resolution from the logs for the test - with patch("helpers.training.state_tracker.StateTracker.get_args") as mock_args: - mock_args.return_value = Mock( - resolution_type="pixel", - resolution=self.resolution, - crop_style="random", - aspect_bucket_rounding=2, - aspect_bucket_alignment=64, - ) - for aspect_ratio in set(input_sizes_dict.keys()): - for resolution in resolutions: - resolution = resolution / 10 # Convert to megapixels - output_sizes = [] - new_aspect_ratios = [] - for size in input_sizes_dict[aspect_ratio]: - should_use_real_image = random.choice([True, False]) - image = ( - Image.new("RGB", size) if should_use_real_image else None - ) # Creating a dummy PIL image with the given size - image_metadata = ( - None if should_use_real_image else {"original_size": size} - ) - function_result, _, new_aspect_ratio = ( - MultiaspectImage.prepare_image( - image=image, - image_metadata=image_metadata, - resolution=resolution, - resolution_type="area", - ) - ) - if hasattr(function_result, "size"): - output_size = function_result.size - else: - output_size = function_result - output_sizes.append(output_size) - new_aspect_ratios.append(new_aspect_ratio) - - # Check if all output sizes are the same, indicating consistent resizing/cropping - self.assertTrue( - all(size == output_sizes[0] for size in output_sizes), - f"Output sizes are not consistent for {resolution} MP", - ) - self.assertTrue( - all(size == new_aspect_ratios[0] for size in new_aspect_ratios), - f"Output sizes are not consistent for {resolution} MP", - ) - - def test_crop_corner(self): - cropped_image, _ = MultiaspectImage._crop_corner( - self.test_image, self.resolution, self.resolution - ) - self.assertEqual(cropped_image.size, (self.resolution, self.resolution)) - - def test_crop_center(self): - cropped_image, _ = MultiaspectImage._crop_center( - self.test_image, self.resolution, self.resolution - ) - self.assertEqual(cropped_image.size, (self.resolution, self.resolution)) - - def test_crop_random(self): - cropped_image, _ = MultiaspectImage._crop_random( - self.test_image, self.resolution, self.resolution - ) - self.assertEqual(cropped_image.size, (self.resolution, self.resolution)) - - def test_prepare_image_valid(self): - with patch("helpers.training.state_tracker.StateTracker.get_args") as mock_args: - mock_args.return_value = Mock( - resolution_type="pixel", - resolution=self.resolution, - crop_style="random", - aspect_bucket_rounding=2, - aspect_bucket_alignment=64, - ) - prepared_img, crop_coordinates, aspect_ratio = ( - MultiaspectImage.prepare_image( - image=self.test_image, resolution=self.resolution - ) - ) - self.assertIsInstance(prepared_img, Image.Image) - - def test_prepare_image_invalid(self): - with self.assertRaises(Exception): - MultiaspectImage.prepare_image(None, self.resolution) - def test_resize_for_condition_image_valid(self): resized_img = MultiaspectImage._resize_image( self.test_image, self.resolution, self.resolution diff --git a/tests/test_training_sample.py b/tests/test_training_sample.py new file mode 100644 index 00000000..1d9f943f --- /dev/null +++ b/tests/test_training_sample.py @@ -0,0 +1,101 @@ +import unittest +from PIL import Image +import numpy as np +from helpers.image_manipulation.training_sample import TrainingSample +from helpers.training.state_tracker import StateTracker +from helpers.multiaspect.image import resize_helpers, crop_handlers +from unittest.mock import MagicMock + + +class TestTrainingSample(unittest.TestCase): + + def setUp(self): + # Create a simple image for testing + self.image = Image.new("RGB", (1024, 768), "white") + self.data_backend_id = "test_backend" + self.image_metadata = {"original_size": (1024, 768)} + + # Assume StateTracker and other helpers are correctly set up to return meaningful values + StateTracker.get_args = MagicMock() + StateTracker.get_args.return_value = MagicMock(aspect_bucket_alignment=8) + StateTracker.get_data_backend_config = MagicMock( + return_value={ + "crop": True, + "crop_style": "center", + "crop_aspect": "square", + "resolution": 512, + "resolution_type": "pixel", + "target_downsample_size": 256, + "maximum_image_size": 1024, + "aspect_bucket_alignment": 8, + } + ) + + def test_image_initialization(self): + """Test that the image is correctly initialized and converted.""" + sample = TrainingSample(self.image, self.data_backend_id, self.image_metadata) + self.assertEqual(sample.original_size, (1024, 768)) + + def test_image_downsample(self): + """Test that downsampling is correctly applied before cropping.""" + sample = TrainingSample(self.image, self.data_backend_id, self.image_metadata) + sample.prepare() + self.assertLessEqual( + sample.image.size[0], 512 + ) # Assuming downsample before crop applies + + def test_no_crop(self): + """Test handling when cropping is disabled.""" + StateTracker.get_data_backend_config = lambda x: { + "crop": False, + "crop_style": "random", + "resolution": 512, + "resolution_type": "pixel", + } + sample = TrainingSample(self.image, self.data_backend_id, self.image_metadata) + original_size = sample.image.size + sample.prepare() + self.assertNotEqual(sample.image.size, original_size) # Ensure resizing occurs + + def test_crop_coordinates(self): + """Test that cropping returns correct coordinates.""" + sample = TrainingSample(self.image, self.data_backend_id, self.image_metadata) + sample.prepare() + self.assertIsNotNone(sample.crop_coordinates) # Crop coordinates should be set + + def test_aspect_ratio_square_up(self): + """Test that the aspect ratio is preserved after processing.""" + sample = TrainingSample(self.image, self.data_backend_id, self.image_metadata) + original_aspect = sample.original_size[0] / sample.original_size[1] + sample.prepare() + processed_aspect = sample.image.size[0] / sample.image.size[1] + self.assertEqual(processed_aspect, 1.0) + + def test_return_tensor(self): + """Test tensor conversion if requested.""" + sample = TrainingSample(self.image, self.data_backend_id, self.image_metadata) + prepared_sample = sample.prepare(return_tensor=True) + # Check if returned object is a tensor (mock or check type if actual tensor transformation is applied) + self.assertTrue( + isinstance(prepared_sample.aspect_ratio, float) + ) # Placeholder check + + +# Helper mock classes and functions +class MockCropper: + def __init__(self, image, image_metadata): + self.image = image + self.image_metadata = image_metadata + + def crop(self, width, height): + return self.image.crop((0, 0, width, height)), (0, 0, width, height) + + +def mock_resize_helper(aspect_ratio, resolution): + # Simulates resizing logic + width, height = resolution, int(resolution / aspect_ratio) + return width, height, aspect_ratio + + +if __name__ == "__main__": + unittest.main() diff --git a/toolkit/captioning/caption_with_cogvlm.py b/toolkit/captioning/caption_with_cogvlm.py index c979f2b1..eb92f2cb 100644 --- a/toolkit/captioning/caption_with_cogvlm.py +++ b/toolkit/captioning/caption_with_cogvlm.py @@ -114,7 +114,7 @@ def load_filter_list(filter_list_path): def eval_image( - image: Image, + image: Image.Image, model, tokenizer, torch_dtype, diff --git a/toolkit/captioning/caption_with_cogvlm_remote.py b/toolkit/captioning/caption_with_cogvlm_remote.py index 30c1d831..344cd4ac 100644 --- a/toolkit/captioning/caption_with_cogvlm_remote.py +++ b/toolkit/captioning/caption_with_cogvlm_remote.py @@ -118,7 +118,7 @@ def parse_args(): def eval_image( - image: Image, + image: Image.Image, model, tokenizer, torch_dtype, @@ -141,7 +141,7 @@ def eval_image( return tokenizer.decode(outputs[0]) -def eval_image_with_ooba(image: Image, query: str) -> str: +def eval_image_with_ooba(image: Image.Image, query: str) -> str: CONTEXT = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n" img_str = base64.b64encode(image).decode("utf-8") @@ -192,7 +192,7 @@ def encode_images(accelerator, vae, images, image_transform): from math import sqrt -def round_to_nearest_multiple(value, multiple): +def round_to_nearest_multiple(value, multiple: int = 64): """Round a value to the nearest multiple.""" rounded = round(value / multiple) * multiple return max(rounded, multiple) # Ensure it's at least the value of 'multiple' @@ -206,8 +206,8 @@ def calculate_new_size_by_pixel_area(W: int, H: int, megapixels: float = 1.0): H_new = int(round(sqrt(total_pixels / aspect_ratio))) # Ensure they are divisible by 64 - W_new = round_to_nearest_multiple(W_new, 64) - H_new = round_to_nearest_multiple(H_new, 64) + W_new = round_to_nearest_multiple(W_new) + H_new = round_to_nearest_multiple(H_new) return W_new, H_new diff --git a/toolkit/captioning/caption_with_llava.py b/toolkit/captioning/caption_with_llava.py index 41f05bbe..8bf8d1fe 100644 --- a/toolkit/captioning/caption_with_llava.py +++ b/toolkit/captioning/caption_with_llava.py @@ -256,7 +256,7 @@ def process_and_evaluate_image(args, image_path: str, model, processor): else: image = Image.open(image_path) - def resize_for_condition_image(input_image: Image, resolution: int): + def resize_for_condition_image(input_image: Image.Image, resolution: int): if resolution == 0: return input_image input_image = input_image.convert("RGB") diff --git a/toolkit/datasets/analyze_aspect_ratios_json.py b/toolkit/datasets/analyze_aspect_ratios_json.py index 94183cc4..2695661a 100644 --- a/toolkit/datasets/analyze_aspect_ratios_json.py +++ b/toolkit/datasets/analyze_aspect_ratios_json.py @@ -39,7 +39,7 @@ # for file in files_to_delete: # import os # os.remove(file) -def _resize_for_condition_image(self, input_image: Image, resolution: int): +def _resize_for_condition_image(self, input_image: Image.Image, resolution: int): input_image = input_image.convert("RGB") W, H = input_image.size k = float(resolution) / min(H, W) diff --git a/toolkit/datasets/csv_to_s3.py b/toolkit/datasets/csv_to_s3.py index a3aceccc..493f9bf3 100644 --- a/toolkit/datasets/csv_to_s3.py +++ b/toolkit/datasets/csv_to_s3.py @@ -74,7 +74,7 @@ def shuffle_words_in_filename(filename): return "_".join(words) + ext -def resize_for_condition_image(input_image: Image, resolution: int): +def resize_for_condition_image(input_image: Image.Image, resolution: int): if resolution == 0: return input_image input_image = input_image.convert("RGB") @@ -469,7 +469,7 @@ def list_all_s3_objects(s3_client, bucket_name): return existing_files -def upload_pil_to_s3(image: Image, filename, args, s3_client): +def upload_pil_to_s3(image: Image.Image, filename, args, s3_client): """Upload a PIL Image directly to S3 bucket""" if object_exists_in_s3(s3_client, args.aws_bucket_name, filename): return diff --git a/toolkit/datasets/dataset_from_laion.py b/toolkit/datasets/dataset_from_laion.py index eb6ff0be..71d86065 100644 --- a/toolkit/datasets/dataset_from_laion.py +++ b/toolkit/datasets/dataset_from_laion.py @@ -113,7 +113,7 @@ def load_csv(file): timeouts = (conn_timeout, read_timeout) -def _resize_for_condition_image(input_image: Image, resolution: int): +def _resize_for_condition_image(input_image: Image.Image, resolution: int): input_image = input_image.convert("RGB") W, H = input_image.size aspect_ratio = round(W / H, 2) diff --git a/toolkit/datasets/enhance_with_controlnet.py b/toolkit/datasets/enhance_with_controlnet.py index 6e0b2216..4e487693 100644 --- a/toolkit/datasets/enhance_with_controlnet.py +++ b/toolkit/datasets/enhance_with_controlnet.py @@ -4,7 +4,7 @@ from diffusers.utils import load_image -def resize_for_condition_image(input_image: Image, resolution: int): +def resize_for_condition_image(input_image: Image.Image, resolution: int): input_image = input_image.convert("RGB") W, H = input_image.size k = float(resolution) / min(H, W) diff --git a/train_sd21.py b/train_sd21.py index 20f29226..1e0bfb31 100644 --- a/train_sd21.py +++ b/train_sd21.py @@ -104,9 +104,8 @@ tokenizer = None -torch.autograd.set_detect_anomaly(True) # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.26.0.dev0") +check_min_version("0.27.0.dev0") SCHEDULER_NAME_MAP = { @@ -888,7 +887,7 @@ def main(): f"Resuming from epoch {first_epoch}, which leaves us with {total_steps_remaining_at_start}." ) current_epoch = first_epoch - if current_epoch >= args.num_train_epochs: + if current_epoch >= args.num_train_epochs + 1: logger.info( f"Reached the end ({current_epoch} epochs) of our training run ({args.num_train_epochs} epochs). This run will do zero steps." ) @@ -916,7 +915,6 @@ def main(): } }, ) - torch.autograd.set_detect_anomaly(True) logger.info("***** Running training *****") total_num_batches = sum( [ @@ -960,10 +958,10 @@ def main(): current_epoch_step = None for epoch in range(first_epoch, args.num_train_epochs): - if current_epoch >= args.num_train_epochs: + if current_epoch >= args.num_train_epochs + 1: # This might immediately end training, but that's useful for simply exporting the model. logger.info( - f"Reached the end ({current_epoch} epochs) of our training run ({args.num_train_epochs} epochs)." + f"Training run is complete ({args.num_train_epochs}/{args.num_train_epochs} epochs, {global_step}/{args.max_train_steps} steps)." ) break if first_epoch != epoch: @@ -1033,12 +1031,8 @@ def main(): # Add the current batch of training data's avg luminance to a list. training_luminance_values.append(batch["batch_luminance"]) - with accelerator.accumulate( - training_models - ), torch.autograd.detect_anomaly(): - training_logger.debug( - f"Sending latent batch from pinned memory to device" - ) + with accelerator.accumulate(training_models): + training_logger.debug(f"Sending latent batch to GPU") latents = batch["latent_batch"].to( accelerator.device, dtype=weight_dtype ) @@ -1080,7 +1074,7 @@ def main(): # Prepare the data for the scatter plot for timestep in timesteps.tolist(): - timesteps_buffer.append((step, timestep)) + timesteps_buffer.append((global_step, timestep)) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -1121,10 +1115,10 @@ def main(): f"\n -> Noise device: {noise.device}" f"\n -> Timesteps device: {timesteps.device}" f"\n -> Encoder hidden states device: {encoder_hidden_states.device}" - f"\n -> Latents dtype: {latents.dtype}" - f"\n -> Noise dtype: {noise.dtype}" + f"\n -> Latents dtype: {latents.dtype}, shape: {latents.shape if hasattr(latents, 'shape') else 'None'}" + f"\n -> Noise dtype: {noise.dtype}, shape: {noise.shape if hasattr(noise, 'shape') else 'None'}" f"\n -> Timesteps dtype: {timesteps.dtype}" - f"\n -> Encoder hidden states dtype: {encoder_hidden_states.dtype}" + f"\n -> Encoder hidden states dtype: {encoder_hidden_states.dtype}, shape: {encoder_hidden_states.shape if hasattr(encoder_hidden_states, 'shape') else 'None'}" ) if unwrap_model(accelerator, unet).config.in_channels == channels * 2: # deepfloyd stage ii requires the inputs to be doubled. note that we're working in pixels, not latents. @@ -1391,13 +1385,12 @@ def main(): SCHEDULER_NAME_MAP=SCHEDULER_NAME_MAP, ) - if global_step >= args.max_train_steps or epoch > args.num_train_epochs: + if global_step >= args.max_train_steps or epoch > args.num_train_epochs + 1: logger.info( - f"Training has completed.", - f"\n -> global_step = {global_step}, max_train_steps = {args.max_train_steps}, epoch = {epoch}, num_train_epochs = {args.num_train_epochs}", + f"Training run is complete ({args.num_train_epochs}/{args.num_train_epochs} epochs, {global_step}/{args.max_train_steps} steps)." ) break - if global_step >= args.max_train_steps or epoch > args.num_train_epochs: + if global_step >= args.max_train_steps or epoch > args.num_train_epochs + 1: logger.info( f"Exiting training loop. Beginning model unwind at epoch {epoch}, step {global_step}" ) diff --git a/train_sd2x.sh b/train_sd2x.sh index dab53b40..b2310403 100644 --- a/train_sd2x.sh +++ b/train_sd2x.sh @@ -299,10 +299,15 @@ if [ -n "$ASPECT_BUCKET_ROUNDING" ]; then export ASPECT_BUCKET_ROUNDING_ARGS="--aspect_bucket_rounding=${ASPECT_BUCKET_ROUNDING}" fi +export MAX_TRAIN_STEPS_ARGS="" +if [ -n "$MAX_NUM_STEPS" ] && [[ "$MAX_NUM_STEPS" != 0 ]]; then + export MAX_TRAIN_STEPS_ARGS="--max_train_steps=${MAX_NUM_STEPS}" +fi + accelerate launch ${ACCELERATE_EXTRA_ARGS} --mixed_precision="${MIXED_PRECISION}" --num_processes="${TRAINING_NUM_PROCESSES}" --num_machines="${TRAINING_NUM_MACHINES}" --dynamo_backend="${TRAINING_DYNAMO_BACKEND}" train_sd21.py \ --model_type="${MODEL_TYPE}" ${DORA_ARGS} --pretrained_model_name_or_path="${MODEL_NAME}" ${XFORMERS_ARG} ${GRADIENT_ARG} --set_grads_to_none --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ --resume_from_checkpoint="${RESUME_CHECKPOINT}" ${DELETE_ARGS} ${SNR_GAMMA_ARG} --data_backend_config="${DATALOADER_CONFIG}" ${OVERRIDE_DATALOADER_CONFIG_ARG} \ - --num_train_epochs=${NUM_EPOCHS} --max_train_steps=${MAX_NUM_STEPS} --metadata_update_interval=${METADATA_UPDATE_INTERVAL} \ + --num_train_epochs=${NUM_EPOCHS} ${MAX_TRAIN_STEPS_ARGS} --metadata_update_interval=${METADATA_UPDATE_INTERVAL} \ --learning_rate="${LEARNING_RATE}" --lr_scheduler="${LR_SCHEDULE}" --seed "${TRAINING_SEED}" --lr_warmup_steps="${LR_WARMUP_STEPS}" --lr_end="${LEARNING_RATE_END}" \ --output_dir="${OUTPUT_DIR}" \ --inference_scheduler_timestep_spacing="${INFERENCE_SCHEDULER_TIMESTEP_SPACING}" --training_scheduler_timestep_spacing="${TRAINING_SCHEDULER_TIMESTEP_SPACING}" \ diff --git a/train_sdxl.py b/train_sdxl.py index b348e1a1..21ecb700 100644 --- a/train_sdxl.py +++ b/train_sdxl.py @@ -86,9 +86,8 @@ from diffusers.utils.import_utils import is_xformers_available from transformers.utils import ContextManagers -torch.autograd.set_detect_anomaly(True) # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.26.0.dev0") +check_min_version("0.27.0.dev0") SCHEDULER_NAME_MAP = { "euler": EulerDiscreteScheduler, @@ -984,7 +983,7 @@ def main(): current_epoch = first_epoch StateTracker.set_epoch(current_epoch) - if current_epoch > args.num_train_epochs: + if current_epoch > args.num_train_epochs + 1: logger.info( f"Reached the end ({current_epoch} epochs) of our training run ({args.num_train_epochs} epochs). This run will do zero steps." ) @@ -1054,10 +1053,10 @@ def main(): current_epoch_step = None for epoch in range(first_epoch, args.num_train_epochs + 1): - if current_epoch > args.num_train_epochs: + if current_epoch > args.num_train_epochs + 1: # This might immediately end training, but that's useful for simply exporting the model. logger.info( - f"Reached the end ({current_epoch} epochs) of our training run ({args.num_train_epochs} epochs)." + f"Training run is complete ({args.num_train_epochs}/{args.num_train_epochs} epochs, {global_step}/{args.max_train_steps} steps)." ) break if first_epoch != epoch: @@ -1130,9 +1129,7 @@ def main(): training_luminance_values.append(batch["batch_luminance"]) with accelerator.accumulate(training_models): - training_logger.debug( - f"Sending latent batch from pinned memory to device." - ) + training_logger.debug(f"Sending latent batch to GPU.") latents = batch["latent_batch"].to( accelerator.device, dtype=weight_dtype ) @@ -1175,7 +1172,7 @@ def main(): # Prepare the data for the scatter plot for timestep in timesteps.tolist(): - timesteps_buffer.append((step, timestep)) + timesteps_buffer.append((global_step, timestep)) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) diff --git a/train_sdxl.sh b/train_sdxl.sh index 15c20d45..b7bd28f1 100644 --- a/train_sdxl.sh +++ b/train_sdxl.sh @@ -280,11 +280,16 @@ if [ -n "$ASPECT_BUCKET_ROUNDING" ]; then export ASPECT_BUCKET_ROUNDING_ARGS="--aspect_bucket_rounding=${ASPECT_BUCKET_ROUNDING}" fi +export MAX_TRAIN_STEPS_ARGS="" +if [ -n "$MAX_NUM_STEPS" ] && [[ "$MAX_NUM_STEPS" != 0 ]]; then + export MAX_TRAIN_STEPS_ARGS="--max_train_steps=${MAX_NUM_STEPS}" +fi + # Run the training script. accelerate launch ${ACCELERATE_EXTRA_ARGS} --mixed_precision="${MIXED_PRECISION}" --num_processes="${TRAINING_NUM_PROCESSES}" --num_machines="${TRAINING_NUM_MACHINES}" --dynamo_backend="${TRAINING_DYNAMO_BACKEND}" train_sdxl.py \ --model_type="${MODEL_TYPE}" ${DORA_ARGS} --pretrained_model_name_or_path="${MODEL_NAME}" ${XFORMERS_ARG} ${GRADIENT_ARG} --set_grads_to_none --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ --resume_from_checkpoint="${RESUME_CHECKPOINT}" ${DELETE_ARGS} ${SNR_GAMMA_ARG} --data_backend_config="${DATALOADER_CONFIG}" \ - --num_train_epochs=${NUM_EPOCHS} --max_train_steps=${MAX_NUM_STEPS} --metadata_update_interval=${METADATA_UPDATE_INTERVAL} \ + --num_train_epochs=${NUM_EPOCHS} ${MAX_TRAIN_STEPS_ARGS} --metadata_update_interval=${METADATA_UPDATE_INTERVAL} \ ${OPTIMIZER_ARG} --learning_rate="${LEARNING_RATE}" --lr_scheduler="${LR_SCHEDULE}" --seed "${TRAINING_SEED}" --lr_warmup_steps="${LR_WARMUP_STEPS}" \ --output_dir="${OUTPUT_DIR}" ${BITFIT_ARGS} ${ASPECT_BUCKET_ROUNDING_ARGS} \ --inference_scheduler_timestep_spacing="${INFERENCE_SCHEDULER_TIMESTEP_SPACING}" --training_scheduler_timestep_spacing="${TRAINING_SCHEDULER_TIMESTEP_SPACING}" \