From 94cae4ac7a1c378ba2d3d0e5ee06dd7a2530f39f Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 12 Oct 2024 22:02:45 -0600 Subject: [PATCH 1/3] fix controlnet training for sdxl and introduce masked loss preconditioning --- documentation/DATALOADER.md | 14 +++- documentation/DREAMBOOTH.md | 67 +++++++++++++++ documentation/quickstart/FLUX.md | 8 +- helpers/data_backend/factory.py | 5 ++ helpers/image_manipulation/training_sample.py | 8 ++ helpers/multiaspect/sampler.py | 26 +++++- helpers/training/collate.py | 29 +++++-- .../training/default_settings/safety_check.py | 1 + helpers/training/state_tracker.py | 2 +- helpers/training/trainer.py | 81 +++++++++++-------- train.py | 2 +- 11 files changed, 195 insertions(+), 48 deletions(-) diff --git a/documentation/DATALOADER.md b/documentation/DATALOADER.md index eb239026..b9dfc4cd 100644 --- a/documentation/DATALOADER.md +++ b/documentation/DATALOADER.md @@ -47,8 +47,8 @@ Here is the most basic example of a dataloader configuration file, as `multidata ### `dataset_type` -- **Values:** `image` | `text_embeds` | `image_embeds` -- **Description:** `image` datasets contain your training data. `text_embeds` contain the outputs of the text encoder cache, and `image_embeds` contain the VAE outputs, if the model uses one. +- **Values:** `image` | `text_embeds` | `image_embeds` | `conditioning` +- **Description:** `image` datasets contain your training data. `text_embeds` contain the outputs of the text encoder cache, and `image_embeds` contain the VAE outputs, if the model uses one. When a dataset is marked as `conditioning`, it is possible to pair it to your `image` dataset via [the conditioning_data option](#conditioning_data) - **Note:** Text and image embed datasets are defined differently than image datasets are. A text embed dataset stores ONLY the text embed objects. An image dataset stores the training data. ### `default` @@ -71,6 +71,16 @@ Here is the most basic example of a dataloader configuration file, as `multidata - **Values:** `aws` | `local` | `csv` - **Description:** Determines the storage backend (local, csv or cloud) used for this dataset. +### `conditioning_type` + +- **Values:** `controlnet` | `mask` +- **Description:** A dataset may contain ControlNet conditioning inputs or masks to use during loss calculations. Only one or the other may be used. + +### `conditioning_data` + +- **Values:** `id` value of conditioning dataset +- **Description:** As described in [the ControlNet guide](/documentation/CONTROLNET.md), an `image` dataset can be paired to its ControlNet or image mask data via this option. + ### `instance_data_dir` / `aws_data_prefix` - **Local:** Path to the data on the filesystem. diff --git a/documentation/DREAMBOOTH.md b/documentation/DREAMBOOTH.md index c276090b..f42832f8 100644 --- a/documentation/DREAMBOOTH.md +++ b/documentation/DREAMBOOTH.md @@ -30,6 +30,73 @@ The model contains something called a "prior" which could, in theory, be preserv > 🟢 ([#1031](https://github.com/bghira/SimpleTuner/issues/1031)) Prior preservation loss is supported in SimpleTuner when training LyCORIS adapters by setting `is_regularisation_data` on that dataset. +### Masked loss + +Image masks may be defined in pairs with image data. The dark portions of the mask will cause the loss calculations to ignore these parts of the image. + +An example [script](/toolkit/datasets/masked_loss/generate_dataset_masks.py) exists to generate these masks, given an input_dir and output_dir: + +```bash +python generate_dataset_masks.py --input_dir /images/input \ + --output_dir /images/output \ + --text_input "person" +``` + +However, this does not have any advanced functionality such as mask padding blurring. + +When defining your image mask dataset: + +- Every image must have a mask. Use an all-white image if you do not want to mask. +- Set `dataset_type=conditioning` on your conditioning (mask) data folder +- Set `conditioning_type=mask` on your mask dataset +- Set `conditioning_data=` to your conditioning dataset `id` on your image dataset + +```json +[ + { + "id": "dreambooth-data", + "type": "local", + "dataset_type": "image", + "conditioning_data": "dreambooth-conditioning", + "instance_data_dir": "/training/datasets/test_datasets/dreambooth", + "cache_dir_vae": "/training/cache/vae/sdxl/dreambooth-data", + "caption_strategy": "instanceprompt", + "instance_prompt": "an dreambooth", + "metadata_backend": "discovery", + "resolution": 1024, + "minimum_image_size": 1024, + "maximum_image_size": 1024, + "target_downsample_size": 1024, + "crop": true, + "crop_aspect": "square", + "crop_style": "center", + "resolution_type": "pixel_area" + }, + { + "id": "dreambooth-conditioning", + "type": "local", + "dataset_type": "conditioning", + "instance_data_dir": "/training/datasets/test_datasets/dreambooth_mask", + "resolution": 1024, + "minimum_image_size": 1024, + "maximum_image_size": 1024, + "target_downsample_size": 1024, + "crop": true, + "crop_aspect": "square", + "crop_style": "center", + "resolution_type": "pixel_area", + "conditioning_type": "mask" + }, + { + "id": "an example backend for text embeds.", + "dataset_type": "text_embeds", + "default": true, + "type": "local", + "cache_dir": "/training/cache/text/sdxl-base/masked_loss" + } +] +``` + ## Setup Following the [tutorial](/TUTORIAL.md) is required before you can continue into Dreambooth-specific configuration. diff --git a/documentation/quickstart/FLUX.md b/documentation/quickstart/FLUX.md index 52880950..d4a4cf33 100644 --- a/documentation/quickstart/FLUX.md +++ b/documentation/quickstart/FLUX.md @@ -428,6 +428,10 @@ NF4 does not work with torch.compile, so whatever you get for speed is what you If VRAM is not a concern (eg. 48G or greater) then int8 with torch.compile is your best, fastest option. +### Masked loss + +If you are training a subject or style and would like to mask one or the other, see the [masked loss training](/documentation/DREAMBOOTH.md#masked-loss) section of the Dreambooth guide. + ### Classifier-free guidance #### Problem @@ -490,7 +494,9 @@ We can partially reintroduce distillation to a de-distilled model by continuing #### LoKr (--lora_type=lycoris) - Higher learning rates are better for LoKr (`1e-3` with AdamW, `2e-4` with Lion) - Other algo need more exploration. -- Setting `is_regularisation_data` on such datasets may help preserve / prevent bleed. +- Setting `is_regularisation_data` on such datasets may help preserve / prevent bleed and improve the final resulting model's quality. + - This behaves differently from "prior loss preservation" which is known for doubling training batch sizes and not improving the result much + - SimpleTuner's regularisation data implementation provides an efficient manner of preserving the base model ### Image artifacts Flux will immediately absorb bad image artifacts. It's just how it is - a final training run on just high quality data may be required to fix it at the end. diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index f80c11c3..687a4d0a 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -873,6 +873,10 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize "Increasing resolution to 256, as is required for DF Stage II." ) + conditioning_type = None + if backend.get("dataset_type") == "conditioning": + conditioning_type = backend.get("conditioning_type", "controlnet") + init_backend["sampler"] = MultiAspectSampler( id=init_backend["id"], metadata_backend=init_backend["metadata_backend"], @@ -891,6 +895,7 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize "prepend_instance_prompt", args.prepend_instance_prompt ), instance_prompt=backend.get("instance_prompt", args.instance_prompt), + conditioning_type=conditioning_type, ) if init_backend["sampler"].caption_strategy == "parquet": configure_parquet_database(backend, args, init_backend["data_backend"]) diff --git a/helpers/image_manipulation/training_sample.py b/helpers/image_manipulation/training_sample.py index c95e1347..d5380e20 100644 --- a/helpers/image_manipulation/training_sample.py +++ b/helpers/image_manipulation/training_sample.py @@ -25,6 +25,7 @@ def __init__( data_backend_id: str, image_metadata: dict = None, image_path: str = None, + conditioning_type: str = None, ): """ Initializes a new TrainingSample instance with a provided PIL.Image object and a data backend identifier. @@ -38,6 +39,7 @@ def __init__( self.target_size = None self.intermediary_size = None self.original_size = None + self.conditioning_type = conditioning_type self.data_backend_id = data_backend_id self.image_metadata = ( image_metadata @@ -601,6 +603,12 @@ def get_image(self): """ return self.image + def is_conditioning_sample(self): + return self.conditioning_type is not None + + def get_conditioning_type(self): + return self.conditioning_type + def get_conditioning_image(self): """ Fetch a conditioning image, eg. a canny edge map for ControlNet training. diff --git a/helpers/multiaspect/sampler.py b/helpers/multiaspect/sampler.py index 8d7d25a7..90f51379 100644 --- a/helpers/multiaspect/sampler.py +++ b/helpers/multiaspect/sampler.py @@ -38,6 +38,7 @@ def __init__( use_captions=True, prepend_instance_prompt=False, instance_prompt: str = None, + conditioning_type: str = None, ): """ Initializes the sampler with provided settings. @@ -60,6 +61,13 @@ def __init__( f"MultiAspectSampler-{self.id}", os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"), ) + if conditioning_type is not None: + if conditioning_type not in ["controlnet", "mask"]: + raise ValueError( + f"Unknown conditioning image type: {conditioning_type}" + ) + self.conditioning_type = conditioning_type + self.rank_info = rank_info() self.accelerator = accelerator self.metadata_backend = metadata_backend @@ -446,19 +454,30 @@ def get_conditioning_sample(self, original_sample_path: str) -> str: full_path = os.path.join( self.metadata_backend.instance_data_dir, original_sample_path ) + try: + conditioning_sample_data = self.data_backend.read_image(full_path) + except Exception as e: + self.logger.error(f"Could not fetch conditioning sample: {e}") + + return None + if not conditioning_sample_data: + self.debug_log(f"Could not fetch conditioning sample from {full_path}.") + return None + conditioning_sample = TrainingSample( - image=self.data_backend.read_image(full_path), + image=conditioning_sample_data, data_backend_id=self.id, image_metadata=self.metadata_backend.get_metadata_by_filepath(full_path), image_path=full_path, + conditioning_type=self.conditioning_type, ) return conditioning_sample def connect_conditioning_samples(self, samples: tuple): - if not StateTracker.get_args().controlnet: - return samples # Locate the conditioning data conditioning_dataset = StateTracker.get_conditioning_dataset(self.id) + if conditioning_dataset is None: + return samples sampler = conditioning_dataset["sampler"] outputs = list(samples) for sample in samples: @@ -540,6 +559,7 @@ def __iter__(self): [instance["image_path"] for instance in final_yield] ) self.accelerator.wait_for_everyone() + # if applicable, we'll append TrainingSample(s) to the end for conditioning inputs. final_yield = self.connect_conditioning_samples(final_yield) yield tuple(final_yield) # Change bucket after a full batch is yielded diff --git a/helpers/training/collate.py b/helpers/training/collate.py index 3ece5331..ac919245 100644 --- a/helpers/training/collate.py +++ b/helpers/training/collate.py @@ -416,17 +416,35 @@ def collate_fn(batch): ) conditioning_filepaths = [] - conditioning_latents = None + conditioning_pixel_values = None + conditioning_type = None if len(conditioning_examples) > 0: for example in conditioning_examples: # Building the list of conditioning image filepaths. + if conditioning_type is not None: + if example.get_conditioning_type() != conditioning_type: + raise ValueError( + f"Conditioning type mismatch: {conditioning_type} != {example.get_conditioning_type()}" + "\n-> Ensure all conditioning samples are of the same type." + ) + else: + conditioning_type = example.get_conditioning_type() + if conditioning_type == "mask" and len(conditioning_examples) != len( + examples + ): + raise ValueError( + f"Masks seem to be missing for some of the following images: {examples}" + f"\n-> Ensure all images have a corresponding mask: {[example.image_path() for example in conditioning_examples]}" + ) conditioning_filepaths.append(example.image_path(basename_only=False)) # Use the poorly-named method to retrieve the image pixel values - conditioning_latents = deepfloyd_pixels(conditioning_filepaths, data_backend_id) - conditioning_latents = torch.stack( + conditioning_pixel_values = deepfloyd_pixels( + conditioning_filepaths, data_backend_id + ) + conditioning_pixel_values = torch.stack( [ latent.to(StateTracker.get_accelerator().device) - for latent in conditioning_latents + for latent in conditioning_pixel_values ] ) @@ -474,7 +492,8 @@ def collate_fn(batch): "add_text_embeds": add_text_embeds_all, "batch_time_ids": batch_time_ids, "batch_luminance": batch_luminance, - "conditioning_pixel_values": conditioning_latents, + "conditioning_pixel_values": conditioning_pixel_values, "encoder_attention_mask": attn_mask, "is_regularisation_data": is_regularisation_data, + "conditioning_type": conditioning_type, } diff --git a/helpers/training/default_settings/safety_check.py b/helpers/training/default_settings/safety_check.py index 15e3ae81..656c102b 100644 --- a/helpers/training/default_settings/safety_check.py +++ b/helpers/training/default_settings/safety_check.py @@ -99,6 +99,7 @@ def safety_check(args, accelerator): if ( args.model_type != "lora" + and not args.controlnet and args.base_model_precision != "no_change" and not args.i_know_what_i_am_doing ): diff --git a/helpers/training/state_tracker.py b/helpers/training/state_tracker.py index 2a72c3fb..d069a2fb 100644 --- a/helpers/training/state_tracker.py +++ b/helpers/training/state_tracker.py @@ -407,7 +407,7 @@ def set_conditioning_dataset( @classmethod def get_conditioning_dataset(cls, data_backend_id: str): - return cls.data_backends[data_backend_id]["conditioning_data"] + return cls.data_backends[data_backend_id].get("conditioning_data", None) @classmethod def get_data_backend_config(cls, data_backend_id: str): diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index 5f452eb2..7f65c738 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -338,6 +338,24 @@ def _misc_init(self): self.config.use_deepspeed_optimizer, self.config.use_deepspeed_scheduler = ( prepare_model_for_deepspeed(self.accelerator, self.config) ) + self.config.base_weight_dtype = self.config.weight_dtype + self.config.is_quanto = False + self.config.is_torchao = False + self.config.is_bnb = False + if "quanto" in self.config.base_model_precision: + self.config.is_quanto = True + elif "torchao" in self.config.base_model_precision: + self.config.is_torchao = True + elif "bnb" in self.config.base_model_precision: + self.config.is_bnb = True + if self.config.is_quanto: + from helpers.training.quantisation import quantise_model + + self.quantise_model = quantise_model + elif self.config.is_torchao: + from helpers.training.quantisation import quantise_model + + self.quantise_model = quantise_model def set_model_family(self, model_family: str = None): model_family = getattr(self.config, "model_family", model_family) @@ -731,16 +749,6 @@ def init_precision(self): self.config.enable_adamw_bf16 = ( True if self.config.weight_dtype == torch.bfloat16 else False ) - self.config.base_weight_dtype = self.config.weight_dtype - self.config.is_quanto = False - self.config.is_torchao = False - self.config.is_bnb = False - if "quanto" in self.config.base_model_precision: - self.config.is_quanto = True - elif "torchao" in self.config.base_model_precision: - self.config.is_torchao = True - elif "bnb" in self.config.base_model_precision: - self.config.is_bnb = True quantization_device = ( "cpu" if self.config.quantize_via == "cpu" else self.accelerator.device ) @@ -770,11 +778,8 @@ def init_precision(self): ) if self.config.is_quanto: - from helpers.training.quantisation import quantise_model - - self.quantise_model = quantise_model with self.accelerator.local_main_process_first(): - quantise_model( + self.quantise_model( unet=self.unet, transformer=self.transformer, text_encoder_1=self.text_encoder_1, @@ -784,9 +789,6 @@ def init_precision(self): args=self.config, ) elif self.config.is_torchao: - from helpers.training.quantisation import quantise_model - - self.quantise_model = quantise_model with self.accelerator.local_main_process_first(): ( self.unet, @@ -795,7 +797,7 @@ def init_precision(self): self.text_encoder_2, self.text_encoder_3, self.controlnet, - ) = quantise_model( + ) = self.quantise_model( unet=self.unet, transformer=self.transformer, text_encoder_1=self.text_encoder_1, @@ -811,24 +813,13 @@ def init_controlnet_model(self): logger.info("Creating the controlnet..") if self.config.controlnet_model_name_or_path: logger.info("Loading existing controlnet weights") - controlnet = ControlNetModel.from_pretrained( + self.controlnet = ControlNetModel.from_pretrained( self.config.controlnet_model_name_or_path ) else: logger.info("Initializing controlnet weights from unet") - controlnet = ControlNetModel.from_unet(self.unet) - if "quanto" in self.config.base_model_precision: - # since controlnet training uses no adapter currently, we just quantise the base transformer here. - with self.accelerator.local_main_process_first(): - self.quantise_model( - unet=self.unet, - transformer=self.transformer, - text_encoder_1=self.text_encoder_1, - text_encoder_2=self.text_encoder_2, - text_encoder_3=self.text_encoder_3, - controlnet=None, - args=self.config, - ) + self.controlnet = ControlNetModel.from_unet(self.unet) + self.accelerator.wait_for_everyone() def init_trainable_peft_adapter(self): @@ -1626,6 +1617,9 @@ def move_models(self, destination: str = "accelerator"): if self.config.controlnet: self.controlnet.train() + logger.info( + f"Moving ControlNet to {target_device} in {self.config.weight_dtype} precision." + ) self.controlnet.to(device=target_device, dtype=self.config.weight_dtype) if self.config.train_text_encoder: logger.warning( @@ -2420,11 +2414,11 @@ def train(self): target.shape[0], -1 ), 1, - ).mean() + ) elif self.config.snr_gamma is None or self.config.snr_gamma == 0: training_logger.debug("Calculating loss") loss = self.config.snr_weight * F.mse_loss( - model_pred.float(), target.float(), reduction="mean" + model_pred.float(), target.float(), reduction="none" ) else: # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. @@ -2466,8 +2460,25 @@ def train(self): loss = ( loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights - ).mean() + ) + + # Mask the loss using any conditioning data + conditioning_type = batch.get("conditioning_type") + if conditioning_type == "mask": + # adapted from: https://github.com/kohya-ss/sd-scripts/blob/main/library/custom_train_functions.py#L482 + mask_image = ( + batch["conditioning_pixel_values"] + .to(dtype=loss.dtype, device=loss.device)[:, 0] + .unsqueeze(1) + ) + mask_image = torch.nn.functional.interpolate( + mask_image, size=loss.shape[2:], mode="area" + ) + mask_image = mask_image / 2 + 0.5 + loss = loss * mask_image + # reduce loss now + loss = loss.mean() if is_regularisation_data: parent_loss = loss diff --git a/train.py b/train.py index b3c72a18..7e00e46b 100644 --- a/train.py +++ b/train.py @@ -33,8 +33,8 @@ trainer.init_unload_vae() trainer.init_load_base_model() - trainer.init_precision() trainer.init_controlnet_model() + trainer.init_precision() trainer.init_freeze_models() trainer.init_trainable_peft_adapter() trainer.init_ema_model() From e0ca7345c7d2c2ad05b400dc7367c58827e5324a Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 12 Oct 2024 22:11:45 -0600 Subject: [PATCH 2/3] reduce logspam from student-teacher attach-detach --- helpers/training/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index 7f65c738..52ebd6e1 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -2365,7 +2365,7 @@ def train(self): handled_regularisation = True with torch.no_grad(): if self.config.lora_type.lower() == "lycoris": - logger.info( + training_logger.debug( "Detaching LyCORIS adapter for parent prediction." ) self.accelerator._lycoris_wrapped_network.restore() @@ -2383,7 +2383,7 @@ def train(self): timesteps=timesteps, ) if self.config.lora_type.lower() == "lycoris": - logger.info( + training_logger.debug( "Attaching LyCORIS adapter for student prediction." ) self.accelerator._lycoris_wrapped_network.apply_to() From 35dd23c7fca8d49b8177f912ebfc681f398643b2 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 12 Oct 2024 22:16:27 -0600 Subject: [PATCH 3/3] update feature list and add masked data generator script --- README.md | 2 + .../masked_loss/generate_dataset_masks.py | 81 +++++++++++++++++++ 2 files changed, 83 insertions(+) create mode 100644 toolkit/datasets/masked_loss/generate_dataset_masks.py diff --git a/README.md b/README.md index 5c770f56..b332a696 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,8 @@ For multi-node distributed training, [this guide](/documentation/DISTRIBUTED.md) - Train directly from an S3-compatible storage provider, eliminating the requirement for expensive local storage. (Tested with Cloudflare R2 and Wasabi S3) - For only SDXL and SD 1.x/2.x, full [ControlNet model training](/documentation/CONTROLNET.md) (not ControlLoRA or ControlLite) - Training [Mixture of Experts](/documentation/MIXTURE_OF_EXPERTS.md) for lightweight, high-quality diffusion models +- [Masked loss training](/documentation/DREAMBOOTH.md#masked-loss) for superior convergence and reduced overfitting on any model +- Strong [prior regularisation](/documentation/DATALOADER.md#is_regularisation_data) training support for LyCORIS models - Webhook support for updating eg. Discord channels with your training progress, validations, and errors - Integration with the [Hugging Face Hub](https://huggingface.co) for seamless model upload and nice automatically-generated model cards. diff --git a/toolkit/datasets/masked_loss/generate_dataset_masks.py b/toolkit/datasets/masked_loss/generate_dataset_masks.py new file mode 100644 index 00000000..af3c6a3d --- /dev/null +++ b/toolkit/datasets/masked_loss/generate_dataset_masks.py @@ -0,0 +1,81 @@ +import argparse +import os +import shutil +from gradio_client import Client, handle_file + + +def main(): + # Set up argument parser + parser = argparse.ArgumentParser( + description="Mask images in a directory using Florence SAM Masking." + ) + parser.add_argument( + "--input_dir", + type=str, + required=True, + help="Path to the input directory containing images.", + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Path to the output directory to save masked images.", + ) + parser.add_argument( + "--text_input", + type=str, + default="person", + help='Text prompt for masking (default: "person").', + ) + parser.add_argument( + "--model", + type=str, + default="SkalskiP/florence-sam-masking", + help='Model name to use (default: "SkalskiP/florence-sam-masking").', + ) + args = parser.parse_args() + + input_path = args.input_dir + output_path = args.output_dir + text_input = args.text_input + model_name = args.model + + # Create the output directory if it doesn't exist + os.makedirs(output_path, exist_ok=True) + + # Initialize the Gradio client + client = Client(model_name) + + # Get all files in the input directory + files = os.listdir(input_path) + + # Iterate over all files + for file in files: + # Construct the full file path + full_path = os.path.join(input_path, file) + # Check if the file is an image + if os.path.isfile(full_path) and full_path.lower().endswith( + (".jpg", ".jpeg", ".png", ".webp") + ): + # Define the path for the output mask + mask_path = os.path.join(output_path, file) + # Skip if the mask already exists + if os.path.exists(mask_path): + print(f"Mask already exists for {file}, skipping.") + continue + # Predict the mask + try: + mask_filename = client.predict( + image_input=handle_file(full_path), + text_input=text_input, + api_name="/process_image", + ) + # Move the generated mask to the output directory + shutil.move(mask_filename, mask_path) + print(f"Saved mask to {mask_path}") + except Exception as e: + print(f"Failed to process {file}: {e}") + + +if __name__ == "__main__": + main()