Skip to content

Commit

Permalink
Merge pull request #375 from bghira/main
Browse files Browse the repository at this point in the history
fix --aspect_bucket_rounding not being applied correctly | rebuild image sample handling to be structured object-oriented logic | fix early epoch exit problem | max epochs vs max steps ambiguity reduced by setting default to 0 for one of them | fixes for LoRA text encoder save/load hooks | optimise trainer | 300% performance gain by removing the torch anomaly detector | fix dataset race condition where a single image dataset was not being detected | AMD documentation for install, dependencies thanks to Beinsezii | fix for wandb timestep distribution chart values racing ahead of reality
  • Loading branch information
bghira authored May 2, 2024
2 parents 2cd98b9 + 2970298 commit 528d8fe
Show file tree
Hide file tree
Showing 32 changed files with 699 additions and 605 deletions.
14 changes: 12 additions & 2 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ To install the Apple-specific requirements:
poetry install --no-root -C install/apple
```

### Linux
### Linux + Nvidia/CUDA

The first command you'll run will install most of the dependencies:

Expand Down Expand Up @@ -59,6 +59,16 @@ If the egg install for Xformers does not work, try including `xformers` on the f
pip3 install --pre xformers torch torchvision torchaudio torchtriton --extra-index-url https://download.pytorch.org/whl/nightly/cu118 --force
```

### Linux + AMD / ROCm
Due to `xformers` not supporting the ROCm platform, memory requirements for training will likely be higher than otherwise stated.

To install the ROCm-specific requirements:

```bash
poetry install --no-root -C install/rocm
```


### All platforms

2. For SD2.1, copy `sd21-env.sh.example` to `env.sh` - be sure to fill out the details. Try to change as little as possible.
Expand Down Expand Up @@ -89,4 +99,4 @@ For SDXL, run the `train_sdxl.sh` script, redirecting outputs to the log file:
bash train_sdxl.sh > /path/to/training-$(date +%s).log 2>&1
```

> ⚠️ At this point, the commands will work, but further configuration is required. See [the tutorial](/TUTORIAL.md) for more information.
> ⚠️ At this point, the commands will work, but further configuration is required. See [the tutorial](/TUTORIAL.md) for more information.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ EMA (exponential moving average) weights are a memory-heavy affair, but provide
### GPU vendors

* NVIDIA - pretty much anything 3090 and up is a safe bet. YMMV.
* AMD - No one has reported anything, we don't know.
* AMD - SDXL LoRA and UNet are verified working on a 7900 XTX 24GB. Lacking `xformers`, it will likely use more memory than Nvidia equivalents
* Apple - LoRA and full u-net tuning are tested to work on an M3 Max with 128G memory, taking about **12G** of "Wired" memory and **4G** of system memory for SDXL.
* You likely need a 24G or greater machine for machine learning with M-series hardware due to the lack of memory-efficient attention.

Expand Down
7 changes: 3 additions & 4 deletions TUTORIAL.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,8 @@ Here's a breakdown of what each environment variable does:

#### Data Locations

- `BASE_DIR`, `INSTANCE_DIR`, `OUTPUT_DIR`: Directories for the training data, instance data, and output models.
- `BASE_DIR`, `OUTPUT_DIR`: Directories for the training data, instance data, and output models.
- `BASE_DIR` - Used for populating other variables, mostly.
- `INSTANCE_DIR` - Where your actual training data is. This can be anywhere, it does not need to be underneath `BASE_DIR`.
- `OUTPUT_DIR` - Where the model pipeline results are stored during training, and after it completes.

#### Training Parameters
Expand All @@ -236,9 +235,9 @@ Here's a breakdown of what each environment variable does:
- If you use `MAX_NUM_STEPS`, it's recommended to set `NUM_EPOCHS` to `0`.
- Similarly, if you use `NUM_EPOCHS`, it is recommended to set `MAX_NUM_STEPS` to `0`.
- This simply signals to the trainer that you explicitly wish to use one or the other.
- If you supply `NUM_EPOCHS` and `MAX_NUM_STEPS` together, the training will stop running at whichever happens first.
- Don't supply `NUM_EPOCHS` and `MAX_NUM_STEPS` values together, it won't let you begin training, to ensure there is no ambiguity about which you expect to take priority.
- `LR_SCHEDULE`, `LR_WARMUP_STEPS`: Learning rate schedule and warmup steps.
- `LR_SCHEDULE` - stick to `constant`, as it is most likely to be stable and less chaotic. However, `polynomial` and `constant_with_warmup` have potential of moving the model's local minima before settling in and reducing the loss. Experimentation can pay off here.
- `LR_SCHEDULE` - stick to `constant`, as it is most likely to be stable and less chaotic. However, `polynomial` and `constant_with_warmup` have potential of moving the model's local minima before settling in and reducing the loss. Experimentation can pay off here, especially using the `cosine` and `sine` schedulers, which offer a unique approach to learning rate scheduling.
- `TRAIN_BATCH_SIZE`: Batch size for training. You want this **as high as you can get it** without running out of VRAM or making your training unnecessarily **slow** (eg. 300-400% increase in training runtime - yikes! 💸)

## Additional Notes
Expand Down
1 change: 0 additions & 1 deletion documentation/DEEPFLOYD.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ export LEARNING_RATE_END=4e-6 #@param {type:"number"}

# Where to store your results.
export BASE_DIR="/training"
export INSTANCE_DIR="${BASE_DIR}/data"
export OUTPUT_DIR="${BASE_DIR}/models/deepfloyd"
export DATALOADER_CONFIG="multidatabackend_deepfloyd.json"

Expand Down
10 changes: 8 additions & 2 deletions helpers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,7 @@ def parse_args(input_args=None):
logger.warning(
"MPS may benefit from the use of --unet_attention_slice for memory savings at the cost of speed."
)
if args.train_batch_size > 12:
if args.train_batch_size > 16:
logger.error(
"An M3 Max 128G will use 12 seconds per step at a batch size of 1 and 65 seconds per step at a batch size of 12."
" Any higher values will result in NDArray size errors or other unstable training results and crashes."
Expand All @@ -1346,6 +1346,12 @@ def parse_args(input_args=None):
)
sys.exit(1)

if args.max_train_steps is not None and args.num_train_epochs > 0:
logger.error(
"When using --max_train_steps (MAX_NUM_STEPS), you must set --num_train_epochs (NUM_EPOCHS) to 0."
)
sys.exit(1)

if (
args.pretrained_vae_model_name_or_path is not None
and StateTracker.get_model_type() == "legacy"
Expand All @@ -1366,7 +1372,7 @@ def parse_args(input_args=None):
deepfloyd_pixel_alignment = 8
if args.aspect_bucket_alignment != deepfloyd_pixel_alignment:
logger.warning(
f"Overriding aspect bucket alignment pixel interval to {deepfloyd_pixel_alignment}px instead of{args.aspect_bucket_alignment}px."
f"Overriding aspect bucket alignment pixel interval to {deepfloyd_pixel_alignment}px instead of {args.aspect_bucket_alignment}px."
)
args.aspect_bucket_alignment = deepfloyd_pixel_alignment

Expand Down
6 changes: 5 additions & 1 deletion helpers/caching/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,11 @@ def _process_images_in_batch(
# retrieve image data from Generator, image_data:
filepath = image_paths.pop()
image = image_data.pop()
aspect_bucket = MultiaspectImage.calculate_image_aspect_ratio(image)
aspect_bucket = (
self.metadata_backend.get_metadata_attribute_by_filepath(
filepath=filepath, attribute="aspect_bucket"
)
)
else:
filepath, image, aspect_bucket = self.process_queue.get()
if self.minimum_image_size is not None:
Expand Down
96 changes: 96 additions & 0 deletions helpers/image_manipulation/cropping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from PIL import Image


class BaseCropping:
def __init__(self, image: Image = None, image_metadata: dict = None):
self.image = image
self.image_metadata = image_metadata
if self.image:
self.original_width, self.original_height = self.image.size
elif self.image_metadata:
self.original_width, self.original_height = self.image_metadata[
"original_size"
]

def crop(self, target_width, target_height):
raise NotImplementedError("Subclasses must implement this method")


class CornerCropping(BaseCropping):
def crop(self, target_width, target_height):
left = max(0, self.original_width - target_width)
top = max(0, self.original_height - target_height)
right = self.original_width
bottom = self.original_height
if self.image:
return self.image.crop((left, top, right, bottom)), (left, top)
elif self.image_metadata:
return self.image_metadata, (left, top)


class CenterCropping(BaseCropping):
def crop(self, target_width, target_height):
left = (self.original_width - target_width) / 2
top = (self.original_height - target_height) / 2
right = (self.original_width + target_width) / 2
bottom = (self.original_height + target_height) / 2
if self.image:
return self.image.crop((left, top, right, bottom)), (left, top)
elif self.image_metadata:
return self.image_metadata, (left, top)


class RandomCropping(BaseCropping):
def crop(self, target_width, target_height):
import random

left = random.randint(0, max(0, self.original_width - target_width))
top = random.randint(0, max(0, self.original_height - target_height))
right = left + target_width
bottom = top + target_height
if self.image:
return self.image.crop((left, top, right, bottom)), (left, top)
elif self.image_metadata:
return self.image_metadata, (left, top)


class FaceCropping(RandomCropping):
def crop(
self,
image: Image.Image,
target_width: int,
target_height: int,
):
# Import modules
import cv2
import numpy as np

# Detect a face in the image
face_cascade = cv2.CascadeClassifier(
cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
)
image = image.convert("RGB")
image = np.array(image)
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, 1.1, 4)
if len(faces) > 0:
# Get the largest face
face = max(faces, key=lambda f: f[2] * f[3])
x, y, w, h = face
left = max(0, x - 0.5 * w)
top = max(0, y - 0.5 * h)
right = min(image.shape[1], x + 1.5 * w)
bottom = min(image.shape[0], y + 1.5 * h)
image = Image.fromarray(image)
return image.crop((left, top, right, bottom)), (left, top)
else:
# Crop the image from a random position
return super.crop(image, target_width, target_height)


crop_handlers = {
"corner": CornerCropping,
"centre": CenterCropping,
"center": CenterCropping,
"random": RandomCropping,
}
Loading

0 comments on commit 528d8fe

Please sign in to comment.