Skip to content

Commit

Permalink
Merge pull request #1054 from bghira/main
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
bghira authored Oct 13, 2024
2 parents 47effc5 + 50f64ab commit dddaf4f
Show file tree
Hide file tree
Showing 13 changed files with 280 additions and 50 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ For multi-node distributed training, [this guide](/documentation/DISTRIBUTED.md)
- Train directly from an S3-compatible storage provider, eliminating the requirement for expensive local storage. (Tested with Cloudflare R2 and Wasabi S3)
- For only SDXL and SD 1.x/2.x, full [ControlNet model training](/documentation/CONTROLNET.md) (not ControlLoRA or ControlLite)
- Training [Mixture of Experts](/documentation/MIXTURE_OF_EXPERTS.md) for lightweight, high-quality diffusion models
- [Masked loss training](/documentation/DREAMBOOTH.md#masked-loss) for superior convergence and reduced overfitting on any model
- Strong [prior regularisation](/documentation/DATALOADER.md#is_regularisation_data) training support for LyCORIS models
- Webhook support for updating eg. Discord channels with your training progress, validations, and errors
- Integration with the [Hugging Face Hub](https://huggingface.co) for seamless model upload and nice automatically-generated model cards.

Expand Down
14 changes: 12 additions & 2 deletions documentation/DATALOADER.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ Here is the most basic example of a dataloader configuration file, as `multidata

### `dataset_type`

- **Values:** `image` | `text_embeds` | `image_embeds`
- **Description:** `image` datasets contain your training data. `text_embeds` contain the outputs of the text encoder cache, and `image_embeds` contain the VAE outputs, if the model uses one.
- **Values:** `image` | `text_embeds` | `image_embeds` | `conditioning`
- **Description:** `image` datasets contain your training data. `text_embeds` contain the outputs of the text encoder cache, and `image_embeds` contain the VAE outputs, if the model uses one. When a dataset is marked as `conditioning`, it is possible to pair it to your `image` dataset via [the conditioning_data option](#conditioning_data)
- **Note:** Text and image embed datasets are defined differently than image datasets are. A text embed dataset stores ONLY the text embed objects. An image dataset stores the training data.

### `default`
Expand All @@ -71,6 +71,16 @@ Here is the most basic example of a dataloader configuration file, as `multidata
- **Values:** `aws` | `local` | `csv`
- **Description:** Determines the storage backend (local, csv or cloud) used for this dataset.

### `conditioning_type`

- **Values:** `controlnet` | `mask`
- **Description:** A dataset may contain ControlNet conditioning inputs or masks to use during loss calculations. Only one or the other may be used.

### `conditioning_data`

- **Values:** `id` value of conditioning dataset
- **Description:** As described in [the ControlNet guide](/documentation/CONTROLNET.md), an `image` dataset can be paired to its ControlNet or image mask data via this option.

### `instance_data_dir` / `aws_data_prefix`

- **Local:** Path to the data on the filesystem.
Expand Down
67 changes: 67 additions & 0 deletions documentation/DREAMBOOTH.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,73 @@ The model contains something called a "prior" which could, in theory, be preserv

> 🟢 ([#1031](https://github.com/bghira/SimpleTuner/issues/1031)) Prior preservation loss is supported in SimpleTuner when training LyCORIS adapters by setting `is_regularisation_data` on that dataset.
### Masked loss

Image masks may be defined in pairs with image data. The dark portions of the mask will cause the loss calculations to ignore these parts of the image.

An example [script](/toolkit/datasets/masked_loss/generate_dataset_masks.py) exists to generate these masks, given an input_dir and output_dir:

```bash
python generate_dataset_masks.py --input_dir /images/input \
--output_dir /images/output \
--text_input "person"
```

However, this does not have any advanced functionality such as mask padding blurring.

When defining your image mask dataset:

- Every image must have a mask. Use an all-white image if you do not want to mask.
- Set `dataset_type=conditioning` on your conditioning (mask) data folder
- Set `conditioning_type=mask` on your mask dataset
- Set `conditioning_data=` to your conditioning dataset `id` on your image dataset

```json
[
{
"id": "dreambooth-data",
"type": "local",
"dataset_type": "image",
"conditioning_data": "dreambooth-conditioning",
"instance_data_dir": "/training/datasets/test_datasets/dreambooth",
"cache_dir_vae": "/training/cache/vae/sdxl/dreambooth-data",
"caption_strategy": "instanceprompt",
"instance_prompt": "an dreambooth",
"metadata_backend": "discovery",
"resolution": 1024,
"minimum_image_size": 1024,
"maximum_image_size": 1024,
"target_downsample_size": 1024,
"crop": true,
"crop_aspect": "square",
"crop_style": "center",
"resolution_type": "pixel_area"
},
{
"id": "dreambooth-conditioning",
"type": "local",
"dataset_type": "conditioning",
"instance_data_dir": "/training/datasets/test_datasets/dreambooth_mask",
"resolution": 1024,
"minimum_image_size": 1024,
"maximum_image_size": 1024,
"target_downsample_size": 1024,
"crop": true,
"crop_aspect": "square",
"crop_style": "center",
"resolution_type": "pixel_area",
"conditioning_type": "mask"
},
{
"id": "an example backend for text embeds.",
"dataset_type": "text_embeds",
"default": true,
"type": "local",
"cache_dir": "/training/cache/text/sdxl-base/masked_loss"
}
]
```

## Setup

Following the [tutorial](/TUTORIAL.md) is required before you can continue into Dreambooth-specific configuration.
Expand Down
8 changes: 7 additions & 1 deletion documentation/quickstart/FLUX.md
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,10 @@ NF4 does not work with torch.compile, so whatever you get for speed is what you

If VRAM is not a concern (eg. 48G or greater) then int8 with torch.compile is your best, fastest option.

### Masked loss

If you are training a subject or style and would like to mask one or the other, see the [masked loss training](/documentation/DREAMBOOTH.md#masked-loss) section of the Dreambooth guide.

### Classifier-free guidance

#### Problem
Expand Down Expand Up @@ -490,7 +494,9 @@ We can partially reintroduce distillation to a de-distilled model by continuing
#### LoKr (--lora_type=lycoris)
- Higher learning rates are better for LoKr (`1e-3` with AdamW, `2e-4` with Lion)
- Other algo need more exploration.
- Setting `is_regularisation_data` on such datasets may help preserve / prevent bleed.
- Setting `is_regularisation_data` on such datasets may help preserve / prevent bleed and improve the final resulting model's quality.
- This behaves differently from "prior loss preservation" which is known for doubling training batch sizes and not improving the result much
- SimpleTuner's regularisation data implementation provides an efficient manner of preserving the base model

### Image artifacts
Flux will immediately absorb bad image artifacts. It's just how it is - a final training run on just high quality data may be required to fix it at the end.
Expand Down
5 changes: 5 additions & 0 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,10 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize
"Increasing resolution to 256, as is required for DF Stage II."
)

conditioning_type = None
if backend.get("dataset_type") == "conditioning":
conditioning_type = backend.get("conditioning_type", "controlnet")

init_backend["sampler"] = MultiAspectSampler(
id=init_backend["id"],
metadata_backend=init_backend["metadata_backend"],
Expand All @@ -891,6 +895,7 @@ def configure_multi_databackend(args: dict, accelerator, text_encoders, tokenize
"prepend_instance_prompt", args.prepend_instance_prompt
),
instance_prompt=backend.get("instance_prompt", args.instance_prompt),
conditioning_type=conditioning_type,
)
if init_backend["sampler"].caption_strategy == "parquet":
configure_parquet_database(backend, args, init_backend["data_backend"])
Expand Down
8 changes: 8 additions & 0 deletions helpers/image_manipulation/training_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
data_backend_id: str,
image_metadata: dict = None,
image_path: str = None,
conditioning_type: str = None,
):
"""
Initializes a new TrainingSample instance with a provided PIL.Image object and a data backend identifier.
Expand All @@ -38,6 +39,7 @@ def __init__(
self.target_size = None
self.intermediary_size = None
self.original_size = None
self.conditioning_type = conditioning_type
self.data_backend_id = data_backend_id
self.image_metadata = (
image_metadata
Expand Down Expand Up @@ -601,6 +603,12 @@ def get_image(self):
"""
return self.image

def is_conditioning_sample(self):
return self.conditioning_type is not None

def get_conditioning_type(self):
return self.conditioning_type

def get_conditioning_image(self):
"""
Fetch a conditioning image, eg. a canny edge map for ControlNet training.
Expand Down
26 changes: 23 additions & 3 deletions helpers/multiaspect/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
use_captions=True,
prepend_instance_prompt=False,
instance_prompt: str = None,
conditioning_type: str = None,
):
"""
Initializes the sampler with provided settings.
Expand All @@ -60,6 +61,13 @@ def __init__(
f"MultiAspectSampler-{self.id}",
os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"),
)
if conditioning_type is not None:
if conditioning_type not in ["controlnet", "mask"]:
raise ValueError(
f"Unknown conditioning image type: {conditioning_type}"
)
self.conditioning_type = conditioning_type

self.rank_info = rank_info()
self.accelerator = accelerator
self.metadata_backend = metadata_backend
Expand Down Expand Up @@ -446,19 +454,30 @@ def get_conditioning_sample(self, original_sample_path: str) -> str:
full_path = os.path.join(
self.metadata_backend.instance_data_dir, original_sample_path
)
try:
conditioning_sample_data = self.data_backend.read_image(full_path)
except Exception as e:
self.logger.error(f"Could not fetch conditioning sample: {e}")

return None
if not conditioning_sample_data:
self.debug_log(f"Could not fetch conditioning sample from {full_path}.")
return None

conditioning_sample = TrainingSample(
image=self.data_backend.read_image(full_path),
image=conditioning_sample_data,
data_backend_id=self.id,
image_metadata=self.metadata_backend.get_metadata_by_filepath(full_path),
image_path=full_path,
conditioning_type=self.conditioning_type,
)
return conditioning_sample

def connect_conditioning_samples(self, samples: tuple):
if not StateTracker.get_args().controlnet:
return samples
# Locate the conditioning data
conditioning_dataset = StateTracker.get_conditioning_dataset(self.id)
if conditioning_dataset is None:
return samples
sampler = conditioning_dataset["sampler"]
outputs = list(samples)
for sample in samples:
Expand Down Expand Up @@ -540,6 +559,7 @@ def __iter__(self):
[instance["image_path"] for instance in final_yield]
)
self.accelerator.wait_for_everyone()
# if applicable, we'll append TrainingSample(s) to the end for conditioning inputs.
final_yield = self.connect_conditioning_samples(final_yield)
yield tuple(final_yield)
# Change bucket after a full batch is yielded
Expand Down
29 changes: 24 additions & 5 deletions helpers/training/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,17 +416,35 @@ def collate_fn(batch):
)

conditioning_filepaths = []
conditioning_latents = None
conditioning_pixel_values = None
conditioning_type = None
if len(conditioning_examples) > 0:
for example in conditioning_examples:
# Building the list of conditioning image filepaths.
if conditioning_type is not None:
if example.get_conditioning_type() != conditioning_type:
raise ValueError(
f"Conditioning type mismatch: {conditioning_type} != {example.get_conditioning_type()}"
"\n-> Ensure all conditioning samples are of the same type."
)
else:
conditioning_type = example.get_conditioning_type()
if conditioning_type == "mask" and len(conditioning_examples) != len(
examples
):
raise ValueError(
f"Masks seem to be missing for some of the following images: {examples}"
f"\n-> Ensure all images have a corresponding mask: {[example.image_path() for example in conditioning_examples]}"
)
conditioning_filepaths.append(example.image_path(basename_only=False))
# Use the poorly-named method to retrieve the image pixel values
conditioning_latents = deepfloyd_pixels(conditioning_filepaths, data_backend_id)
conditioning_latents = torch.stack(
conditioning_pixel_values = deepfloyd_pixels(
conditioning_filepaths, data_backend_id
)
conditioning_pixel_values = torch.stack(
[
latent.to(StateTracker.get_accelerator().device)
for latent in conditioning_latents
for latent in conditioning_pixel_values
]
)

Expand Down Expand Up @@ -474,7 +492,8 @@ def collate_fn(batch):
"add_text_embeds": add_text_embeds_all,
"batch_time_ids": batch_time_ids,
"batch_luminance": batch_luminance,
"conditioning_pixel_values": conditioning_latents,
"conditioning_pixel_values": conditioning_pixel_values,
"encoder_attention_mask": attn_mask,
"is_regularisation_data": is_regularisation_data,
"conditioning_type": conditioning_type,
}
1 change: 1 addition & 0 deletions helpers/training/default_settings/safety_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def safety_check(args, accelerator):

if (
args.model_type != "lora"
and not args.controlnet
and args.base_model_precision != "no_change"
and not args.i_know_what_i_am_doing
):
Expand Down
2 changes: 1 addition & 1 deletion helpers/training/state_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def set_conditioning_dataset(

@classmethod
def get_conditioning_dataset(cls, data_backend_id: str):
return cls.data_backends[data_backend_id]["conditioning_data"]
return cls.data_backends[data_backend_id].get("conditioning_data", None)

@classmethod
def get_data_backend_config(cls, data_backend_id: str):
Expand Down
Loading

0 comments on commit dddaf4f

Please sign in to comment.