Skip to content

Commit

Permalink
Merge pull request #374 from bghira/refactor/training-sample-handler
Browse files Browse the repository at this point in the history
TrainingSample: refactor and encapsulate image handling, improving performance and reliability
  • Loading branch information
bghira authored May 2, 2024
2 parents fa494cd + cbf0f07 commit 2970298
Show file tree
Hide file tree
Showing 19 changed files with 446 additions and 397 deletions.
4 changes: 2 additions & 2 deletions helpers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,7 @@ def parse_args(input_args=None):
logger.warning(
"MPS may benefit from the use of --unet_attention_slice for memory savings at the cost of speed."
)
if args.train_batch_size > 12:
if args.train_batch_size > 16:
logger.error(
"An M3 Max 128G will use 12 seconds per step at a batch size of 1 and 65 seconds per step at a batch size of 12."
" Any higher values will result in NDArray size errors or other unstable training results and crashes."
Expand Down Expand Up @@ -1372,7 +1372,7 @@ def parse_args(input_args=None):
deepfloyd_pixel_alignment = 8
if args.aspect_bucket_alignment != deepfloyd_pixel_alignment:
logger.warning(
f"Overriding aspect bucket alignment pixel interval to {deepfloyd_pixel_alignment}px instead of{args.aspect_bucket_alignment}px."
f"Overriding aspect bucket alignment pixel interval to {deepfloyd_pixel_alignment}px instead of {args.aspect_bucket_alignment}px."
)
args.aspect_bucket_alignment = deepfloyd_pixel_alignment

Expand Down
10 changes: 9 additions & 1 deletion helpers/image_manipulation/cropping.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def crop(self, target_width, target_height):
class FaceCropping(RandomCropping):
def crop(
self,
image: Image,
image: Image.Image,
target_width: int,
target_height: int,
):
Expand Down Expand Up @@ -86,3 +86,11 @@ def crop(
else:
# Crop the image from a random position
return super.crop(image, target_width, target_height)


crop_handlers = {
"corner": CornerCropping,
"centre": CenterCropping,
"center": CenterCropping,
"random": RandomCropping,
}
265 changes: 265 additions & 0 deletions helpers/image_manipulation/training_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
from PIL import Image
from PIL.ImageOps import exif_transpose
from helpers.multiaspect.image import MultiaspectImage, resize_helpers
from helpers.multiaspect.image import crop_handlers
from helpers.training.state_tracker import StateTracker
import logging

logger = logging.getLogger(__name__)


class TrainingSample:
def __init__(
self, image: Image.Image, data_backend_id: str, image_metadata: dict = None
):
"""
Initializes a new TrainingSample instance with a provided PIL.Image object and a data backend identifier.
Args:
image (Image.Image): A PIL Image object.
data_backend_id (str): Identifier for the data backend used for additional operations.
metadata (dict): Optional metadata associated with the image.
"""
self.image = image
self.data_backend_id = data_backend_id
self.image_metadata = image_metadata if image_metadata else {}
if hasattr(image, "size"):
self.original_size = self.image.size
elif image_metadata is not None:
self.original_size = image_metadata.get("original_size")
logger.debug(
f"Metadata for training sample given instead of image? {image_metadata}"
)

if not self.original_size:
raise Exception("Original size not found in metadata.")

# Torchvision transforms turn the pixels into a Tensor and normalize them for the VAE.
self.transforms = MultiaspectImage.get_image_transforms()
# EXIT, RGB conversions.
self.correct_image()

# Backend config details
self.data_backend_config = StateTracker.get_data_backend_config(data_backend_id)
self.crop_enabled = self.data_backend_config.get("crop", False)
self.crop_style = self.data_backend_config.get("crop_style", "random")
self.crop_aspect = self.data_backend_config.get("crop_aspect", "square")
self.crop_coordinates = (0, 0)
crop_handler_cls = crop_handlers.get(self.crop_style)
if not crop_handler_cls:
raise ValueError(f"Unknown crop style: {self.crop_style}")
self.cropper = crop_handler_cls(image=self.image, image_metadata=image_metadata)
self.resolution = self.data_backend_config.get("resolution")
self.resolution_type = self.data_backend_config.get("resolution_type")
self.target_size_calculator = resize_helpers.get(self.resolution_type)
if self.target_size_calculator is None:
raise ValueError(f"Unknown resolution type: {self.resolution_type}")
if self.resolution_type == "pixel":
self.target_area = self.resolution
# Store the pixel value, eg. 1024
self.pixel_resolution = self.resolution
# Store the megapixel value, eg. 1.0
self.megapixel_resolution = self.resolution / 1e6
elif self.resolution_type == "area":
self.target_area = self.resolution * 1e6 # Convert megapixels to pixels
# Store the pixel value, eg. 1024
self.pixel_resolution = self.resolution * 1e6
# Store the megapixel value, eg. 1.0
self.megapixel_resolution = self.resolution
else:
raise Exception(f"Unknown resolution type: {self.resolution_type}")
self.target_downsample_size = self.data_backend_config.get(
"target_downsample_size", None
)
self.maximum_image_size = self.data_backend_config.get(
"maximum_image_size", None
)

def prepare(self, return_tensor: bool = False):
"""
Perform initial image preparations such as converting to RGB and applying EXIF transformations.
Args:
image (Image.Image): The image to prepare.
Returns:
(image, crop_coordinates, aspect_ratio)
"""
self.crop()
if not self.crop_enabled:
self.resize()

image = self.image
if return_tensor:
# Return normalised tensor.
image = self.transforms(image)
return PreparedSample(
image=image,
original_size=self.original_size,
crop_coordinates=self.crop_coordinates,
aspect_ratio=self.aspect_ratio,
image_metadata=self.image_metadata,
target_size=self.target_size,
)

def area(self) -> int:
"""
Calculate the area of the image.
Returns:
int: The area of the image.
"""
if self.image is not None:
return self.image.size[0] * self.image.size[1]
if self.original_size:
return self.original_size[0] * self.original_size[1]

def should_downsample_before_crop(self) -> bool:
"""
Returns:
bool: True if the image should be downsampled before cropping, False otherwise.
"""
if (
not self.crop_enabled
or not self.maximum_image_size
or not self.target_downsample_size
):
return False
if self.data_backend_config.get("resolution_type") == "pixel":
return (
self.image.size[0] > self.pixel_resolution
or self.image.size[1] > self.pixel_resolution
)
elif self.data_backend_config.get("resolution_type") == "area":
logger.debug(
f"Image is too large? {self.area() > self.target_area} (image area: {self.area()}, target area: {self.target_area})"
)
return self.area() > self.target_area
else:
raise ValueError(
f"Unknown resolution type: {self.data_backend_config.get('resolution_type')}"
)

def downsample_before_crop(self):
"""
Downsample the image before cropping, to preserve scene details.
"""
if self.image and self.should_downsample_before_crop():
width, height, _ = self.calculate_target_size(downsample_before_crop=True)
logger.debug(
f"Downsampling image from {self.image.size} to {width}x{height} before cropping."
)
self.resize((width, height))
return self

def calculate_target_size(self, downsample_before_crop: bool = False):
# Square crops are always {self.pixel_resolution}x{self.pixel_resolution}
if self.crop_aspect == "square" and not downsample_before_crop:
self.aspect_ratio = 1.0
self.target_size = (self.pixel_resolution, self.pixel_resolution)
return self.target_size[0], self.target_size[1], self.aspect_ratio
self.aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio(
self.original_size
)
if downsample_before_crop and self.target_downsample_size is not None:
target_width, target_height, self.aspect_ratio = (
self.target_size_calculator(
self.aspect_ratio, self.target_downsample_size
)
)
else:
target_width, target_height, self.aspect_ratio = (
self.target_size_calculator(self.aspect_ratio, self.resolution)
)
self.target_size = (target_width, target_height)
self.aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio(
self.target_size
)

return self.target_size[0], self.target_size[1], self.aspect_ratio

def correct_image(self):
"""
Apply a series of transformations to the image to "correct" it.
"""
if self.image:
# Convert image to RGB to remove any alpha channel and apply EXIF data transformations
self.image = self.image.convert("RGB")
self.image = exif_transpose(self.image)
self.original_size = self.image.size
return self

def crop(self):
"""
Crop the image using the detected crop handler class.
"""
if not self.crop_enabled:
return self
logger.debug(
f"Cropping image with {self.crop_style} style and {self.crop_aspect}."
)

# Too-big of an image, resize before we crop.
self.downsample_before_crop()
width, height, aspect_ratio = self.calculate_target_size(
downsample_before_crop=False
)
logger.debug(
f"Pre-crop size: {self.image.size if hasattr(self.image, 'size') else 'Unknown'}."
)
self.image, self.crop_coordinates = self.cropper.crop(width, height)
logger.debug(
f"Post-crop size: {self.image.size if hasattr(self.image, 'size') else 'Unknown'}."
)
return self

def resize(self, target_size: tuple = None):
"""
Resize the image to a new size.
Args:
target_size (tuple): The target size as (width, height).
"""
if target_size is None:
target_width, target_height, aspect_ratio = self.calculate_target_size()
target_size = (target_width, target_height)
if self.image:
self.image = self.image.resize(target_size, Image.Resampling.LANCZOS)
self.aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio(
self.image.size
)
return self

def get_image(self):
"""
Returns the current state of the image.
Returns:
Image.Image: The current image.
"""
return self.image


class PreparedSample:
def __init__(
self,
image: Image.Image,
image_metadata: dict,
original_size: tuple,
target_size: tuple,
aspect_ratio: float,
crop_coordinates: tuple,
):
"""
Initializes a new PreparedSample instance with a provided PIL.Image object and optional metadata.
Args:
image (Image.Image): A PIL Image object.
metadata (dict): Optional metadata associated with the image.
"""
self.image = image
self.image_metadata = image_metadata if image_metadata else {}
self.original_size = original_size
self.target_size = target_size
self.aspect_ratio = aspect_ratio
self.crop_coordinates = crop_coordinates
8 changes: 2 additions & 6 deletions helpers/metadata/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,9 +740,7 @@ def handle_vae_cache_inconsistencies(self, vae_cache, vae_cache_behavior: str):
continue
if vae_cache_behavior == "sync":
# Sync aspect buckets with the cache
expected_bucket = MultiaspectImage.determine_bucket_for_aspect_ratio(
self._get_aspect_ratio_from_tensor(cache_content)
)
expected_bucket = str(self._get_aspect_ratio_from_tensor(cache_content))
self._modify_cache_entry_bucket(cache_file, expected_bucket)
elif vae_cache_behavior == "recreate":
# Delete the cache file if it doesn't match the aspect bucket indices
Expand Down Expand Up @@ -837,9 +835,7 @@ def is_cache_inconsistent(self, vae_cache, cache_file, cache_content):
)

actual_aspect_ratio = self._get_aspect_ratio_from_tensor(cache_content)
expected_bucket = MultiaspectImage.determine_bucket_for_aspect_ratio(
recalculated_aspect_ratio
)
expected_bucket = str(recalculated_aspect_ratio)
logger.debug(
f"Expected bucket for {cache_file}: {expected_bucket} vs actual {actual_aspect_ratio}"
)
Expand Down
25 changes: 12 additions & 13 deletions helpers/metadata/backends/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from helpers.multiaspect.image import MultiaspectImage
from helpers.data_backend.base import BaseDataBackend
from helpers.metadata.backends.base import MetadataBackend
from helpers.image_manipulation.training_sample import TrainingSample
from pathlib import Path
import json, logging, os, time, re
from multiprocessing import Manager
Expand Down Expand Up @@ -217,27 +218,25 @@ def _process_for_bucket(
statistics["skipped"]["too_small"] += 1
return aspect_ratio_bucket_indices
image_metadata["original_size"] = image.size
image, crop_coordinates, new_aspect_ratio = (
MultiaspectImage.prepare_image(
image=image,
resolution=self.resolution,
resolution_type=self.resolution_type,
id=self.data_backend.id,
)
training_sample = TrainingSample(
image=image, data_backend_id=self.id, image_metadata=image_metadata
)
image_metadata["crop_coordinates"] = crop_coordinates
prepared_sample = training_sample.prepare()
image_metadata["crop_coordinates"] = prepared_sample.crop_coordinates
image_metadata["target_size"] = image.size
# Round to avoid excessive unique buckets
image_metadata["aspect_ratio"] = new_aspect_ratio
image_metadata["aspect_ratio"] = prepared_sample.aspect_ratio
image_metadata["luminance"] = calculate_luminance(image)
logger.debug(
f"Image {image_path_str} has aspect ratio {new_aspect_ratio} and size {image.size}."
f"Image {image_path_str} has aspect ratio {prepared_sample.aspect_ratio} and size {image.size}."
)

# Create a new bucket if it doesn't exist
if str(new_aspect_ratio) not in aspect_ratio_bucket_indices:
aspect_ratio_bucket_indices[str(new_aspect_ratio)] = []
aspect_ratio_bucket_indices[str(new_aspect_ratio)].append(image_path_str)
if str(prepared_sample.aspect_ratio) not in aspect_ratio_bucket_indices:
aspect_ratio_bucket_indices[str(prepared_sample.aspect_ratio)] = []
aspect_ratio_bucket_indices[str(prepared_sample.aspect_ratio)].append(
image_path_str
)
# Instead of directly updating, just fill the provided dictionary
if metadata_updates is not None:
metadata_updates[image_path_str] = image_metadata
Expand Down
Loading

0 comments on commit 2970298

Please sign in to comment.