diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..e537fc16 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,82 @@ +# SimpleTuner needs CU118 +FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04 + +# /workspace is the default volume for Runpod & other hosts +WORKDIR /workspace + +# Update apt-get +RUN apt-get update -y + +# Prevents different commands from being stuck by waiting +# on user input during build +ENV DEBIAN_FRONTEND noninteractive + +# Install openssh & git +RUN apt-get install -y --no-install-recommends openssh-server \ + openssh-client \ + git \ + git-lfs + +# Installl misc unix libraries +RUN apt-get install -y wget \ + curl \ + tmux \ + tldr \ + nvtop \ + vim \ + rsync \ + net-tools \ + less \ + iputils-ping \ + 7zip \ + zip \ + unzip \ + htop \ + inotify-tools + +# Set up git to support LFS, and to store credentials; useful for Huggingface Hub +RUN git config --global credential.helper store && \ + git lfs install + +# Install Python VENV +RUN apt-get install -y python3.10-venv + +# Ensure SSH access. Not needed for Runpod but is required on Vast and other Docker hosts +EXPOSE 22/tcp + +# Install misc Python & CUDA Libraries +RUN apt-get update -y && apt-get install -y python3 python3-pip libcudnn8 libcudnn8-dev +RUN python3 -m pip install pip --upgrade + +# HF +ARG HUGGING_FACE_HUB_TOKEN +ENV HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN +ENV HF_HOME=/workspace/huggingface + +RUN pip3 install "huggingface_hub[cli]" + +RUN huggingface-cli login --token "$HUGGING_FACE_HUB_TOKEN" --add-to-git-credential + +# WanDB +ARG WANDB_TOKEN +ENV WANDB_TOKEN=$WANDB_TOKEN + +RUN pip3 install wandb + +RUN wandb login "$WANDB_TOKEN" + +# Clone SimpleTuner +RUN git clone https://github.com/bghira/SimpleTuner --branch release +# RUN git clone https://github.com/bghira/SimpleTuner --branch main # Uncomment to use latest (possibly unstable) version + +# Install SimpleTuner +RUN pip3 install poetry +RUN cd SimpleTuner && python3 -m venv .venv && poetry install --no-root +RUN chmod +x SimpleTuner/train_sdxl.sh +RUN chmod +x SimpleTuner/train_sd2x.sh + +# Copy start script with exec permissions +COPY --chmod=755 docker-start.sh /start.sh + +# Dummy entrypoint +ENTRYPOINT [ "/start.sh" ] diff --git a/INSTALL.md b/INSTALL.md index 0dc8814e..067fbe30 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -1,5 +1,7 @@ ## Setup +For users that wish to make use of Docker or another container orchestration platform, see [this document](/documentation/DOCKER.md) first. + 1. Clone the repository and install the dependencies: ```bash diff --git a/OPTIONS.md b/OPTIONS.md index 48f84468..6a67b1fe 100644 --- a/OPTIONS.md +++ b/OPTIONS.md @@ -166,9 +166,13 @@ This guide provides a user-friendly breakdown of the command-line options availa This is a basic overview meant to help you get started. For a complete list of options and more detailed explanations, please refer to the full specification: ``` -usage: train_sdxl.py [-h] [--snr_gamma SNR_GAMMA] [--model_type {full,lora}] - [--lora_type {Standard}] [--lora_rank LORA_RANK] - [--lora_alpha LORA_ALPHA] [--lora_dropout LORA_DROPOUT] +usage: train_sdxl.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] + [--soft_min_snr_sigma_data SOFT_MIN_SNR_SIGMA_DATA] + [--model_type {full,lora,deepfloyd-full,deepfloyd-lora,deepfloyd-stage2,deepfloyd-stage2-lora}] + [--lora_type {Standard}] + [--lora_init_type {default,gaussian,loftq}] + [--lora_rank LORA_RANK] [--lora_alpha LORA_ALPHA] + [--lora_dropout LORA_DROPOUT] --pretrained_model_name_or_path PRETRAINED_MODEL_NAME_OR_PATH [--pretrained_vae_model_name_or_path PRETRAINED_VAE_MODEL_NAME_OR_PATH] @@ -204,10 +208,12 @@ usage: train_sdxl.py [-h] [--snr_gamma SNR_GAMMA] [--model_type {full,lora}] [--seed_for_each_device SEED_FOR_EACH_DEVICE] [--resolution RESOLUTION] [--resolution_type {pixel,area}] + [--aspect_bucket_rounding {1,2,3,4,5,6,7,8,9}] [--minimum_image_size MINIMUM_IMAGE_SIZE] [--maximum_image_size MAXIMUM_IMAGE_SIZE] [--target_downsample_size TARGET_DOWNSAMPLE_SIZE] [--train_text_encoder] + [--tokenizer_max_length TOKENIZER_MAX_LENGTH] [--train_batch_size TRAIN_BATCH_SIZE] [--num_train_epochs NUM_TRAIN_EPOCHS] [--max_train_steps MAX_TRAIN_STEPS] @@ -252,6 +258,8 @@ usage: train_sdxl.py [-h] [--snr_gamma SNR_GAMMA] [--model_type {full,lora}] [--validation_negative_prompt VALIDATION_NEGATIVE_PROMPT] [--num_validation_images NUM_VALIDATION_IMAGES] [--validation_steps VALIDATION_STEPS] + [--num_eval_images NUM_EVAL_IMAGES] + [--eval_dataset_id EVAL_DATASET_ID] [--validation_num_inference_steps VALIDATION_NUM_INFERENCE_STEPS] [--validation_resolution VALIDATION_RESOLUTION] [--validation_noise_scheduler {ddim,ddpm,euler,euler-a,unipc}] @@ -291,7 +299,14 @@ options: SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. More details here: https://arxiv.org/abs/2303.09556. - --model_type {full,lora} + --use_soft_min_snr If set, will use the soft min SNR calculation method. + This method uses the sigma_data parameter. If not + provided, the method will raise an error. + --soft_min_snr_sigma_data SOFT_MIN_SNR_SIGMA_DATA + The standard deviation of the data used in the soft + min weighting method. This is required when using the + soft min SNR calculation method. + --model_type {full,lora,deepfloyd-full,deepfloyd-lora,deepfloyd-stage2,deepfloyd-stage2-lora} The training type to use. 'full' will train the full model, while 'lora' will train the LoRA model. LoRA is a smaller model that can be used for faster training. @@ -300,6 +315,16 @@ options: a different type of LoRA to train here. Currently, only 'Standard' type is supported. This option exists for compatibility with Kohya configuration files. + --lora_init_type {default,gaussian,loftq} + The initialization type for the LoRA model. 'default' + will use Microsoft's initialization method, 'gaussian' + will use a Gaussian scaled distribution, and 'loftq' + will use LoftQ initialization. In short experiments, + 'default' produced accurate results earlier in + training, 'gaussian' had slightly more creative + outputs, and LoftQ produces an entirely different + result with worse quality at first, taking potentially + longer to converge than the other methods. --lora_rank LORA_RANK The dimension of the LoRA update matrices. --lora_alpha LORA_ALPHA @@ -518,6 +543,13 @@ options: resized to the resolution by pixel edge. If 'area', the images will be resized so the pixel area is this many megapixels. + --aspect_bucket_rounding {1,2,3,4,5,6,7,8,9} + The number of decimal places to round the aspect ratio + to. This is used to create buckets for aspect ratios. + For higher precision, ensure the image sizes remain + compatible. Higher precision levels result in a + greater number of buckets, which may not be a + desirable outcome. --minimum_image_size MINIMUM_IMAGE_SIZE The minimum resolution for both sides of input images. If --delete_unwanted_images is set, images smaller @@ -545,6 +577,9 @@ options: cropping to 1 megapixel. --train_text_encoder (SD 2.x only) Whether to train the text encoder. If set, the text encoder should be float32 precision. + --tokenizer_max_length TOKENIZER_MAX_LENGTH + The maximum length of the tokenizer. If not set, will + default to the tokenizer's max length. --train_batch_size TRAIN_BATCH_SIZE Batch size (per device) for the training dataloader. --num_train_epochs NUM_TRAIN_EPOCHS @@ -658,7 +693,10 @@ options: --adam_bfloat16 Whether or not to use stochastic bf16 in Adam. Currently the only supported optimizer. --max_grad_norm MAX_GRAD_NORM - Max gradient norm. + Clipping the max gradient norm can help prevent + exploding gradients, but may also harm training by + introducing artifacts or making it hard to train + artifacts away. --push_to_hub Whether or not to push the model to the Hub. --hub_token HUB_TOKEN The token to use to push to the Model Hub. Do not use @@ -719,6 +757,15 @@ options: running the prompt `args.validation_prompt` multiple times: `args.num_validation_images` and logging the images. + --num_eval_images NUM_EVAL_IMAGES + If possible, this many eval images will be selected + from each dataset. This is used when training super- + resolution models such as DeepFloyd Stage II, which + will upscale input images from the training set. + --eval_dataset_id EVAL_DATASET_ID + When provided, only this dataset's images will be used + as the eval set, to keep the training and eval images + split. --validation_num_inference_steps VALIDATION_NUM_INFERENCE_STEPS The default scheduler, DDIM, benefits from more steps. UniPC can do well with just 10-15. For more speed diff --git a/README.md b/README.md index 17ebccb9..d4c91d8b 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,7 @@ For memory-constrained systems, see the [DeepSpeed document](/documentation/DEEP - Optional EMA (Exponential moving average) weight network to counteract model overfitting and improve training stability. **Note:** This does not apply to LoRA. - Support for a variety of image sizes and aspect ratios, enabling widescreen and portrait training on SDXL and SD 2.x. - Train directly from an S3-compatible storage provider, eliminating the requirement for expensive local storage. (Tested with Cloudflare R2 and Wasabi S3) +- DeepFloyd stage I and II full u-net or parameter-efficient fine-tuning via LoRA using 22G VRAM ### Stable Diffusion 2.0/2.1 diff --git a/TUTORIAL.md b/TUTORIAL.md index 015b4671..63fc3e61 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -221,6 +221,8 @@ Here's a breakdown of what each environment variable does: - Optionally, a user prompt library or the built-in prompt library may be used to generate more than 84 images on each checkpoint across a large number of concepts. - See `--user_prompt_library` for more information. + For DeepFloyd, a page is maintained with specific options to set. Visit [this document](/documentation/DEEPFLOYD.md) for a head start. + #### Data Locations - `BASE_DIR`, `INSTANCE_DIR`, `OUTPUT_DIR`: Directories for the training data, instance data, and output models. diff --git a/docker-start.sh b/docker-start.sh new file mode 100644 index 00000000..328dae53 --- /dev/null +++ b/docker-start.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +# Export useful ENV variables, including all Runpod specific vars, to /etc/rp_environment +# This file can then later be sourced in a login shell +echo "Exporting environment variables..." +printenv | + grep -E '^RUNPOD_|^PATH=|^HF_HOME=|^HUGGING_FACE_HUB_TOKEN=|^_=' | + sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >>/etc/rp_environment + +# Add it to Bash login script +echo 'source /etc/rp_environment' >>~/.bashrc + +# Vast.ai uses $SSH_PUBLIC_KEY +if [[ $SSH_PUBLIC_KEY ]]; then + PUBLIC_KEY="${SSH_PUBLIC_KEY}" +fi + +# Runpod uses $PUBLIC_KEY +if [[ $PUBLIC_KEY ]]; then + mkdir -p ~/.ssh + chmod 700 ~/.ssh + echo "${PUBLIC_KEY}" >>~/.ssh/authorized_keys + chmod 700 -R ~/.ssh +fi + +# Start SSH server +service ssh start + +# 🫡 +sleep infinity diff --git a/documentation/DEEPFLOYD.md b/documentation/DEEPFLOYD.md new file mode 100644 index 00000000..95dd4521 --- /dev/null +++ b/documentation/DEEPFLOYD.md @@ -0,0 +1,237 @@ +# DeepFloyd IF + +> ⚠️ Support for tuning DeepFloyd-IF is in its initial support stages. Not all SimpleTuner features are currently available for this model. + +> 🤷🏽‍♂️ Training DeepFloyd requires at least 24G VRAM for a LoRA. This guide focuses on the 400M parameter base model, though the 4.3B XL flavour can be trained using the same guidelines. + +## Background + +In spring of 2023, StabilityAI released a cascaded pixel diffusion model called DeepFloyd. +![](https://tripleback.net/public/deepfloyd.png) + +Comparing briefly to Stable Diffusion XL: +- Text encoder + - SDXL uses two CLIP encoders, "OpenCLIP G/14" and "OpenAI CLIP-L/14" + - DeepFloyd uses a single self-supervised transformer model, Google's T5 XXL +- Parameter count + - DeepFloyd comes in multiple flavours of density: 400M, 900M, and 4.3B parameters. Each larger unit is successively more expensive to train. + - SDXL has just one, ~3B parameters. + - DeepFloyd's text encoder has 11B parameters in it alone, making the fattest configuration roughly 15.3B parameters. +- Model count + - DeepFloyd runs in **three** stages: 64px -> 256px -> 1024px + - Each stage fully completes its denoising objective + - SDXL runs in **two** stages, including its refiner, from 1024px -> 1024px + - Each stage only partly completes its denoising objective +- Design + - DeepFloyd's three models increase resolution and fine details + - SDXL's two models manage fine details and composition + +For both models, the first stage defines most of the image's composition (where large items / shadows appear). + +## Model assessment + +Here's what you can expect when using DeepFloyd for training or inference. + +### Aesthetics + +When compared to SDXL or Stable Diffusion 1.x/2.x, DeepFloyd's aesthetics lie somewhere between Stable Diffusion 2.x and SDXL. + + +### Disadvantages + +This is not a popular model, for various reasons: + +- Inference-time compute VRAM requirement is heavier than other models +- Training-time compute VRAM requirements dwarf other models + - A full u-net tune needing more than 48G VRAM + - LoRA at rank-32, batch-4 needs ~24G VRAM + - The text embed cache objects are ENORMOUS (multiple Megabytes each, vs hundreds of Kilobytes for SDXL's dual CLIP embeds) + - The text embed cache objects are SLOW TO CREATE, about 9-10 per second currently on an A6000 non-Ada. +- The default aesthetic is worse than other models (like trying to train vanilla SD 1.5) +- There's **three** models to finetune or load onto your system during inference (four if you count the text encoder) +- The promises from StabilityAI did not meet the reality of what it felt like to use the model (over-hyped) +- The DeepFloyd-IF license is restrictive against commercial use. + - This didn't impact the NovelAI weights, which were in fact leaked illicitly. The commercial license nature seems like a convenient excuse, considering the other, bigger issues. + +### Advantages + +However, DeepFloyd really has its upsides that often go overlooked: + +- At inference time, the T5 text encoder demonstrates a strong understanding of the world +- Can be natively trained on very-long captions +- The first stage is ~64x64 pixel area, and can be trained on multi-aspect resolutions + - The low-resolution nature of the training data means DeepFloyd was _the only model_ capable of training on _ALL_ of LAION-A (few images are under 64x64 in LAION) +- Each stage can be tuned independently, focusing on different objectives + - The first stage can be tuned focusing on compositional qualities, and the later stages are tuned for better upscaled details +- It trains very quickly despite its larger training memory footprint + - Trains quicker in terms of throughput - a high samples per hour rate is observed on stage 1 tuning + - Learns more quickly than a CLIP equivalent model, perhaps to the detriment of people used to training CLIP models + - In other words, you will have to adjust your expectations of learning rates and training schedules +- There is no VAE, the training samples are directly downscaled into their target size and the pixels are consumed by the U-net +- It supports ControlNet LoRAs and many other tricks that work on typical linear CLIP u-nets. + +## Fine-tuning a LoRA + +> ⚠️ Due to the compute requirements of full u-net backpropagation in even DeepFloyd's smallest 400M model, it has not been tested. LoRA will be used for this document, though full u-net tuning should also work. + +Training DeepFloyd makes use of the "legacy" SD 1.x/2.x trainer in SimpleTuner to reduce code duplication by keeping similar models together. + +As such, we'll be making use of the `sd2x-env.sh` configuration file for tuning DeepFloyd: + +### sd2x-env.sh + +```bash +# Possible values: +# - deepfloyd-full +# - deepfloyd-lora +# - deepfloyd-stage2 +# - deepfloyd-stage2-lora +export MODEL_TYPE="deepfloyd-lora" + +# DoRA isn't tested a whole lot yet. It's still new and experimental. +export USE_DORA=false +# Bitfit hasn't been tested for efficacy on DeepFloyd. +# It will probably work, but no idea what the outcome is. +export USE_BITFIT=false + +# Highest learning rate to use. +export LEARNING_RATE=4e-5 #@param {type:"number"} +# For schedules that decay or oscillate, this will be the end LR or the bottom of the valley. +export LEARNING_RATE_END=4e-6 #@param {type:"number"} + +## Using a Huggingface Hub model for Stage 1 tuning: +#export MODEL_NAME="DeepFloyd/IF-I-M-v1.0" +## Using a Huggingface Hub model for Stage 2 tuning: +#export MODEL_NAME="DeepFloyd/IF-II-M-v1.0" +# Using a local path to a huggingface hub model or saved checkpoint: +#export MODEL_NAME="/notebooks/datasets/models/pipeline" + +# Where to store your results. +export BASE_DIR="/training" +export INSTANCE_DIR="${BASE_DIR}/data" +export OUTPUT_DIR="${BASE_DIR}/models/deepfloyd" +export DATALOADER_CONFIG="multidatabackend_deepfloyd.json" + +# Max number of steps OR epochs can be used. But we default to Epochs. +export MAX_NUM_STEPS=50000 +export NUM_EPOCHS=0 + +# Adjust this for your GPU memory size. +export TRAIN_BATCH_SIZE=1 + +# "pixel" is using pixel edge length on the smaller or square side of the image. +# this is how DeepFloyd was originally trained. +export RESOLUTION_TYPE="pixel" +export RESOLUTION=64 # 1.0 Megapixel training sizes + +# Validation is when the model is used during training to make test outputs. +export VALIDATION_RESOLUTION=96x64 # The resolution of the validation images. Default: 64x64 +export VALIDATION_STEPS=250 # How long between each validation run. Default: 250 +export VALIDATION_NUM_INFERENCE_STEPS=25 # How many inference steps to do. Default: 25 +export VALIDATION_PROMPT="an ethnographic photograph of a teddy bear at a picnic" # What to make for the first/only test image. +export VALIDATION_NEGATIVE_PROMPT="blurry, ugly, cropped, amputated" # What to avoid in the first/only test image. + +# These can be left alone. +export VALIDATION_GUIDANCE=7.5 +export VALIDATION_GUIDANCE_RESCALE=0.0 +export VALIDATION_SEED=42 + +export GRADIENT_ACCUMULATION_STEPS=1 # Accumulate over many steps. Default: 1 +export MIXED_PRECISION="bf16" # SimpleTuner requires bf16. +export PURE_BF16=true # Will not use mixed precision, but rather pure bf16 (bf16 requires pytorch 2.3 on MPS.) +export OPTIMIZER="adamw_bf16" +export USE_XFORMERS=true +``` + +A keen eye will have observed the following: + +- The `MODEL_TYPE` is specified as deepfloyd-compatible +- The `MODEL_NAME` is pointing to Stage I or II +- `RESOLUTION` is now `64` and `RESOLUTION_TYPE` is `pixel` +- `USE_XFORMERS` is set to `true`, but AMD and Apple users won't be able to set this, requiring more VRAM. + - **Note** Apple MPS currently has a bug preventing DeepFloyd tuning from working at all. + +For more thorough validations, the value for `VALIDATION_RESOLUTION` can be set as: + +- `VALIDATION_RESOLUTION=64` will result in a 64x64 square image. +- `VALIDATION_RESOLUTION=96x64` will result in a 3:2 widescreen image. +- `VALIDATION_RESOLUTION=64,96,64x96,96x64` will result in four images being generated for each validation: + - 64x64 + - 96x96 + - 64x96 + - 96x64 + +### multidatabackend_deepfloyd.json + +Now let's move onto configuring the dataloader for DeepFloyd training. This will be nearly identical to configuration of SDXL or legacy model datasets, with a focus on resolution parameters. + +```json +[ + { + "id": "primary-dataset", + "type": "local", + "instance_data_dir": "/training/data/primary-dataset", + "crop": true, + "crop_aspect": "square", + "crop_style": "random", + "resolution": 64, + "resolution_type": "pixel", + "minimum_image_size": 64, + "maximum_image_size": 256, + "target_downsample_size": 128, + "prepend_instance_prompt": false, + "instance_prompt": "Your Subject Trigger Phrase or Word", + "caption_strategy": "instanceprompt", + "repeats": 1 + }, + { + "id": "an example backend for text embeds.", + "dataset_type": "text_embeds", + "default": true, + "disable": false, + "type": "local", + "cache_dir": "/training/cache/deepfloyd/text/dreambooth" + } +] +``` + +Provided above is a basic Dreambooth configuration for DeepFloyd: + +- The values for `resolution` and `resolution_type` are set to `64` and `pixel`, respectively +- The value for `minimum_image_size` is reduced to 64 pixels to ensure we don't accidentally upsample any smaller images +- The value for `maximum_image_size` is set to 256 pixels to ensure that any large images do not become cropped at a ratio of more than 4:1, which may result in catastrophic scene context loss +- The value for `target_downsample_size` is set to 128 pixels so that any images larger than `maximum_image_size` of 256 pixels are first resized to 128 pixels before cropping + +Note: images are downsampled 25% at a time so to avoid extreme leaps in image size causing an incorrect averaging of the scene's details. + +## Running inference + +Currently, DeepFloyd does not have any dedicated inference scripts in the SimpleTuner toolkit. + +Other than the built-in validations process, you may want to reference [this document from Hugging Face](https://huggingface.co/docs/diffusers/v0.23.1/en/training/dreambooth#if) which contains a small example for running inference afterward: + +```py +from diffusers import DiffusionPipeline + +pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-M-v1.0", use_safetensors=True) +pipe.load_lora_weights("") +pipe.scheduler = pipe.scheduler.__class__.from_config(pipe.scheduler.config, variance_type="fixed_small") +``` + +> ⚠️ Note that the first value for `DiffusionPipeline.from_pretrained(...)` is set to `IF-I-M-v1.0`, but you must update this to use the base model path that you trained your LoRA on. + +> ⚠️ Note that not all of the recommendations from Hugging Face apply to SimpleTuner. For example, we can tune DeepFloyd stage I LoRA in just 22G of VRAM vs 28G for Diffusers' example dreambooth scripts thanks to efficient pre-caching and pure-bf16 optimiser states. 8Bit AdamW isn't currently supported by SimpleTuner. + +## Fine-tuning the super-resolution stage II model + +DeepFloyd's stage II model takes inputs around 64x64 (or 96x64) images, and returns the resulting upscaled image using the `VALIDATION_RESOLUTION` setting. + +The eval images are automatically collected from your datasets, such that `--num_eval_images` will specify how many upscale images to select from each dataset. The images are currently selected at random - but they'll remain the same on each session. + +A few more checks are in place to ensure you don't accidentally run with the incorrect sizes set. + +To train stage II, you just need to follow the steps above, using `deepfloyd-stage2-lora` in place of `deepfloyd-lora` for `MODEL_TYPE`: + +```bash +export MODEL_TYPE="deepfloyd-stage2-lora" +``` \ No newline at end of file diff --git a/documentation/DOCKER.md b/documentation/DOCKER.md new file mode 100644 index 00000000..0b39c7c3 --- /dev/null +++ b/documentation/DOCKER.md @@ -0,0 +1,115 @@ +# Docker for SimpleTuner + +This Docker configuration provides a comprehensive environment for running the SimpleTuner application on various platforms including Runpod, Vast.ai, and other Docker-compatible hosts. It is optimized for ease of use and robustness, integrating tools and libraries essential for machine learning projects. + +## Container Features + +- **CUDA-enabled Base Image**: Built from `nvidia/cuda:11.8.0-runtime-ubuntu22.04` to support GPU-accelerated applications. +- **Development Tools**: Includes Git, SSH, and various utilities like `tmux`, `vim`, `htop`. +- **Python and Libraries**: Comes with Python 3.10 and essential libraries like `poetry` for Python package management. +- **Huggingface and WandB Integration**: Pre-configured for seamless integration with Huggingface Hub and WandB, facilitating model sharing and experiment tracking. + +## Getting Started + +### 1. Building the Container + +Clone the repository and navigate to the directory containing the Dockerfile. Build the Docker image using: + +```bash +docker build -t simpletuner . +``` + +### 2. Running the Container + +To run the container with GPU support, execute: + +```bash +docker run --gpus all -it -p 22:22 simpletuner +``` + +This command sets up the container with GPU access and maps the SSH port for external connectivity. + +### 3. Environment Variables + +To facilitate integration with external tools, the container supports environment variables for Huggingface and WandB tokens. Pass these at runtime as follows: + +```bash +docker run --gpus all -e HUGGING_FACE_HUB_TOKEN='your_token' -e WANDB_TOKEN='your_token' -it -p 22:22 simpletuner +``` + +### 4. Data Volumes + +For persistent storage and data sharing between the host and the container, mount a data volume: + +```bash +docker run --gpus all -v /path/on/host:/workspace -it -p 22:22 simpletuner +``` + +### 5. SSH Access + +SSH into the container is configured by default. Ensure you provide your SSH public key through the appropriate environment variable (`SSH_PUBLIC_KEY` for Vast.ai or `PUBLIC_KEY` for Runpod). + +### 6. Using SimpleTuner + +Navigate to the SimpleTuner directory, activate the Python virtual environment, and start using or developing the application: + +```bash +cd SimpleTuner +source .venv/bin/activate +``` + +Run training scripts or other provided utilities directly within this environment. + +## Additional Configuration + +### Custom Scripts and Configurations + +If you want to add custom startup scripts or modify configurations, extend the entry script (`docker-start.sh`) to fit your specific needs. + +If any capabilities cannot be achieved through this setup, please open a new issue. + +--- + +## Troubleshooting + +### CUDA Version Mismatch + +**Symptom**: The application fails to utilize the GPU, or errors related to CUDA libraries appear when attempting to run GPU-accelerated tasks. + +**Cause**: This issue may occur if the CUDA version installed within the Docker container does not match the CUDA driver version available on the host machine. + +**Solution**: +1. **Check CUDA Driver Version on Host**: Determine the version of the CUDA driver installed on the host machine by running: + ```bash + nvidia-smi + ``` + This command will display the CUDA version at the top right of the output. + +2. **Match Container CUDA Version**: Ensure that the version of the CUDA toolkit in your Docker image is compatible with the host's CUDA driver. NVIDIA generally allows forward compatibility but check the specific compatibility matrix on the NVIDIA website. + +3. **Rebuild the Image**: If necessary, modify the base image in the Dockerfile to match the host’s CUDA driver. For example, if your host runs CUDA 11.2 and the container is set up for CUDA 11.8, you might need to switch to an appropriate base image: + ```Dockerfile + FROM nvidia/cuda:11.2.0-runtime-ubuntu22.04 + ``` + After modifying the Dockerfile, rebuild the Docker image. + +### SSH Connection Issues + +**Symptom**: Unable to connect to the container via SSH. + +**Cause**: Misconfiguration of SSH keys or the SSH service not starting correctly. + +**Solution**: +1. **Check SSH Configuration**: Ensure that the public SSH key is correctly added to `~/.ssh/authorized_keys` in the container. Also, verify that the SSH service is up and running by entering the container and executing: + ```bash + service ssh status + ``` +2. **Exposed Ports**: Confirm that the SSH port (22) is properly exposed and mapped when starting the container, as shown in the running instructions: + ```bash + docker run --gpus all -it -p 22:22 simpletuner + ``` + +### General Advice + +- **Logs and Output**: Review the container logs and output for any error messages or warnings that can provide more context on the issue. +- **Documentation and Forums**: Consult the Docker and NVIDIA CUDA documentation for more detailed troubleshooting advice. Community forums and issue trackers related to the specific software or dependencies you are using can also be valuable resources. \ No newline at end of file diff --git a/documentation/DREAMBOOTH.md b/documentation/DREAMBOOTH.md index 6cf6d419..baa4e008 100644 --- a/documentation/DREAMBOOTH.md +++ b/documentation/DREAMBOOTH.md @@ -34,8 +34,11 @@ The model contains something called a "prior" which could, in theory, be preserv Following the [tutorial](/TUTORIAL.md) is required before you can continue into Dreambooth-specific configuration. -Recommended configuration values for `sdxl-env.sh` or `sd2x-env.sh`: +For DeepFloyd tuning, it's recommended to visit [this page](/documentation/DEEPFLOYD.md) for specific tips related to that model's setup. +For Stable Diffusion 1.x/2.x/XL, here are recommended configuration values. + +Located in `sdxl-env.sh` or `sd2x-env.sh`: ```bash TRAIN_BATCH_SIZE=1 @@ -47,7 +50,7 @@ LR_WARMUP_STEPS=100 OPTIMIZER=adamw_bf16 MAX_NUM_STEPS=1000 -NUM_EPOCHS=25 +NUM_EPOCHS=0 VALIDATION_STEPS=100 VALIDATION_PROMPT="a photograph of subjectname" diff --git a/helpers/arguments.py b/helpers/arguments.py index 4b5db480..5ab45bd9 100644 --- a/helpers/arguments.py +++ b/helpers/arguments.py @@ -29,10 +29,34 @@ def parse_args(input_args=None): " More details here: https://arxiv.org/abs/2303.09556." ), ) + parser.add_argument( + "--use_soft_min_snr", + action="store_true", + help=( + "If set, will use the soft min SNR calculation method. This method uses the sigma_data parameter." + " If not provided, the method will raise an error." + ), + ) + parser.add_argument( + "--soft_min_snr_sigma_data", + default=None, + type=float, + help=( + "The standard deviation of the data used in the soft min weighting method." + " This is required when using the soft min SNR calculation method." + ), + ) parser.add_argument( "--model_type", type=str, - choices=["full", "lora"], + choices=[ + "full", + "lora", + "deepfloyd-full", + "deepfloyd-lora", + "deepfloyd-stage2", + "deepfloyd-stage2-lora", + ], default="full", help=( "The training type to use. 'full' will train the full model, while 'lora' will train the LoRA model." @@ -492,6 +516,15 @@ def parse_args(input_args=None): action="store_true", help="(SD 2.x only) Whether to train the text encoder. If set, the text encoder should be float32 precision.", ) + # DeepFloyd + parser.add_argument( + "--tokenizer_max_length", + type=int, + default=None, + required=False, + help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", + ) + # End DeepFloyd-specific settings parser.add_argument( "--train_batch_size", type=int, @@ -868,6 +901,25 @@ def parse_args(input_args=None): " and logging the images." ), ) + parser.add_argument( + "--num_eval_images", + type=int, + default=4, + help=( + "If possible, this many eval images will be selected from each dataset." + " This is used when training super-resolution models such as DeepFloyd Stage II," + " which will upscale input images from the training set." + ), + ) + parser.add_argument( + "--eval_dataset_id", + type=str, + default=None, + help=( + "When provided, only this dataset's images will be used as the eval set, to keep" + " the training and eval images split." + ), + ) parser.add_argument( "--validation_num_inference_steps", type=int, @@ -880,7 +932,7 @@ def parse_args(input_args=None): ) parser.add_argument( "--validation_resolution", - type=float, + type=str, default=256, help="Square resolution images will be output at this resolution (256x256).", ) @@ -1281,42 +1333,44 @@ def parse_args(input_args=None): ) sys.exit(1) - if args.cache_dir_vae is None or args.cache_dir_vae == "": - args.cache_dir_vae = os.path.join(args.output_dir, "cache_vae") - if args.cache_dir_text is None or args.cache_dir_text == "": - args.cache_dir_text = os.path.join(args.output_dir, "cache_text") - for target_dir in [ - Path(args.cache_dir), - Path(args.cache_dir_vae), - Path(args.cache_dir_text), - ]: - os.makedirs(target_dir, exist_ok=True) - if ( args.pretrained_vae_model_name_or_path is not None and StateTracker.get_model_type() == "legacy" and "sdxl" in args.pretrained_vae_model_name_or_path + and "deepfloyd" not in args.model_type ): logger.error( f"The VAE model {args.pretrained_vae_model_name_or_path} is not compatible with SD 2.x. Please use a 2.x VAE to eliminate this error." ) args.pretrained_vae_model_name_or_path = None - logger.info( - f"VAE Model: {args.pretrained_vae_model_name_or_path or args.pretrained_model_name_or_path}" - ) - logger.info(f"Default VAE Cache location: {args.cache_dir_vae}") - logger.info(f"Text Cache location: {args.cache_dir_text}") + if "deepfloyd" not in args.model_type: + logger.info( + f"VAE Model: {args.pretrained_vae_model_name_or_path or args.pretrained_model_name_or_path}" + ) + logger.info(f"Default VAE Cache location: {args.cache_dir_vae}") + logger.info(f"Text Cache location: {args.cache_dir_text}") - if args.validation_resolution < 128: + if "deepfloyd-stage2" in args.model_type and args.resolution < 256: + logger.warning( + "DeepFloyd Stage II requires a resolution of at least 256. Setting to 256." + ) + args.resolution = 256 + args.resolution_type = "pixel" + + if ( + args.validation_resolution.isdigit() + and int(args.validation_resolution) < 128 + and "deepfloyd" not in args.model_type + ): # Convert from megapixels to pixels: log_msg = f"It seems that --validation_resolution was given in megapixels ({args.validation_resolution}). Converting to pixel measurement:" - if args.validation_resolution == 1: + if int(args.validation_resolution) == 1: args.validation_resolution = 1024 else: - args.validation_resolution = int(args.validation_resolution * 1e3) + args.validation_resolution = int(int(args.validation_resolution) * 1e3) # Make it divisible by 8: - args.validation_resolution = int(args.validation_resolution / 8) * 8 - logger.info(f"{log_msg} {args.validation_resolution}px") + args.validation_resolution = int(int(args.validation_resolution) / 8) * 8 + logger.info(f"{log_msg} {int(args.validation_resolution)}px") if args.timestep_bias_portion < 0.0 or args.timestep_bias_portion > 1.0: raise ValueError("Timestep bias portion must be between 0.0 and 1.0.") diff --git a/helpers/caching/sdxl_embeds.py b/helpers/caching/sdxl_embeds.py index e5620a9d..373acd62 100644 --- a/helpers/caching/sdxl_embeds.py +++ b/helpers/caching/sdxl_embeds.py @@ -262,6 +262,47 @@ def encode_prompt(self, prompt: str, is_validation: bool = False): self.text_encoders[0], self.tokenizers[0], prompt ) + def tokenize_deepfloyd_prompt(self, prompt, tokenizer_max_length=None): + if tokenizer_max_length is not None: + max_length = tokenizer_max_length + else: + max_length = self.tokenizers[0].model_max_length + + text_inputs = self.tokenizers[0]( + prompt, + truncation=True, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + + return text_inputs + + def encode_deepfloyd_prompt(self, input_ids, attention_mask): + text_input_ids = input_ids.to(self.text_encoders[0].device) + attention_mask = attention_mask.to(self.text_encoders[0].device) + prompt_embeds = self.text_encoders[0]( + text_input_ids, + attention_mask=attention_mask, + return_dict=False, + ) + prompt_embeds = prompt_embeds[0].to("cpu") + + return prompt_embeds + + def compute_deepfloyd_prompt(self, prompt: str): + logger.debug(f"Computing deepfloyd prompt for: {prompt}") + text_inputs = self.tokenize_deepfloyd_prompt( + prompt, tokenizer_max_length=StateTracker.get_args().tokenizer_max_length + ) + result = self.encode_deepfloyd_prompt( + text_inputs.input_ids, + text_inputs.attention_mask, + ) + del text_inputs + + return result + def compute_embeddings_for_prompts( self, all_prompts, @@ -515,10 +556,15 @@ def compute_embeddings_for_legacy_prompts( while self.write_queue.qsize() > 100: logger.debug(f"Waiting for write thread to catch up.") time.sleep(5) - prompt_embeds = self.encode_legacy_prompt( - self.text_encoders[0], self.tokenizers[0], [prompt] - ) - prompt_embeds = prompt_embeds.to(self.accelerator.device) + if "deepfloyd" in StateTracker.get_args().model_type: + # TODO: Batch this + prompt_embeds = self.compute_deepfloyd_prompt(prompt) + else: + prompt_embeds = self.encode_legacy_prompt( + self.text_encoders[0], self.tokenizers[0], [prompt] + ) + if return_concat: + prompt_embeds = prompt_embeds.to(self.accelerator.device) self.save_to_cache(filename, prompt_embeds) prompt_embeds_all.append(prompt_embeds) diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index 87285360..22c1d04d 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -98,6 +98,7 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict: maximum_image_size and output["config"]["resolution_type"] == "pixel" and maximum_image_size < 512 + and "deepfloyd" not in args.model_type ): raise ValueError( f"When a data backend is configured to use `'resolution_type':pixel`, `maximum_image_size` must be at least 512 pixels. You may have accidentally entered {maximum_image_size} megapixels, instead of pixels." @@ -115,6 +116,7 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict: target_downsample_size and output["config"]["resolution_type"] == "pixel" and target_downsample_size < 512 + and "deepfloyd" not in args.model_type ): raise ValueError( f"When a data backend is configured to use `'resolution_type':pixel`, `target_downsample_size` must be at least 512 pixels. You may have accidentally entered {target_downsample_size} megapixels, instead of pixels." @@ -434,13 +436,14 @@ def configure_multi_databackend( ), cache_file=os.path.join( init_backend["instance_data_root"], - "aspect_ratio_bucket_indices.json", + "aspect_ratio_bucket_indices", ), metadata_file=os.path.join( init_backend["instance_data_root"], - "aspect_ratio_bucket_metadata.json", + "aspect_ratio_bucket_metadata", ), delete_problematic_images=args.delete_problematic_images or False, + cache_file_suffix=backend.get("cache_file_suffix", None), **metadata_backend_args, ) @@ -524,6 +527,31 @@ def configure_multi_databackend( datasets=[init_backend["metadata_backend"]], ) + if "deepfloyd" in args.model_type: + if init_backend["metadata_backend"].resolution_type == "area": + logger.warning( + "Resolution type is 'area', but should be 'pixel' for DeepFloyd. Unexpected results may occur." + ) + if init_backend["metadata_backend"].resolution > 0.25: + logger.warning( + "Resolution is greater than 0.25 megapixels. This may lead to unconstrained memory requirements." + ) + if init_backend["metadata_backend"].resolution_type == "pixel": + if ( + "stage2" not in args.model_type + and init_backend["metadata_backend"].resolution > 64 + ): + logger.warning( + "Resolution is greater than 64 pixels, which will possibly lead to poor quality results." + ) + + if "deepfloyd-stage2" in args.model_type: + # Resolution must be at least 256 for Stage II. + if init_backend["metadata_backend"].resolution < 256: + logger.warning( + "Increasing resolution to 256, as is required for DF Stage II." + ) + init_backend["sampler"] = MultiAspectSampler( id=init_backend["id"], metadata_backend=init_backend["metadata_backend"], @@ -594,40 +622,41 @@ def configure_multi_databackend( f"(id={init_backend['id']}) Completed processing {len(captions)} captions." ) - logger.info(f"(id={init_backend['id']}) Creating VAE latent cache.") - init_backend["vaecache"] = VAECache( - id=init_backend["id"], - vae=StateTracker.get_vae(), - accelerator=accelerator, - metadata_backend=init_backend["metadata_backend"], - data_backend=init_backend["data_backend"], - instance_data_root=init_backend["instance_data_root"], - delete_problematic_images=backend.get( - "delete_problematic_images", args.delete_problematic_images - ), - resolution=backend.get("resolution", args.resolution), - resolution_type=backend.get("resolution_type", args.resolution_type), - maximum_image_size=backend.get( - "maximum_image_size", args.maximum_image_size - ), - target_downsample_size=backend.get( - "target_downsample_size", args.target_downsample_size - ), - minimum_image_size=backend.get( - "minimum_image_size", args.minimum_image_size - ), - vae_batch_size=args.vae_batch_size, - write_batch_size=args.write_batch_size, - cache_dir=backend.get("cache_dir_vae", args.cache_dir_vae), - max_workers=backend.get("max_workers", 32), - vae_cache_preprocess=args.vae_cache_preprocess, - ) + if "deepfloyd" not in StateTracker.get_args().model_type: + logger.info(f"(id={init_backend['id']}) Creating VAE latent cache.") + init_backend["vaecache"] = VAECache( + id=init_backend["id"], + vae=StateTracker.get_vae(), + accelerator=accelerator, + metadata_backend=init_backend["metadata_backend"], + data_backend=init_backend["data_backend"], + instance_data_root=init_backend["instance_data_root"], + delete_problematic_images=backend.get( + "delete_problematic_images", args.delete_problematic_images + ), + resolution=backend.get("resolution", args.resolution), + resolution_type=backend.get("resolution_type", args.resolution_type), + maximum_image_size=backend.get( + "maximum_image_size", args.maximum_image_size + ), + target_downsample_size=backend.get( + "target_downsample_size", args.target_downsample_size + ), + minimum_image_size=backend.get( + "minimum_image_size", args.minimum_image_size + ), + vae_batch_size=args.vae_batch_size, + write_batch_size=args.write_batch_size, + cache_dir=backend.get("cache_dir_vae", args.cache_dir_vae), + max_workers=backend.get("max_workers", 32), + vae_cache_preprocess=args.vae_cache_preprocess, + ) - if args.vae_cache_preprocess: - logger.info(f"(id={init_backend['id']}) Discovering cache objects..") - if accelerator.is_local_main_process: - init_backend["vaecache"].discover_all_files() - accelerator.wait_for_everyone() + if args.vae_cache_preprocess: + logger.info(f"(id={init_backend['id']}) Discovering cache objects..") + if accelerator.is_local_main_process: + init_backend["vaecache"].discover_all_files() + accelerator.wait_for_everyone() if ( ( @@ -636,6 +665,7 @@ def configure_multi_databackend( ) and accelerator.is_main_process and backend.get("scan_for_errors", False) + and "deepfloyd" not in StateTracker.get_args().model_type ): logger.info( f"Beginning error scan for dataset {init_backend['id']}. Set 'scan_for_errors' to False in the dataset config to disable this." diff --git a/helpers/legacy/sd_files.py b/helpers/legacy/sd_files.py index 0b6a12f3..3ada352a 100644 --- a/helpers/legacy/sd_files.py +++ b/helpers/legacy/sd_files.py @@ -37,6 +37,10 @@ def import_model_class_from_model_name_or_path( ) return RobertaSeriesModelWithTransformation + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel else: raise ValueError(f"{model_class} is not supported.") @@ -55,7 +59,7 @@ def save_model_hook(models, weights, output_dir): StateTracker.save_training_state( os.path.join(output_dir, "training_state.json") ) - if StateTracker.get_args().model_type == "lora": + if "lora" in StateTracker.get_args().model_type: # there are only two options here. Either are just the unet attn processor layers # or there are the unet and text encoder atten layers unet_lora_layers_to_save = None @@ -141,7 +145,7 @@ def load_model_hook(models, input_dir): else: raise ValueError(f"unexpected save model: {model.__class__}") - if args.model_type == "lora": + if "lora" in args.model_type: logger.info(f"Loading LoRA weights from Path: {input_dir}") lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) diff --git a/helpers/legacy/validation.py b/helpers/legacy/validation.py index 8c6141c0..b1f31263 100644 --- a/helpers/legacy/validation.py +++ b/helpers/legacy/validation.py @@ -2,6 +2,7 @@ from tqdm import tqdm from diffusers.utils import is_wandb_available from diffusers.utils.torch_utils import is_compiled_module +from helpers.multiaspect.image import MultiaspectImage from helpers.image_manipulation.brightness import calculate_luminance from helpers.training.state_tracker import StateTracker from helpers.training.wrappers import unwrap_model @@ -23,6 +24,42 @@ logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL") or "INFO") +def deepfloyd_validation_images(): + """ + From each data backend, collect the top 5 images for validation, such that + we select the same images on each startup, unless the dataset changes. + + Returns: + dict: A dictionary of shortname to image paths. + """ + data_backends = StateTracker.get_data_backends() + validation_data_backend_id = StateTracker.get_args().eval_dataset_id + validation_set = [] + logger.info("Collecting DF-II validation images") + for _data_backend in data_backends: + data_backend = StateTracker.get_data_backend(_data_backend) + if "id" not in data_backend: + continue + logger.info(f"Checking data backend: {data_backend['id']}") + if ( + validation_data_backend_id is not None + and data_backend["id"] != validation_data_backend_id + ): + logger.warning(f"Not collecting images from {data_backend['id']}") + continue + if "sampler" in data_backend: + validation_set.extend( + data_backend["sampler"].retrieve_validation_set( + batch_size=StateTracker.get_args().num_eval_images + ) + ) + else: + logger.warning( + f"Data backend {data_backend['id']} does not have a sampler. Skipping." + ) + return validation_set + + def prepare_validation_prompt_list(args, embed_cache): validation_negative_prompt_embeds = None validation_negative_pooled_embeds = None @@ -33,6 +70,22 @@ def prepare_validation_prompt_list(args, embed_cache): f"Embed cache engine did not contain a model_type. Cannot continue." ) model_type = embed_cache.model_type + validation_sample_images = None + if "deepfloyd-stage2" in StateTracker.get_args().model_type: + # Now, we prepare the DeepFloyd upscaler image inputs so that we can calculate their prompts. + # If we don't do it here, they won't be available at inference time. + validation_sample_images = deepfloyd_validation_images() + if len(validation_sample_images) > 0: + StateTracker.set_validation_sample_images(validation_sample_images) + # Collect the prompts for the validation images. + for _validation_sample in tqdm( + validation_sample_images, + ncols=100, + desc="Precomputing DeepFloyd stage 2 eval prompt embeds", + ): + _, validation_prompt, _ = _validation_sample + embed_cache.compute_embeddings_for_prompts([validation_prompt]) + if args.validation_prompt_library: # Use the SimpleTuner prompts library for validation prompts. from helpers.prompts import prompts as prompt_library @@ -104,6 +157,52 @@ def prepare_validation_prompt_list(args, embed_cache): ) +def parse_validation_resolution(input_str: str) -> tuple: + """ + If the args.validation_resolution: + - is an int, we'll treat it as height and width square aspect + - if it has an x in it, we will split and treat as WIDTHxHEIGHT + - if it has comma, we will split and treat each value as above + """ + if isinstance(input_str, int) or input_str.isdigit(): + if ( + "deepfloyd-stage2" in StateTracker.get_args().model_type + and int(input_str) < 256 + ): + raise ValueError( + "Cannot use less than 256 resolution for DeepFloyd stage 2." + ) + return (input_str, input_str) + if "x" in input_str: + pieces = input_str.split("x") + if "deepfloyd-stage2" in StateTracker.get_args().model_type and ( + int(pieces[0]) < 256 or int(pieces[1]) < 256 + ): + raise ValueError( + "Cannot use less than 256 resolution for DeepFloyd stage 2." + ) + return (int(pieces[0]), int(pieces[1])) + + +def get_validation_resolutions(): + """ + If the args.validation_resolution: + - is an int, we'll treat it as height and width square aspect + - if it has an x in it, we will split and treat as WIDTHxHEIGHT + - if it has comma, we will split and treat each value as above + """ + validation_resolution_parameter = StateTracker.get_args().validation_resolution + if ( + type(validation_resolution_parameter) is str + and "," in validation_resolution_parameter + ): + return [ + parse_validation_resolution(res) + for res in validation_resolution_parameter.split(",") + ] + return [parse_validation_resolution(validation_resolution_parameter)] + + def log_validations( accelerator, prompt_handler, @@ -172,7 +271,7 @@ def log_validations( vae_subfolder_path = "vae" if args.pretrained_vae_model_name_or_path is not None: vae_subfolder_path = None - if vae is None: + if vae is None and "deepfloyd" not in args.model_type: vae = AutoencoderKL.from_pretrained( vae_path, subfolder=vae_subfolder_path, @@ -197,6 +296,10 @@ def log_validations( ) elif StateTracker.get_model_type() == "legacy": pipeline_cls = DiffusionPipeline + if "deepfloyd-stage2" in args.model_type: + from diffusers.pipelines import IFSuperResolutionPipeline + + pipeline_cls = IFSuperResolutionPipeline pipeline = pipeline_cls.from_pretrained( args.pretrained_model_name_or_path, unet=unwrap_model(accelerator, unet), @@ -204,6 +307,7 @@ def log_validations( tokenizer=None, vae=vae, revision=args.revision, + safety_checker=None, torch_dtype=( torch.bfloat16 if torch.backends.mps.is_available() @@ -211,6 +315,18 @@ def log_validations( else torch.bfloat16 ), ) + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + if "deepfloyd" in args.model_type: + args.validation_noise_scheduler = "ddpm" + pipeline.scheduler = SCHEDULER_NAME_MAP[ args.validation_noise_scheduler ].from_pretrained( @@ -219,6 +335,7 @@ def log_validations( prediction_type=args.prediction_type, timestep_spacing=args.inference_scheduler_timestep_spacing, rescale_betas_zero_snr=args.rescale_betas_zero_snr, + **scheduler_args, ) if args.validation_torch_compile and not is_compiled_module(pipeline.unet): logger.warning( @@ -238,19 +355,33 @@ def log_validations( if not os.path.exists(val_save_dir): os.makedirs(val_save_dir) - validation_images = [] + validation_images = {} pipeline = pipeline.to(accelerator.device) extra_validation_kwargs = {} if not args.validation_randomize: extra_validation_kwargs["generator"] = torch.Generator( device=accelerator.device ).manual_seed(args.validation_seed or args.seed or 0) - for validation_prompt in tqdm( + _content = zip( + validation_shortnames, validation_prompts, + [None] * len(validation_prompts), + ) + if "deepfloyd-stage2" in args.model_type: + _content = StateTracker.get_validation_sample_images() + logger.info( + f"Processing {len(_content)} DeepFloyd stage 2 validation images." + ) + + for _validation_prompt in tqdm( + _content, leave=False, ncols=125, desc="Generating validation images", ): + validation_shortname, validation_prompt, validation_sample = ( + _validation_prompt + ) logger.debug(f"Validation image: {validation_prompt}") # Each validation prompt needs its own embed. if StateTracker.get_model_type() == "sdxl": @@ -291,7 +422,10 @@ def log_validations( logger.debug( f"Validations received the prompt embed: positive={current_validation_prompt_embeds.shape}, negative={validation_negative_prompt_embeds.shape}" ) - if prompt_handler is not None: + if ( + prompt_handler is not None + and "deepfloyd" not in args.model_type + ): for text_encoder in prompt_handler.text_encoders: if text_encoder: text_encoder = text_encoder.to(accelerator.device) @@ -316,29 +450,29 @@ def log_validations( ) ) - logger.debug( - f"Generating validation image: {validation_prompt}" - "\n Device allocations:" - f"\n -> unet on {pipeline.unet.device}" - f"\n -> text_encoder on {pipeline.text_encoder.device if pipeline.text_encoder is not None else None}" - f"\n -> vae on {pipeline.vae.device}" - f"\n -> current_validation_prompt_embeds on {current_validation_prompt_embeds.device}" - f"\n -> current_validation_pooled_embeds on {current_validation_pooled_embeds.device}" - f"\n -> validation_negative_prompt_embeds on {validation_negative_prompt_embeds.device}" - f"\n -> validation_negative_pooled_embeds on {validation_negative_pooled_embeds.device}" - ) + # logger.debug( + # f"Generating validation image: {validation_prompt}" + # "\n Device allocations:" + # f"\n -> unet on {pipeline.unet.device}" + # f"\n -> text_encoder on {pipeline.text_encoder.device if pipeline.text_encoder is not None else None}" + # f"\n -> vae on {pipeline.vae.device if hasattr(pipeline, 'vae') else None}" + # f"\n -> current_validation_prompt_embeds on {current_validation_prompt_embeds.device}" + # f"\n -> current_validation_pooled_embeds on {current_validation_pooled_embeds.device if current_validation_pooled_embeds is not None else None}" + # f"\n -> validation_negative_prompt_embeds on {validation_negative_prompt_embeds.device}" + # f"\n -> validation_negative_pooled_embeds on {validation_negative_pooled_embeds.device if validation_negative_pooled_embeds is not None else None}" + # ) - logger.debug( - f"Generating validation image: {validation_prompt}" - f"\n Weight dtypes:" - f"\n -> unet: {pipeline.unet.dtype}" - f"\n -> text_encoder: {pipeline.text_encoder.dtype if pipeline.text_encoder is not None else None}" - f"\n -> vae: {pipeline.vae.dtype}" - f"\n -> current_validation_prompt_embeds: {current_validation_prompt_embeds.dtype}" - f"\n -> current_validation_pooled_embeds: {current_validation_pooled_embeds.dtype}" - f"\n -> validation_negative_prompt_embeds: {validation_negative_prompt_embeds.dtype}" - f"\n -> validation_negative_pooled_embeds: {validation_negative_pooled_embeds.dtype}" - ) + # logger.debug( + # f"Generating validation image: {validation_prompt}" + # f"\n Weight dtypes:" + # f"\n -> unet: {pipeline.unet.dtype}" + # f"\n -> text_encoder: {pipeline.text_encoder.dtype if pipeline.text_encoder is not None else None}" + # f"\n -> vae: {pipeline.vae.dtype}" + # f"\n -> current_validation_prompt_embeds: {current_validation_prompt_embeds.dtype}" + # f"\n -> current_validation_pooled_embeds: {current_validation_pooled_embeds.dtype}" + # f"\n -> validation_negative_prompt_embeds: {validation_negative_prompt_embeds.dtype}" + # f"\n -> validation_negative_pooled_embeds: {validation_negative_pooled_embeds.dtype}" + # ) # logger.debug( # f"Generating validation image: {validation_prompt}" # f"\n -> Number of images: {args.num_validation_images}" @@ -348,50 +482,88 @@ def log_validations( # f"\n -> Resolution: {args.validation_resolution}" # f"\n -> Extra validation kwargs: {extra_validation_kwargs}" # ) - validation_images.extend( - pipeline( - prompt_embeds=current_validation_prompt_embeds, - pooled_prompt_embeds=current_validation_pooled_embeds, - negative_prompt_embeds=validation_negative_prompt_embeds, - negative_pooled_prompt_embeds=validation_negative_pooled_embeds, - num_images_per_prompt=args.num_validation_images, - num_inference_steps=args.validation_num_inference_steps, - guidance_scale=args.validation_guidance, - guidance_rescale=args.validation_guidance_rescale, - height=int(args.validation_resolution), - width=int(args.validation_resolution), - **extra_validation_kwargs, - ).images - ) - validation_images[-1].save( - os.path.join( - val_save_dir, - f"step_{global_step}_val_img_{len(validation_images)}.png", + if "deepfloyd" not in args.model_type: + extra_validation_kwargs["pooled_prompt_embeds"] = ( + current_validation_pooled_embeds + ) + extra_validation_kwargs["negative_pooled_prompt_embeds"] = ( + validation_negative_pooled_embeds + ) + extra_validation_kwargs["guidance_rescale"] = ( + args.validation_guidance_rescale, + ) + + if validation_sample is not None: + # Resize the input sample so that we can validate the model's upscaling performance. + width, height, _ = ( + MultiaspectImage.calculate_new_size_by_pixel_edge( + validation_sample.size[0], validation_sample.size[1], 64 + ) + ) + extra_validation_kwargs["image"] = validation_sample.resize( + (width, height) + ) + + validation_resolutions = get_validation_resolutions() + logger.debug(f"Resolutions for validation: {validation_resolutions}") + if validation_shortname not in validation_images: + validation_images[validation_shortname] = [] + for resolution in validation_resolutions: + validation_resolution_width, validation_resolution_height = ( + resolution + ) + validation_images[validation_shortname].extend( + pipeline( + prompt_embeds=current_validation_prompt_embeds, + negative_prompt_embeds=validation_negative_prompt_embeds, + num_images_per_prompt=args.num_validation_images, + num_inference_steps=args.validation_num_inference_steps, + guidance_scale=args.validation_guidance, + height=int(validation_resolution_height), + width=int(validation_resolution_width), + **extra_validation_kwargs, + ).images + ) + validation_img_idx = 0 + for validation_image in validation_images[validation_shortname]: + validation_image.save( + os.path.join( + val_save_dir, + f"step_{global_step}_{validation_shortname}_{str(validation_resolutions[validation_img_idx])}.png", + ) ) - ) logger.debug(f"Completed generating image: {validation_prompt}") for tracker in accelerator.trackers: if tracker.name == "wandb": - validation_document = {} - validation_luminance = [] - for idx, validation_image in enumerate(validation_images): - # Create a WandB entry containing each image. - validation_document[validation_shortnames[idx]] = wandb.Image( - validation_image - ) - # Compute the luminance of each image. - validation_luminance.append( - calculate_luminance(validation_image) - ) - # Compute the mean luminance across all samples: - validation_luminance = torch.tensor(validation_luminance) - validation_document["validation_luminance"] = ( - validation_luminance.mean() - ) - del validation_luminance - tracker.log(validation_document, step=global_step) + resolution_list = [ + f"{res[0]}x{res[1]}" for res in get_validation_resolutions() + ] + columns = [ + "Prompt", + *resolution_list, + "Mean Luminance", + ] + table = wandb.Table(columns=columns) + + # Process each prompt and its associated images + for prompt_shortname, image_list in validation_images.items(): + wandb_images = [] + luminance_values = [] + for image in image_list: + wandb_image = wandb.Image(image) + wandb_images.append(wandb_image) + luminance = calculate_luminance(image) + luminance_values.append(luminance) + mean_luminance = torch.tensor(luminance_values).mean().item() + while len(wandb_images) < len(resolution_list): + # any missing images will crash it. use None so they are indexed. + wandb_images.append(None) + table.add_data(prompt_shortname, *wandb_images, mean_luminance) + + # Log the table to Weights & Biases + tracker.log({"Validation Gallery": table}, step=global_step) if validation_type == "validation" and args.use_ema: # Switch back to the original UNet parameters. diff --git a/helpers/metadata/backends/base.py b/helpers/metadata/backends/base.py index 3cb233a1..f948f394 100644 --- a/helpers/metadata/backends/base.py +++ b/helpers/metadata/backends/base.py @@ -32,6 +32,7 @@ def __init__( delete_problematic_images: bool = False, metadata_update_interval: int = 3600, minimum_image_size: int = None, + cache_file_suffix: str = None, ): self.id = id if self.id != data_backend.id: @@ -42,8 +43,11 @@ def __init__( self.data_backend = data_backend self.batch_size = batch_size self.instance_data_root = instance_data_root - self.cache_file = Path(cache_file) - self.metadata_file = Path(metadata_file) + if cache_file_suffix is not None: + cache_file = f"{cache_file}_{cache_file_suffix}" + metadata_file = f"{metadata_file}_{cache_file_suffix}" + self.cache_file = Path(f"{cache_file}.json") + self.metadata_file = Path(f"{metadata_file}.json") self.aspect_ratio_bucket_indices = {} self.image_metadata = {} # Store image metadata self.seen_images = {} @@ -717,6 +721,8 @@ def handle_vae_cache_inconsistencies(self, vae_cache, vae_cache_behavior: str): vae_cache: The VAECache object. vae_cache_behavior (str): Behavior for handling inconsistencies ('sync' or 'recreate'). """ + if "deepfloyd" in StateTracker.get_args().model_type: + return if vae_cache_behavior not in ["sync", "recreate"]: raise ValueError("Invalid VAE cache behavior specified.") logger.info(f"Scanning VAE cache for inconsistencies with aspect buckets...") diff --git a/helpers/metadata/backends/json.py b/helpers/metadata/backends/json.py index 65b2a0dd..f24ad149 100644 --- a/helpers/metadata/backends/json.py +++ b/helpers/metadata/backends/json.py @@ -33,6 +33,7 @@ def __init__( delete_problematic_images: bool = False, metadata_update_interval: int = 3600, minimum_image_size: int = None, + cache_file_suffix: str = None, ): super().__init__( id=id, @@ -47,6 +48,7 @@ def __init__( delete_problematic_images=delete_problematic_images, metadata_update_interval=metadata_update_interval, minimum_image_size=minimum_image_size, + cache_file_suffix=cache_file_suffix, ) def __len__(self): diff --git a/helpers/metadata/backends/parquet.py b/helpers/metadata/backends/parquet.py index 85eec56f..07c6275a 100644 --- a/helpers/metadata/backends/parquet.py +++ b/helpers/metadata/backends/parquet.py @@ -33,6 +33,7 @@ def __init__( delete_problematic_images: bool = False, metadata_update_interval: int = 3600, minimum_image_size: int = None, + cache_file_suffix: str = None, ): self.parquet_config = parquet_config self.parquet_path = parquet_config.get("path", None) @@ -49,6 +50,7 @@ def __init__( delete_problematic_images=delete_problematic_images, metadata_update_interval=metadata_update_interval, minimum_image_size=minimum_image_size, + cache_file_suffix=cache_file_suffix, ) self.load_parquet_database() diff --git a/helpers/multiaspect/dataset.py b/helpers/multiaspect/dataset.py index 938470ae..93345fe9 100644 --- a/helpers/multiaspect/dataset.py +++ b/helpers/multiaspect/dataset.py @@ -1,4 +1,5 @@ from torch.utils.data import Dataset +from helpers.training.state_tracker import StateTracker from pathlib import Path import logging, os @@ -32,15 +33,19 @@ def __getitem__(self, image_tuple): first_aspect_ratio = None for sample in image_tuple: image_metadata = sample - if first_aspect_ratio is None: - first_aspect_ratio = image_metadata["aspect_ratio"] - elif first_aspect_ratio != image_metadata["aspect_ratio"]: - raise ValueError( - f"Aspect ratios must be the same for all images in a batch. Expected: {first_aspect_ratio}, got: {image_metadata['aspect_ratio']}" - ) + if 'aspect_ratio' in image_metadata: + if first_aspect_ratio is None: + first_aspect_ratio = image_metadata["aspect_ratio"] + elif first_aspect_ratio != image_metadata["aspect_ratio"]: + raise ValueError( + f"Aspect ratios must be the same for all images in a batch. Expected: {first_aspect_ratio}, got: {image_metadata['aspect_ratio']}" + ) if ( - image_metadata["original_size"] is None - or image_metadata["target_size"] is None + "deepfloyd" not in StateTracker.get_args().model_type + and ( + image_metadata["original_size"] is None + or image_metadata["target_size"] is None + ) ): raise Exception( f"Metadata was unavailable for image: {image_metadata['image_path']}. Ensure --skip_file_discovery=metadata is not set." diff --git a/helpers/multiaspect/image.py b/helpers/multiaspect/image.py index c34d80dc..4516dd23 100644 --- a/helpers/multiaspect/image.py +++ b/helpers/multiaspect/image.py @@ -50,7 +50,8 @@ def prepare_image( original_width, original_height = image_size original_resolution = resolution # Convert 'resolution' from eg. "1 megapixel" to "1024 pixels" - original_resolution = original_resolution * 1e3 + if resolution_type == "area": + original_resolution = original_resolution * 1e3 # Make resolution a multiple of 64 original_resolution = MultiaspectImage._round_to_nearest_multiple( original_resolution, 64 @@ -398,15 +399,15 @@ def calculate_new_size_by_pixel_edge(W: int, H: int, resolution: int): if W < H: W = resolution H = MultiaspectImage._round_to_nearest_multiple( - resolution / aspect_ratio, 8 + resolution / aspect_ratio, 64 ) elif H < W: H = resolution W = MultiaspectImage._round_to_nearest_multiple( - resolution * aspect_ratio, 8 + resolution * aspect_ratio, 64 ) else: - W = H = MultiaspectImage._round_to_nearest_multiple(resolution, 8) + W = H = MultiaspectImage._round_to_nearest_multiple(resolution, 64) new_aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio((W, H)) return int(W), int(H), new_aspect_ratio diff --git a/helpers/multiaspect/sampler.py b/helpers/multiaspect/sampler.py index 626eb953..d2b26f18 100644 --- a/helpers/multiaspect/sampler.py +++ b/helpers/multiaspect/sampler.py @@ -125,6 +125,36 @@ def load_buckets(self): self.metadata_backend.aspect_ratio_bucket_indices.keys() ) # These keys are a float value, eg. 1.78. + def retrieve_validation_set(self, batch_size: int): + """ + Return random images from the set. They should be paired with their caption. + + Args: + batch_size (int): Number of images to return. + Returns: + list: a list of tuples(validation_shortname, validation_prompt, validation_sample) + """ + results = ( + [] + ) # [tuple(validation_shortname, validation_prompt, validation_sample)] + for _ in range(batch_size): + image_path = self._yield_random_image() + image_data = self.data_backend.read_image(image_path) + image_metadata = self.metadata_backend.get_metadata_by_filepath(image_path) + validation_shortname = os.path.basename(image_path)[:10] + validation_prompt = PromptHandler.magic_prompt( + sampler_backend_id=self.id, + data_backend=self.data_backend, + image_path=image_path, + caption_strategy=self.caption_strategy, + use_captions=self.use_captions, + prepend_instance_prompt=self.prepend_instance_prompt, + instance_prompt=self.instance_prompt, + ) + results.append((validation_shortname, validation_prompt, image_data)) + + return results + def _yield_random_image(self): bucket = random.choice(self.buckets) image_path = random.choice( @@ -286,10 +316,22 @@ def _validate_and_yield_images_from_samples(self, samples, bucket): to_yield = [] for image_path in samples: image_metadata = self.metadata_backend.get_metadata_by_filepath(image_path) - if "crop_coordinates" not in image_metadata: + if ( + StateTracker.get_args().model_type + not in [ + "legacy", + "deepfloyd-full", + "deepfloyd-lora", + "deepfloyd-stage2", + "deepfloyd-stage2-lora", + ] + and "crop_coordinates" not in image_metadata + ): raise Exception( f"An image was discovered ({image_path}) that did not have its metadata: {self.metadata_backend.get_metadata_by_filepath(image_path)}" ) + if image_metadata is None: + image_metadata = {} image_metadata["data_backend_id"] = self.id image_metadata["image_path"] = image_path @@ -349,9 +391,10 @@ def __iter__(self): self.batch_accumulator ) # Now we'll add only remaining_entries_needed amount to the accumulator: - self.debug_log( - f"Current bucket: {self.current_bucket}. Adding samples with aspect ratios: {[i['aspect_ratio'] for i in to_yield[:remaining_entries_needed]]}" - ) + if "aspect_ratio" in to_yield[0]: + self.debug_log( + f"Current bucket: {self.current_bucket}. Adding samples with aspect ratios: {[i['aspect_ratio'] for i in to_yield[:remaining_entries_needed]]}" + ) self.batch_accumulator.extend(to_yield[:remaining_entries_needed]) # If the batch is full, yield it if len(self.batch_accumulator) >= self.batch_size: diff --git a/helpers/sdxl/save_hooks.py b/helpers/sdxl/save_hooks.py index 13d37232..114a86bb 100644 --- a/helpers/sdxl/save_hooks.py +++ b/helpers/sdxl/save_hooks.py @@ -41,7 +41,7 @@ def save_model_hook(self, models, weights, output_dir): StateTracker.save_training_state( os.path.join(output_dir, "training_state.json") ) - if self.args.model_type == "lora": + if "lora" in self.args.model_type: # there are only two options here. Either are just the unet attn processor layers # or there are the unet and text encoder atten layers unet_lora_layers_to_save = None @@ -118,7 +118,7 @@ def load_model_hook(self, models, input_dir): f"Could not find training_state.json in checkpoint dir {input_dir}" ) - if self.args.model_type == "lora": + if "lora" in self.args.model_type: logger.info(f"Loading LoRA weights from Path: {input_dir}") unet_ = None text_encoder_one_ = None diff --git a/helpers/training/collate.py b/helpers/training/collate.py index 6b168cdc..e3126987 100644 --- a/helpers/training/collate.py +++ b/helpers/training/collate.py @@ -2,6 +2,7 @@ from os import environ from helpers.training.state_tracker import StateTracker from helpers.training.multi_process import rank_info +from helpers.multiaspect.image import MultiaspectImage from helpers.image_manipulation.brightness import calculate_batch_luminance from accelerate.logging import get_logger from concurrent.futures import ThreadPoolExecutor @@ -69,6 +70,37 @@ def extract_filepaths(examples): return filepaths +def fetch_pixel_values(fp, data_backend_id: str): + """Worker method to fetch pixel values for a single image.""" + debug_log( + f" -> pull pixels for fp {fp} from cache via data backend {data_backend_id}" + ) + pixels = StateTracker.get_data_backend(data_backend_id)["data_backend"].read_image( + fp + ) + """ + def prepare_image( + resolution: float, + image: Image = None, + image_metadata: dict = None, + resolution_type: str = "pixel", + id: str = "foo", + ): + + """ + backend_config = StateTracker.get_data_backend_config(data_backend_id) + reformed_image, _, _ = MultiaspectImage.prepare_image( + resolution=backend_config["resolution"], + image=pixels, + image_metadata=None, + resolution_type=backend_config["resolution_type"], + id=data_backend_id, + ) + image_transform = MultiaspectImage.get_image_transforms()(reformed_image) + + return image_transform + + def fetch_latent(fp, data_backend_id: str): """Worker method to fetch latent for a single image.""" debug_log( @@ -83,9 +115,32 @@ def fetch_latent(fp, data_backend_id: str): return latent +def deepfloyd_pixels(filepaths, data_backend_id: str): + """DeepFloyd doesn't use the VAE. We retrieve, normalise, and stack the pixel tensors directly.""" + # Use a thread pool to fetch latents concurrently + try: + with concurrent.futures.ThreadPoolExecutor() as executor: + pixels = list( + executor.map( + fetch_pixel_values, filepaths, [data_backend_id] * len(filepaths) + ) + ) + except Exception as e: + logger.error(f"(id={data_backend_id}) Error while computing pixels: {e}") + raise + pixels = torch.stack(pixels) + pixels = pixels.to(memory_format=torch.contiguous_format).float() + + return pixels + + def compute_latents(filepaths, data_backend_id: str): # Use a thread pool to fetch latents concurrently try: + if "deepfloyd" in StateTracker.get_args().model_type: + latents = deepfloyd_pixels(filepaths, data_backend_id) + + return latents if not StateTracker.get_args().vae_cache_preprocess: latents = StateTracker.get_vaecache(id=data_backend_id).encode_images( [None] * len(filepaths), filepaths @@ -247,15 +302,21 @@ def collate_fn(batch): example["drop_conditioning"] = False debug_log("Collect luminance values") - batch_luminance = [example["luminance"] for example in examples] + if "luminance" in examples[0]: + batch_luminance = [example["luminance"] for example in examples] + else: + batch_luminance = [0] * len(examples) # average it batch_luminance = sum(batch_luminance) / len(batch_luminance) debug_log("Extract filepaths") filepaths = extract_filepaths(examples) debug_log("Compute latents") latent_batch = compute_latents(filepaths, data_backend_id) - debug_log("Check latents") - latent_batch = check_latent_shapes(latent_batch, filepaths, data_backend_id, batch) + if "deepfloyd" not in StateTracker.get_args().model_type: + debug_log("Check latents") + latent_batch = check_latent_shapes( + latent_batch, filepaths, data_backend_id, batch + ) # Compute embeddings and handle dropped conditionings debug_log("Extract captions") diff --git a/helpers/training/min_snr_gamma.py b/helpers/training/min_snr_gamma.py index f62cc36b..857a5dd3 100644 --- a/helpers/training/min_snr_gamma.py +++ b/helpers/training/min_snr_gamma.py @@ -2,15 +2,24 @@ import logging, torch -def compute_snr(timesteps, noise_scheduler): +def compute_snr(timesteps, noise_scheduler, use_soft_min: bool = False, sigma_data=1.0): """ - Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + Computes SNR using two different methods based on the `use_soft_min` flag. + + Args: + timesteps (torch.Tensor): The timesteps at which SNR is computed. + noise_scheduler (NoiseScheduler): An object that contains the alpha_cumprod values. + use_soft_min (bool): If True, use the _weighting_soft_min_snr method to compute SNR. + sigma_data (torch.Tensor or None): The standard deviation of the data used in the soft min weighting method. + + Returns: + torch.Tensor: The computed SNR values. """ alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + # Expand the tensors. - # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ timesteps ].float() @@ -25,6 +34,15 @@ def compute_snr(timesteps, noise_scheduler): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) - # Compute SNR - snr = (alpha / sigma) ** 2 + # Choose the method to compute SNR + if use_soft_min: + if sigma_data is None: + raise ValueError( + "sigma_data must be provided when using soft min SNR calculation." + ) + snr = (sigma * sigma_data) ** 2 / (sigma**2 + sigma_data**2) ** 2 + else: + # Default SNR computation + snr = (alpha / sigma) ** 2 + return snr diff --git a/helpers/training/model_freeze.py b/helpers/training/model_freeze.py index 05039563..6ffbc5eb 100644 --- a/helpers/training/model_freeze.py +++ b/helpers/training/model_freeze.py @@ -24,7 +24,9 @@ def freeze_entire_component(component): def freeze_text_encoder(args, component): - if not args.freeze_encoder: + from transformers import T5EncoderModel + + if not args.freeze_encoder or type(component) is T5EncoderModel: logger.info(f"Not freezing text encoder. Live dangerously and prosper!") return component method = args.freeze_encoder_strategy diff --git a/helpers/training/state_tracker.py b/helpers/training/state_tracker.py index 4af27fbb..9c50e99f 100644 --- a/helpers/training/state_tracker.py +++ b/helpers/training/state_tracker.py @@ -37,6 +37,8 @@ class StateTracker: exhausted_backends = [] # A dict of backend IDs to the number of times they have been repeated. repeats = {} + # The images we'll use for upscaling at validation time. Stored at startup. + validation_sample_images = [] vae = None vae_dtype = None weight_dtype = None @@ -300,6 +302,14 @@ def get_caption_files(cls): cls.all_caption_files = cls._load_from_disk("all_caption_files") return cls.all_caption_files + @classmethod + def get_validation_sample_images(cls): + return cls.validation_sample_images + + @classmethod + def set_validation_sample_images(cls, validation_sample_images): + cls.validation_sample_images = validation_sample_images + @classmethod def register_data_backend(cls, data_backend): cls.data_backends[data_backend["id"]] = data_backend diff --git a/inference.py b/inference.py index d4fd01e3..2b7e5dd2 100644 --- a/inference.py +++ b/inference.py @@ -15,8 +15,8 @@ logger.setLevel(logging.INFO) # Load the pipeline with the same arguments (model, revision) that were used for training model_id = "stabilityai/stable-diffusion-2" -model_id = "ptx0/pseudo-flex-base" -base_dir = "/notebooks/datasets" +model_id = "ptx0/terminus-xl-gamma-v2-1" +base_dir = "/Volumes/models/training" model_path = os.path.join(base_dir, "models") # output_test_dir = os.path.join(base_dir, 'test_results') output_test_dir = os.path.join(base_dir, "encoder_test") @@ -110,7 +110,11 @@ rescale_betas_zero_snr=True, timestep_spacing="trailing", ) - pipeline.to("cuda") + pipeline.to( + "cuda" + if torch.cuda.is_available() + else "mps" if torch.mps.is_available() else "cpu" + ) except Exception as e: logging.info( f"Could not generate pipeline for checkpoint {checkpoint}: {e}" @@ -130,6 +134,7 @@ logging.info(f"Negative: {negative}") conditioning = compel.build_conditioning_tensor(prompt) generator = torch.Generator(device="cuda").manual_seed(torch_seed) + pipeline.do_guidance_rescale_before = 20 output = pipeline( generator=generator, negative_prompt_embeds=negative_embed, diff --git a/sdxl-env.sh.example b/sdxl-env.sh.example index 8f44772c..1649512c 100644 --- a/sdxl-env.sh.example +++ b/sdxl-env.sh.example @@ -18,6 +18,8 @@ if [[ "$MODEL_TYPE" == "full" ]]; then elif [[ "$MODEL_TYPE" == "lora" ]]; then # As of v0.9.2 of SimpleTuner, LoRA can not use BitFit. export USE_BITFIT=false +elif [[ "$MODEL_TYPE" == "deepfloyd-full" ]]; then + export USE_BITFIT=true fi # Restart where we left off. Change this to "checkpoint-1234" to start from a specific checkpoint. @@ -134,7 +136,7 @@ export ALLOW_TF32=true # AdamW 8Bit is a robust and lightweight choice. Adafactor might reduce memory consumption, and Dadaptation is slow and experimental. # AdamW is the default optimizer, but it uses a lot of memory and is slower than AdamW8Bit or Adafactor. # Choices: adamw, adamw8bit, adafactor, dadaptation -export OPTIMIZER="adamw8bit" +export OPTIMIZER="adamw_bf16" # EMA is a strong regularisation method that uses a lot of extra VRAM to hold two copies of the weights. @@ -188,4 +190,4 @@ export ACCELERATE_EXTRA_ARGS="" # --multi_gpu or other # With Pytorch 2.1, you might have pretty good luck here. # If you're using aspect bucketing however, each resolution change will recompile. Seriously, just don't do it. # Well, then again... Pytorch 2.2 has support for dynamic shapes. Why not? -export TRAINING_DYNAMO_BACKEND='inductor' # or 'no' if you want to disable torch compile in case of performance issues or lack of support (eg. AMD) \ No newline at end of file +export TRAINING_DYNAMO_BACKEND='no' # or 'no' if you want to disable torch compile in case of performance issues or lack of support (eg. AMD) \ No newline at end of file diff --git a/tests/test_image.py b/tests/test_image.py index c86bb549..01cf094d 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -49,7 +49,7 @@ def test_image_resize(self): # Define target resolutions and expected output sizes tests = [ - (1024, "pixel", (1824, 1024)), + (1024, "pixel", (1792, 1024)), (1.0, "area", (1344, 768)), # Assuming target is 1 megapixel ] diff --git a/toolkit/captioning/caption_with_llava.py b/toolkit/captioning/caption_with_llava.py index 2b641b2d..41f05bbe 100644 --- a/toolkit/captioning/caption_with_llava.py +++ b/toolkit/captioning/caption_with_llava.py @@ -189,6 +189,7 @@ def load_llava_model( bnb_config = None torch_dtype = torch.float16 if "1.6" in model_path: + logger.info("Using LLaVA 1.6+ model.") model = LlavaNextForConditionalGeneration.from_pretrained( model_path, quantization_config=bnb_config, @@ -196,6 +197,7 @@ def load_llava_model( device_map="auto", ) else: + logger.info("Using LLaVA 1.5 model.") model = LlavaForConditionalGeneration.from_pretrained( model_path, quantization_config=bnb_config, @@ -203,8 +205,10 @@ def load_llava_model( device_map="auto", ) if "1.6" in model_path: + logger.info("Using LLaVA 1.6+ model processor.") autoprocessor_cls = LlavaNextProcessor else: + logger.info("Using LLaVA 1.5 model processor.") autoprocessor_cls = AutoProcessor processor = autoprocessor_cls.from_pretrained(model_path) @@ -222,6 +226,7 @@ def eval_model(args, image_file, model, processor): if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) + logging.info(f"Inputs: {inputs}") else: prompt = f"\nUSER: {args.query_str}\nASSISTANT:" images = [image_file] diff --git a/train_sd21.py b/train_sd21.py index 86313c56..20f29226 100644 --- a/train_sd21.py +++ b/train_sd21.py @@ -183,7 +183,7 @@ def main(): hasattr(accelerator.state, "deepspeed_plugin") and accelerator.state.deepspeed_plugin is not None ): - if args.model_type == "lora": + if "lora" in args.model_type: logger.error( "LoRA can not be trained with DeepSpeed. Please disable DeepSpeed via 'accelerate config' before reattempting." ) @@ -242,6 +242,9 @@ def main(): # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.bfloat16 + if torch.backends.mps.is_available() and "deepfloyd" in args.model_type: + weight_dtype = torch.float32 + args.adam_bfloat16 = False StateTracker.set_weight_dtype(weight_dtype) # Load the scheduler, tokenizer and models. @@ -288,9 +291,17 @@ def main(): revision=args.revision, ), ) - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision - ) + from transformers import T5EncoderModel + + if "deepfloyd" not in args.model_type: + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + ) + vae.requires_grad_(False) + else: + vae = None unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ).to(weight_dtype) @@ -300,7 +311,6 @@ def main(): logger.info(f"Applying BitFit freezing strategy to the U-net.") unet = apply_bitfit_freezing(unet) - vae.requires_grad_(False) if not args.train_text_encoder: text_encoder.requires_grad_(False) @@ -315,7 +325,7 @@ def main(): ) unet.enable_xformers_memory_efficient_attention() - if args.model_type == "lora": + if "lora" in args.model_type: logger.info("Using LoRA training mode.") # now we will add new LoRA weights to the attention layers # Set correct lora layers @@ -501,13 +511,21 @@ def main(): extra_optimizer_args["lr"] = args.learning_rate # Optimizer creation - if args.model_type == "full": + if ( + args.model_type == "full" + or args.model_type == "deepfloyd-full" + or args.model_type == "deepfloyd-stage2" + ): params_to_optimize = ( itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() ) - elif args.model_type == "lora": + elif ( + args.model_type == "lora" + or args.model_type == "deepfloyd-lora" + or args.model_type == "deepfloyd-stage2-lora" + ): params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters())) if args.train_text_encoder: params_to_optimize = params_to_optimize + list( @@ -534,34 +552,36 @@ def main(): params_to_optimize, **extra_optimizer_args, ) - # For mixed precision training we cast the text_encoder and vae weights to half-precision - # as these models are only used for inference, keeping weights in full precision is not required. - weight_dtype = torch.bfloat16 + from helpers.legacy.validation import get_validation_resolutions + + # Kick out an early error for DF II trainers that used the wrong resolutions. + get_validation_resolutions() # Move text_encoder to device and cast to weight_dtype) logging.info("Moving text encoder to GPU..") text_encoder.to(accelerator.device, dtype=weight_dtype) - # Move vae, unet and text_encoder to device and cast to weight_dtype - # The VAE is in float32 to avoid NaN losses. - vae_dtype = torch.bfloat16 - if hasattr(args, "vae_dtype"): - logger.info( - f"Initialising VAE in {args.vae_dtype} precision, you may specify a different value if preferred: bf16, fp16, fp32, default" - ) - # Let's use a case-switch for convenience: bf16, fp16, fp32, none/default - if args.vae_dtype == "bf16" or args.mixed_precision == "bf16": - vae_dtype = torch.bfloat16 - elif args.vae_dtype == "fp16" or args.mixed_precision == "fp16": - vae_dtype = torch.float16 - elif args.vae_dtype == "fp32": - vae_dtype = torch.float32 - elif args.vae_dtype == "none" or args.vae_dtype == "default": - vae_dtype = torch.bfloat16 - logger.debug(f"Moving VAE to GPU with {vae_dtype} precision level.") - vae.to(accelerator.device, dtype=vae_dtype) - logger.info(f"Loaded VAE into VRAM.") - StateTracker.set_vae_dtype(vae_dtype) - StateTracker.set_vae(vae) + if vae is not None: + # Move vae, unet and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + vae_dtype = torch.bfloat16 + if hasattr(args, "vae_dtype"): + logger.info( + f"Initialising VAE in {args.vae_dtype} precision, you may specify a different value if preferred: bf16, fp16, fp32, default" + ) + # Let's use a case-switch for convenience: bf16, fp16, fp32, none/default + if args.vae_dtype == "bf16" or args.mixed_precision == "bf16": + vae_dtype = torch.bfloat16 + elif args.vae_dtype == "fp16" or args.mixed_precision == "fp16": + vae_dtype = torch.float16 + elif args.vae_dtype == "fp32": + vae_dtype = torch.float32 + elif args.vae_dtype == "none" or args.vae_dtype == "default": + vae_dtype = torch.bfloat16 + logger.debug(f"Moving VAE to GPU with {vae_dtype} precision level.") + vae.to(accelerator.device, dtype=vae_dtype) + logger.info(f"Loaded VAE into VRAM.") + StateTracker.set_vae_dtype(vae_dtype) + StateTracker.set_vae(vae) # Create a DataBackend, so that we can access our dataset. prompt_handler = None @@ -757,6 +777,9 @@ def main(): # Conditionally prepare the text_encoder if required if args.train_text_encoder: text_encoder = accelerator.prepare(text_encoder) + elif args.fully_unload_text_encoder: + del text_encoder + text_encoder = None # Conditionally prepare the EMA model if required if args.use_ema: @@ -790,7 +813,7 @@ def main(): f" {args.num_train_epochs} epochs and {num_update_steps_per_epoch} steps per epoch." ) - if not args.keep_vae_loaded and args.vae_cache_preprocess: + if vae is not None and not args.keep_vae_loaded and args.vae_cache_preprocess: memory_before_unload = torch.cuda.memory_allocated() / 1024**3 import gc @@ -893,7 +916,7 @@ def main(): } }, ) - + torch.autograd.set_detect_anomaly(True) logger.info("***** Running training *****") total_num_batches = sum( [ @@ -951,7 +974,8 @@ def main(): backend_config = StateTracker.get_data_backend_config(backend_id) logger.debug(f"Backend config: {backend_config}") if ( - "vae_cache_clear_each_epoch" in backend_config + "deepfloyd" not in args.model_type + and "vae_cache_clear_each_epoch" in backend_config and backend_config["vae_cache_clear_each_epoch"] ): # We will clear the cache and then rebuild it. This is useful for random crops. @@ -1009,7 +1033,9 @@ def main(): # Add the current batch of training data's avg luminance to a list. training_luminance_values.append(batch["batch_luminance"]) - with accelerator.accumulate(training_models): + with accelerator.accumulate( + training_models + ), torch.autograd.detect_anomaly(): training_logger.debug( f"Sending latent batch from pinned memory to device" ) @@ -1018,7 +1044,6 @@ def main(): ) # Sample noise that we'll add to the latents - args.noise_offset might need to be set to 0.1 by default. - noise = torch.randn_like(latents) if args.offset_noise: if ( args.noise_offset_probability == 1.0 @@ -1043,8 +1068,8 @@ def main(): noise = noise + args.input_perturbation * torch.randn_like( noise ) + bsz, channels, height, width = latents.shape - bsz = latents.shape[0] logger.debug(f"Working on batch size: {bsz}") # Sample a random timestep for each image, potentially biased by the timestep weights. # Biasing the timestep weights allows us to spend less time training irrelevant timesteps. @@ -1101,17 +1126,34 @@ def main(): f"\n -> Timesteps dtype: {timesteps.dtype}" f"\n -> Encoder hidden states dtype: {encoder_hidden_states.dtype}" ) + if unwrap_model(accelerator, unet).config.in_channels == channels * 2: + # deepfloyd stage ii requires the inputs to be doubled. note that we're working in pixels, not latents. + noisy_latents = torch.cat([noisy_latents, noisy_latents], dim=1) + + if "deepfloyd-stage2" in args.model_type: + class_labels = timesteps + else: + class_labels = None model_pred = unet( - noisy_latents, timesteps, encoder_hidden_states + noisy_latents, + timesteps, + encoder_hidden_states, + class_labels=class_labels, ).sample + if model_pred.shape[1] == 6: + # Chop the variance off of DeepFloyd models. + model_pred, _ = torch.chunk(model_pred, 2, dim=1) + # x-prediction requires that we now subtract the noise residual from the prediction to get the target sample. if noise_scheduler.config.prediction_type == "sample": model_pred = model_pred - noise if args.snr_gamma is None: - training_logger.debug(f"Calculating loss") + training_logger.debug( + f"Calculating loss for {model_pred.shape} vs {target.shape}" + ) loss = args.snr_weight * F.mse_loss( model_pred.float(), target.float(), reduction="mean" ) @@ -1120,7 +1162,21 @@ def main(): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. training_logger.debug(f"Using min-SNR loss") - snr = compute_snr(timesteps, noise_scheduler) + snr = compute_snr( + timesteps=timesteps, + noise_scheduler=noise_scheduler, + use_soft_min=( + True + if "deepfloyd" in args.model_type + or args.use_soft_min_snr is True + else False + ), + sigma_data=( + 1.0 + if args.soft_min_snr_sigma_data is None + else args.soft_min_snr_sigma_data + ), + ) snr_divisor = snr if noise_scheduler.config.prediction_type == "v_prediction": snr_divisor = snr + 1 @@ -1158,13 +1214,16 @@ def main(): logger.debug(f"Backwards pass.") accelerator.backward(loss) + grad_norm = None if ( accelerator.sync_gradients and not args.use_adafactor_optimizer and args.max_grad_norm > 0 ): # Adafactor shouldn't have gradient clipping applied. - accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm) + grad_norm = accelerator.clip_grad_norm_( + params_to_optimize, args.max_grad_norm + ) training_logger.debug(f"Stepping components forward.") optimizer.step() lr_scheduler.step(**scheduler_kwargs) @@ -1187,6 +1246,8 @@ def main(): "learning_rate": lr, "epoch": epoch, } + if grad_norm is not None: + logs["grad_norm"] = grad_norm progress_bar.update(1) global_step += 1 current_epoch_step += 1 @@ -1348,7 +1409,7 @@ def main(): unet = accelerator.unwrap_model(unet) if args.model_type == "full" and args.train_text_encoder: text_encoder = accelerator.unwrap_model(text_encoder) - elif args.model_type == "lora": + elif "lora" in args.model_type: unet_lora_layers = convert_state_dict_to_diffusers( get_peft_model_state_dict(unet) ) @@ -1373,7 +1434,7 @@ def main(): if args.use_ema: ema_unet.copy_to(unet.parameters()) - if StateTracker.get_vae() is None: + if StateTracker.get_vae() is None and "deepfloyd" not in args.model_type: StateTracker.set_vae( AutoencoderKL.from_pretrained( args.pretrained_vae_model_name_or_path, @@ -1386,13 +1447,24 @@ def main(): force_upcast=False, ) ) - pipeline = StableDiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - text_encoder=text_encoder, - vae=StateTracker.get_vae(), - unet=unet, - revision=args.revision, - ) + if "deepfloyd" in args.model_type: + from diffusers import DiffusionPipeline + + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=text_encoder, + unet=unet, + revision=args.revision, + ) + else: + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=text_encoder, + vae=StateTracker.get_vae(), + unet=unet, + revision=args.revision, + ) + pipeline.set_progress_bar_config(disable=True) pipeline.scheduler = SCHEDULER_NAME_MAP[ args.validation_noise_scheduler @@ -1403,12 +1475,12 @@ def main(): timestep_spacing="trailing", rescale_betas_zero_snr=True, ) - if args.model_type == "full": + if "full" in args.model_type: pipeline.save_pretrained( os.path.join(args.output_dir, args.hub_model_id or "pipeline"), safe_serialization=True, ) - elif args.model_type == "lora": + elif "lora" in args.model_type: pipeline.save_lora_weights(args.output_dir) if args.push_to_hub: @@ -1446,7 +1518,8 @@ def main(): else torch.bfloat16 if torch.cuda.is_available() else torch.float32 ), ) - pipeline.components["vae"].to(vae_dtype) + if "vae" in pipeline.components: + pipeline.components["vae"].to(vae_dtype) pipeline.scheduler = SCHEDULER_NAME_MAP[ args.validation_noise_scheduler ].from_pretrained( diff --git a/train_sd2x.sh b/train_sd2x.sh index afccedf5..dab53b40 100644 --- a/train_sd2x.sh +++ b/train_sd2x.sh @@ -203,8 +203,8 @@ if ! [ -f "$DATALOADER_CONFIG" ]; then fi export PURE_BF16_ARGS="" -if ! [ -z "$USE_PURE_BF16" ] && [[ "$USE_PURE_BF16" == "true" ]]; then - PURE_BF16_ARGS="--adamw_bf16" +if ! [ -z "$PURE_BF16" ] && [[ "$PURE_BF16" == "true" ]]; then + PURE_BF16_ARGS="--adam_bfloat16" MIXED_PRECISION="bf16" fi @@ -240,7 +240,7 @@ case $OPTIMIZER in export OPTIMIZER_ARG="" ;; "adamw_bf16") - export OPTIMIZER_ARG="--adamw_bf16" + export OPTIMIZER_ARG="--adam_bfloat16" ;; "adamw8bit") export OPTIMIZER_ARG="--use_8bit_adam" @@ -268,12 +268,30 @@ elif [[ "$MODEL_TYPE" == "lora" ]] && [[ "$USE_DORA" != "false" ]]; then DORA_ARGS="--use_dora" fi +export DORA_ARGS="" +if [[ "$MODEL_TYPE" == "deepfloyd-full" ]] && [[ "$USE_DORA" != "false" ]]; then + echo "Cannot use DoRA with a full u-net training task. Disabling DoRA." +elif [[ "$MODEL_TYPE" == "deepfloyd-lora" ]] && [[ "$USE_DORA" != "false" ]]; then + echo "Enabling DoRA." + DORA_ARGS="--use_dora" +fi + export BITFIT_ARGS="" if [[ "$MODEL_TYPE" == "full" ]] && [[ "$USE_BITFIT" != "false" ]]; then echo "Enabling BitFit." BITFIT_ARGS="--freeze_unet_strategy=bitfit" elif [[ "$MODEL_TYPE" == "lora" ]] && [[ "$USE_BITFIT" != "false" ]]; then - echo "Cannot use BitFit with a full u-net training task. Disabling." + echo "Cannot use BitFit with a LoRA training task. Disabling." +elif [[ "$MODEL_TYPE" == "deepfloyd-full" ]] && [[ "$USE_BITFIT" != "false" ]]; then + echo "Enabling BitFit." + BITFIT_ARGS="--freeze_unet_strategy=bitfit" +elif [[ "$MODEL_TYPE" == "deepfloyd-stage2" ]] && [[ "$USE_BITFIT" != "false" ]]; then + echo "Enabling BitFit." + BITFIT_ARGS="--freeze_unet_strategy=bitfit" +elif [[ "$MODEL_TYPE" == "deepfloyd-lora" ]] && [[ "$USE_BITFIT" != "false" ]]; then + echo "Cannot use BitFit with a LoRA training task. Disabling." +elif [[ "$MODEL_TYPE" == "deepfloyd-stage2-lora" ]] && [[ "$USE_BITFIT" != "false" ]]; then + echo "Cannot use BitFit with a LoRA training task. Disabling." fi export ASPECT_BUCKET_ROUNDING_ARGS="" diff --git a/train_sdxl.py b/train_sdxl.py index 9300b7a6..b348e1a1 100644 --- a/train_sdxl.py +++ b/train_sdxl.py @@ -187,7 +187,7 @@ def main(): hasattr(accelerator.state, "deepspeed_plugin") and accelerator.state.deepspeed_plugin is not None ): - if args.model_type == "lora": + if "lora" in args.model_type: logger.error( "LoRA can not be trained with DeepSpeed. Please disable DeepSpeed via 'accelerate config' before reattempting." ) @@ -348,7 +348,7 @@ def main(): "xformers is not available. Make sure it is installed correctly" ) - if args.model_type == "lora": + if "lora" in args.model_type: logger.info("Using LoRA training mode.") # now we will add new LoRA weights to the attention layers # Set correct lora layers @@ -685,7 +685,7 @@ def main(): raise ValueError( "Full model tuning does not currently support text encoder training." ) - elif args.model_type == "lora": + elif "lora" in args.model_type: params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters())) if args.train_text_encoder: params_to_optimize = ( @@ -859,7 +859,7 @@ def main(): train_dataloaders.append(accelerator.prepare(backend["train_dataloader"])) idx_count = 0 - if args.model_type == "lora" and args.train_text_encoder: + if "lora" in args.model_type and args.train_text_encoder: logger.info("Preparing text encoders for training.") text_encoder_1, text_encoder_2 = accelerator.prepare( text_encoder_1, text_encoder_2 @@ -1082,7 +1082,7 @@ def main(): unet.train() training_models = [unet] - if args.model_type == "lora" and args.train_text_encoder: + if "lora" in args.model_type and args.train_text_encoder: text_encoder_1.train() text_encoder_2.train() training_models.append(text_encoder_1) @@ -1478,7 +1478,7 @@ def main(): accelerator.wait_for_everyone() if accelerator.is_main_process: unet = unwrap_model(accelerator, unet) - if args.model_type == "lora": + if "lora" in args.model_type: unet_lora_layers = convert_state_dict_to_diffusers( get_peft_model_state_dict(unet) ) @@ -1581,7 +1581,7 @@ def main(): validation_type="finish", pipeline=pipeline, ) - elif args.model_type == "lora": + elif "lora" in args.model_type: # load attention processors. They were saved earlier. pipeline.load_lora_weights(args.output_dir) log_validations( diff --git a/train_sdxl.sh b/train_sdxl.sh index 286b7528..15c20d45 100644 --- a/train_sdxl.sh +++ b/train_sdxl.sh @@ -27,8 +27,8 @@ if [ -z "${MIXED_PRECISION}" ]; then fi export PURE_BF16_ARGS="" -if ! [ -z "$USE_PURE_BF16" ] && [[ "$USE_PURE_BF16" == "true" ]]; then - PURE_BF16_ARGS="--adamw_bf16" +if ! [ -z "$PURE_BF16" ] && [[ "$PURE_BF16" == "true" ]]; then + PURE_BF16_ARGS="--adam_bfloat16" MIXED_PRECISION="bf16" fi @@ -254,12 +254,19 @@ fi export DORA_ARGS="" if [[ "$MODEL_TYPE" == "full" ]] && [[ "$USE_DORA" != "false" ]]; then echo "Cannot use DoRA with a full u-net training task. Disabling DoRA." - export USE_DORA="false" elif [[ "$MODEL_TYPE" == "lora" ]] && [[ "$USE_DORA" != "false" ]]; then echo "Enabling DoRA." DORA_ARGS="--use_dora" fi +export BITFIT_ARGS="" +if [[ "$MODEL_TYPE" == "full" ]] && [[ "$USE_BITFIT" != "false" ]]; then + echo "Enabling BitFit." + BITFIT_ARGS="--freeze_unet_strategy=bitfit" +elif [[ "$MODEL_TYPE" == "lora" ]] && [[ "$USE_BITFIT" != "false" ]]; then + echo "Cannot use BitFit with a LoRA training task. Disabling." +fi + export BITFIT_ARGS="" if [[ "$MODEL_TYPE" == "full" ]] && [[ "$USE_DORA" != "false" ]]; then echo "Enabling BitFit."