Skip to content

Commit

Permalink
Merge pull request #386 from bghira/main
Browse files Browse the repository at this point in the history
bitfit restrictions / model freezing simplification | updates to huggingface hub integration, automatically push model card and weights | webhooks: minor log level fixes, other improvements. ability to debug image cropping by sending them to discord. | resize and crop fixes for json and parquet backend edge cases (VAE encode in-flight)
  • Loading branch information
bghira authored May 9, 2024
2 parents a66dab5 + 5737588 commit 9cd535a
Show file tree
Hide file tree
Showing 19 changed files with 261 additions and 178 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ For memory-constrained systems, see the [DeepSpeed document](/documentation/DEEP
- Train directly from an S3-compatible storage provider, eliminating the requirement for expensive local storage. (Tested with Cloudflare R2 and Wasabi S3)
- DeepFloyd stage I and II full u-net or parameter-efficient fine-tuning via LoRA using 22G VRAM
- Webhook support for updating eg. Discord channels with your training progress, validations, and errors
- Integration with the [Hugging Face Hub](https://huggingface.co) for seamless model upload and nice automatically-generated model cards.

### Stable Diffusion 2.0/2.1

Expand Down
17 changes: 17 additions & 0 deletions TUTORIAL.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,23 @@ A value of 25% seems to provide some additional benefits such as reducing the nu

For users who are more familiar with model training and wish to tweak settings eg. `MIXED_PRECISION`, enabling offset noise, or setting up zero terminal SNR - detailed explanations can be found in [OPTIONS.md](/OPTIONS.md).

## Publishing checkpoints to Hugging Face Hub

Setting two values inside `sdxl-env.sh` or `sd2x-env.sh` will cause the trainer to automatically push your model up to the Hugging Face Hub upon training completion:

```bash
export PUSH_TO_HUB="true"
export HUB_MODEL_NAME="what-you-will-call-this"
```

Be sure to login before you begin training by executing:

```bash
huggingface-cli login
```

A model card will be automatically generated containing a majority of the relevant training session parameters.

## Monitoring and Logging

If `--report_to=wandb` is passed to the trainer (the default), it will ask on startup whether you wish to register on Weights & Biases to monitor your training run there. While you can always select option **3** or remove `--report_to=...` and disable reporting, it's encouraged to give it a try and watch your loss value drop as your training runs!
Expand Down
6 changes: 3 additions & 3 deletions helpers/caching/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def _process_images_in_batch(
try:
result = (
future.result()
) # Returns (image, crop_coordinates, new_aspect_ratio)
) # Returns PreparedSample or tuple(image, crop_coordinates, aspect_ratio)
if result: # Ensure result is not None or invalid
processed_images.append(result)
if first_aspect_ratio is None:
Expand All @@ -666,10 +666,10 @@ def _process_images_in_batch(
type(result) is tuple
and result[2]
and first_aspect_ratio is not None
and result.aspect_ratio != first_aspect_ratio
and result[2] != first_aspect_ratio
):
raise ValueError(
f"Image {filepath} has a different aspect ratio ({result.aspect_ratio}) than the first image in the batch ({first_aspect_ratio})."
f"Image {filepath} has a different aspect ratio ({result[2]}) than the first image in the batch ({first_aspect_ratio})."
)

except Exception as e:
Expand Down
66 changes: 33 additions & 33 deletions helpers/image_manipulation/cropping.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@

class BaseCropping:
def __init__(self, image: Image = None, image_metadata: dict = None):
self.original_height = None
self.original_width = None
self.intermediary_height = None
self.intermediary_width = None
self.image = image
self.image_metadata = image_metadata
# When we've only got metadata, we can't crop the image.
self.meta_crop = False
if self.image:
self.original_width, self.original_height = self.image.size
elif self.image_metadata:
Expand All @@ -24,67 +30,61 @@ def __init__(self, image: Image = None, image_metadata: dict = None):
def crop(self, target_width, target_height):
raise NotImplementedError("Subclasses must implement this method")

def set_image(self, image: Image = None, image_metadata: dict = None):
logger.debug(
f"Cropper image being refreshed. Before size: {self.original_width} x {self.original_height}"
)
if image is not None:
self.image = image
self.original_width, self.original_height = self.image.size
if image_metadata:
self.original_width, self.original_height = image_metadata["original_size"]
def set_image(self, image: Image.Image):
if type(image) is not Image.Image:
raise TypeError("Image must be a PIL Image object")
else:
logger.debug(f"Cropper received updated image contents: {image}")
self.image = image

def set_image_metadata(self, image_metadata: dict):
logger.debug(
f"Cropper image metadata being refreshed. Before size: {self.original_width} x {self.original_height}"
)
self.image_metadata = image_metadata
self.original_width, self.original_height = self.image_metadata["original_size"]
if "current_size" in self.image_metadata:
self.original_width, self.original_height = self.image_metadata[
"current_size"
]
return self

def set_intermediary_size(self, width, height):
self.intermediary_width = width
self.intermediary_height = height

return self


class CornerCropping(BaseCropping):
def crop(self, target_width, target_height):
left = max(0, self.original_width - target_width)
top = max(0, self.original_height - target_height)
right = self.original_width
bottom = self.original_height
left = max(0, self.intermediary_width - target_width)
top = max(0, self.intermediary_height - target_height)
right = self.intermediary_width
bottom = self.intermediary_height
if self.image:
return self.image.crop((left, top, right, bottom)), (top, left)
elif self.image_metadata:
return self.image_metadata, (top, left)
return None, (top, left)


class CenterCropping(BaseCropping):
def crop(self, target_width, target_height):
left = (self.original_width - target_width) / 2
top = (self.original_height - target_height) / 2
right = (self.original_width + target_width) / 2
bottom = (self.original_height + target_height) / 2
left = (self.intermediary_width - target_width) / 2
top = (self.intermediary_height - target_height) / 2
right = (self.intermediary_width + target_width) / 2
bottom = (self.intermediary_height + target_height) / 2
if self.image:
return self.image.crop((left, top, right, bottom)), (top, left)
elif self.image_metadata:
return self.image_metadata, (top, left)
return None, (top, left)


class RandomCropping(BaseCropping):
def crop(self, target_width, target_height):
import random

left = random.randint(0, max(0, self.original_width - target_width))
top = random.randint(0, max(0, self.original_height - target_height))
left = random.randint(0, max(0, self.intermediary_width - target_width))
top = random.randint(0, max(0, self.intermediary_height - target_height))
logger.debug(
f"Random cropping from {left}, {top} - {self.original_width}x{self.original_height} to {target_width}x{target_height}"
f"Random cropping from {left}, {top} - {self.intermediary_width}x{self.intermediary_height} to {target_width}x{target_height}"
)
right = left + target_width
bottom = top + target_height
if self.image:
return self.image.crop((left, top, right, bottom)), (top, left)
elif self.image_metadata:
return self.image_metadata, (top, left)
return None, (top, left)


class FaceCropping(RandomCropping):
Expand Down
85 changes: 54 additions & 31 deletions helpers/image_manipulation/training_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def prepare(self, return_tensor: bool = False):
if return_tensor:
# Return normalised tensor.
image = self.transforms(image)
return PreparedSample(
webhook_handler = StateTracker.get_webhook_handler()
prepared_sample = PreparedSample(
image=image,
original_size=self.original_size,
crop_coordinates=self.crop_coordinates,
Expand All @@ -109,6 +110,13 @@ def prepare(self, return_tensor: bool = False):
target_size=self.target_size,
intermediary_size=self.intermediary_size,
)
if webhook_handler:
webhook_handler.send(
message=f"Debug info for prepared sample, {str(prepared_sample)}",
images=[self.image] if self.image else None,
message_level="debug",
)
return prepared_sample

def area(self) -> int:
"""
Expand Down Expand Up @@ -256,14 +264,18 @@ def crop(self):
self._downsample_before_crop()
self.calculate_target_size(downsample_before_crop=False)
logger.debug(
f"Pre-crop size: {self.image.size if hasattr(self.image, 'size') else 'Unknown'}."
f"Pre-crop size: {self.image.size if hasattr(self.image, 'size') else self.target_size}."
)
if self.image is not None:
self.cropper.set_image(self.image)
self.cropper.set_intermediary_size(
self.intermediary_size[0], self.intermediary_size[1]
)
self.cropper.set_image(self.image)
self.image, self.crop_coordinates = self.cropper.crop(
self.target_size[0], self.target_size[1]
)
logger.debug(
f"Post-crop size: {self.image.size if hasattr(self.image, 'size') else 'Unknown'}."
f"Post-crop size: {self.image.size if hasattr(self.image, 'size') else self.target_size}."
)
return self

Expand All @@ -276,57 +288,55 @@ def resize(self, size: tuple = None):
"""
current_size = self.image.size if self.image is not None else self.original_size
if size is None:
target_size, intermediary_size, target_aspect_ratio = (
self.target_size, self.intermediary_size, self.target_aspect_ratio = (
self.calculate_target_size()
)
size = target_size
if target_size != intermediary_size:
size = self.target_size
if self.target_size != self.intermediary_size:
# Now we can resize the image to the intermediary size.
logger.debug(
f"Before resizing to {intermediary_size}, our image is {current_size} resolution."
f"Before resizing to {self.intermediary_size}, our image is {current_size} resolution."
)
if self.image is not None:
self.image = self.image.resize(
intermediary_size, Image.Resampling.LANCZOS
self.intermediary_size, Image.Resampling.LANCZOS
)
logger.debug(f"After resize, we are at {self.image.size}")
# Crop the image to its target size, so that we do not squish or stretch the image.
original_crop_coordinates = self.crop_coordinates
if self.image is not None:
logger.debug(
f"It's our lucky day, the cropper has an actual image to adjust: {self.image_metadata}"
)
self.cropper.set_image(self.image, self.image_metadata)
else:
logger.debug(
"We are adjusting based on a dream of the image metadata, we do not have a real image."
)
self.image_metadata["current_size"] = intermediary_size
self.cropper.set_image_metadata(self.image_metadata)
self.image, new_crop_coordinates = self.cropper.crop(
target_size[0], target_size[1]
logger.debug(
f"TrainingSample is updating Cropper with the latest image and intermediary size: {self.image} and {self.intermediary_size}"
)
# Adjust self.crop_coordinates to adjust the crop to the new size.
self.crop_coordinates = (
self.crop_coordinates[0] + new_crop_coordinates[0],
self.crop_coordinates[1] + new_crop_coordinates[1],
if self.image is not None and self.cropper:
self.cropper.set_image(self.image)
self.cropper.set_intermediary_size(
self.intermediary_size[0], self.intermediary_size[1]
)
logger.debug(
f"After crop-adjusting pixel alignment, our image is now {self.image_metadata['current_size'] if 'current_size' in self.image_metadata else self.image.size} resolution and its crop coordinates are now {self.crop_coordinates} adjusted from {original_crop_coordinates}"
f"Setting intermediary size to {self.intermediary_size} for image {self.image}"
)
self.image, self.crop_coordinates = self.cropper.crop(
self.target_size[0], self.target_size[1]
)
logger.debug(
f"Cropper returned image {self.image} and coords {self.crop_coordinates}"
)
logger.debug(
f"After crop-adjusting pixel alignment, our image is now {self.image.size if hasattr(self.image, 'size') else size} resolution and its crop coordinates are now {self.crop_coordinates}"
)

logger.debug(
f"Resizing image from {self.image.size if self.image is not None and type(self.image) is not dict else intermediary_size} to final target size: {size}"
f"Resizing image from {self.image.size if self.image is not None and type(self.image) is not dict else self.intermediary_size} to final target size: {size}"
)
else:
logger.debug(
f"Resizing image from {self.image.size if self.image is not None and type(self.image) is not dict else intermediary_size} to custom-provided size: {size}"
f"Resizing image from {self.image.size if self.image is not None and type(self.image) is not dict else self.intermediary_size} to custom-provided size: {size}"
)
if self.image and hasattr(self.image, "resize"):
logger.debug("Actually resizing image.")
self.image = self.image.resize(size, Image.Resampling.LANCZOS)
self.aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio(
self.image.size
)
logger.debug("Completed resize operation.")
return self

def get_image(self):
Expand Down Expand Up @@ -364,3 +374,16 @@ def __init__(
self.target_size = target_size
self.aspect_ratio = aspect_ratio
self.crop_coordinates = crop_coordinates

def __str__(self):
return f"PreparedSample(image={self.image}, original_size={self.original_size}, intermediary_size={self.intermediary_size}, target_size={self.target_size}, aspect_ratio={self.aspect_ratio}, crop_coordinates={self.crop_coordinates})"

def to_dict(self):
return {
"image": self.image,
"original_size": self.original_size,
"intermediary_size": self.intermediary_size,
"target_size": self.target_size,
"aspect_ratio": self.aspect_ratio,
"crop_coordinates": self.crop_coordinates,
}
4 changes: 2 additions & 2 deletions helpers/legacy/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ def save_model_card(
{'Xformers was used to train this model. As such, bf16 or fp32 inference may be required. Your mileage may vary.' if StateTracker.get_args().enable_xformers_memory_efficient_attention else 'This model was not trained with Xformers.'}
{StateTracker.get_args().mixed_precision} precision was used during training.
- Training epochs: {StateTracker.get_epoch()}
- Training epochs: {StateTracker.get_epoch() - 1}
- Training steps: {StateTracker.get_global_step()}
- Learning rate: {StateTracker.get_args().learning_rate}
- Effective batch size: {StateTracker.get_args().train_batch_size * StateTracker.get_args().gradient_accumulation_steps}
- Micro-batch size: {StateTracker.get_args().train_batch_size}
- Gradient accumulation steps: {StateTracker.get_args().gradient_accumulation_steps}
- Prediction type: {StateTracker.get_args().prediction_type}
- Rescaled betas zero SNR: {StateTracker.get_args().rescale_betas_zero_snr}
- Optimizer: {'AdamW, stochastic bf16' if StateTracker.get_args().adamw_bfloat16 else 'AdamW8Bit' if StateTracker.get_args().use_8bit_adam else 'Adafactor' if StateTracker.get_args().use_adafactor_optimizer else 'Prodigy' if StateTracker.get_args().use_prodigy_optimizer else 'AdamW'}
- Optimizer: {'AdamW, stochastic bf16' if StateTracker.get_args().adam_bfloat16 else 'AdamW8Bit' if StateTracker.get_args().use_8bit_adam else 'Adafactor' if StateTracker.get_args().use_adafactor_optimizer else 'Prodigy' if StateTracker.get_args().use_prodigy_optimizer else 'AdamW'}
## Datasets
Expand Down
12 changes: 10 additions & 2 deletions helpers/legacy/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def log_validations(
pipeline: DiffusionPipeline = None,
):
global_step = StateTracker.get_global_step()
global_resume_step = StateTracker.get_global_resume_step() or global_step
global_resume_step = StateTracker.get_global_resume_step() or 1
should_do_intermediary_validation = (
validation_prompts
and global_step % args.validation_steps == 0
Expand All @@ -242,9 +242,17 @@ def log_validations(
or args.num_validation_images is None
or args.num_validation_images <= 0
):
logger.debug(
f"Validations are disabled:"
f"\n -> validation_prompts: {validation_prompts}"
f"\n -> num_validation_images: {args.num_validation_images}"
)
return
if validation_type == "finish" and should_do_intermediary_validation:
# 382 - don't run final validation when we'd also have run the intermediary validation.
logger.debug(
"Skipping final validation, because training is completed. Avoiding 2x validation."
)
return
logger.debug(f"We have valid prompts to process.")
if StateTracker.get_webhook_handler() is not None:
Expand Down Expand Up @@ -515,7 +523,7 @@ def log_validations(
validation_resolution_width, validation_resolution_height = (
val * 4 for val in extra_validation_kwargs["image"].size
)
logger.info(
logger.debug(
f"Processing width/height: {validation_resolution_width}x{validation_resolution_height}"
)
validation_images[validation_shortname].extend(
Expand Down
2 changes: 1 addition & 1 deletion helpers/metadata/backends/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _process_for_bucket(
)
prepared_sample = training_sample.prepare()
image_metadata["crop_coordinates"] = prepared_sample.crop_coordinates
image_metadata["target_size"] = prepared_sample.image.size
image_metadata["target_size"] = prepared_sample.target_size
image_metadata["intermediary_size"] = prepared_sample.intermediary_size
# Round to avoid excessive unique buckets
image_metadata["aspect_ratio"] = prepared_sample.aspect_ratio
Expand Down
Loading

0 comments on commit 9cd535a

Please sign in to comment.