Skip to content

Commit

Permalink
Merge pull request #1029 from bghira/main
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
bghira authored Oct 5, 2024
2 parents 565aedd + 1d71033 commit 01de5d0
Show file tree
Hide file tree
Showing 18 changed files with 371 additions and 70 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ For memory-constrained systems, see the [DeepSpeed document](/documentation/DEEP
- Most models are trainable on a 24G GPU, or even down to 16G at lower base resolutions.
- LoRA/LyCORIS training for PixArt, SDXL, SD3, and SD 2.x that uses less than 16G VRAM
- DeepSpeed integration allowing for [training SDXL's full u-net on 12G of VRAM](/documentation/DEEPSPEED.md), albeit very slowly.
- Quantised LoRA training, using low-precision base model or text encoder weights to reduce VRAM consumption while still allowing DreamBooth.
- Quantised NF4/INT8/FP8 LoRA training, using low-precision base model to reduce VRAM consumption.
- Optional EMA (Exponential moving average) weight network to counteract model overfitting and improve training stability. **Note:** This does not apply to LoRA.
- Train directly from an S3-compatible storage provider, eliminating the requirement for expensive local storage. (Tested with Cloudflare R2 and Wasabi S3)
- For only SDXL and SD 1.x/2.x, full [ControlNet model training](/documentation/CONTROLNET.md) (not ControlLoRA or ControlLite)
Expand Down Expand Up @@ -105,7 +105,7 @@ RunwayML's SD 1.5 and StabilityAI's SD 2.x are both trainable under the `legacy`

### NVIDIA

Pretty much anything 3090 and up is a safe bet. YMMV.
Pretty much anything 3080 and up is a safe bet. YMMV.

### AMD

Expand All @@ -124,7 +124,8 @@ LoRA and full-rank tuning are tested to work on an M3 Max with 128G memory, taki
- A100-80G (Full tune with DeepSpeed)
- A100-40G (LoRA, LoKr)
- 3090 24G (LoRA, LoKr)
- 4060 Ti, 3080 16G (int8, LoRA, LoKr)
- 4060 Ti 16G, 4070 Ti 16G, 3080 16G (int8, LoRA, LoKr)
- 4070 Super 12G, 3080 10G, 3060 12GB (nf4, LoRA, LoKr)

Flux prefers being trained with multiple large GPUs but a single 16G card should be able to do it with quantisation of the transformer and text encoders.

Expand Down
19 changes: 18 additions & 1 deletion documentation/quickstart/FLUX.md
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,23 @@ Inferencing the CFG-distilled LoRA is as easy as using a lower guidance_scale ar

## Notes & troubleshooting tips

### Lowest VRAM config

Currently, the lowest VRAM utilisation (9090M) can be attained with:

- OS: Ubuntu Linux 24
- GPU: A single NVIDIA CUDA device (10G, 12G)
- System memory: 50G of system memory approximately
- Base model precision: `bnb-nf4`
- Optimiser: Lion 8Bit Paged, `bnb-lion8bit-paged`
- Resolution: 512px
- 1024px requires >= 12G VRAM
- Batch size: 1, zero gradient accumulation steps
- DeepSpeed: disabled / unconfigured
- PyTorch: 2.6 Nightly (Sept 29th build)

Speed was approximately 1.4 iterations per second on a 4090.

### Classifier-free guidance

#### Problem
Expand Down Expand Up @@ -402,7 +419,7 @@ We can partially reintroduce distillation to a de-distilled model by continuing
- It allows you to push higher batch sizes and possibly obtain a better result
- Behaves the same as full-precision training - fp32 won't make your model any better than bf16+int8.
- **int8** has hardware acceleration and `torch.compile()` support on newer NVIDIA hardware (3090 or better)
- **nf4** does not seem to benefit training as much as it benefits inference
- **nf4-bnb** brings VRAM requirements down to 9GB, fitting on a 10G card (with bfloat16 support)
- When loading the LoRA in ComfyUI later, you **must** use the same base model precision as you trained your LoRA on.
- **int4** is weird and really only works on A100 and H100 cards due to a reliance on custom bf16 kernels

Expand Down
24 changes: 24 additions & 0 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,30 @@ def get_argument_parser():
" which has improved results in short experiments. Thanks to @mhirki for the contribution."
),
)
parser.add_argument(
"--flux_use_beta_schedule",
action="store_true",
help=(
"Whether or not to use a beta schedule with Flux instead of sigmoid. The default values of alpha"
" and beta approximate a sigmoid."
),
)
parser.add_argument(
"--flux_beta_schedule_alpha",
type=float,
default=2.0,
help=(
"The alpha value of the flux beta schedule. Default is 2.0"
),
)
parser.add_argument(
"--flux_beta_schedule_beta",
type=float,
default=2.0,
help=(
"The beta value of the flux beta schedule. Default is 2.0"
),
)
parser.add_argument(
"--flux_schedule_shift",
type=float,
Expand Down
7 changes: 3 additions & 4 deletions helpers/data_backend/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

class S3DataBackend(BaseDataBackend):
# Storing the list_files output in a local dict.
_list_cache = {}
_list_cache: dict = {}

def __init__(
self,
Expand Down Expand Up @@ -301,9 +301,8 @@ def torch_load(self, s3_key):
try:
stored_tensor = self._decompress_torch(stored_tensor)
except Exception as e:
logger.error(
f"Failed to decompress torch file, falling back to passthrough: {e}"
)
pass

if hasattr(stored_tensor, "seek"):
stored_tensor.seek(0)

Expand Down
5 changes: 2 additions & 3 deletions helpers/data_backend/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,8 @@ def torch_load(self, filename):
try:
stored_tensor = self._decompress_torch(stored_tensor)
except Exception as e:
logger.error(
f"Failed to decompress torch file, falling back to passthrough: {e}"
)
pass

if hasattr(stored_tensor, "seek"):
stored_tensor.seek(0)
try:
Expand Down
2 changes: 1 addition & 1 deletion helpers/models/flux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,4 @@ def prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_id_channels,
)

return latent_image_ids.to(device=device, dtype=dtype)
return latent_image_ids.to(device=device, dtype=dtype)[0]
26 changes: 13 additions & 13 deletions helpers/models/flux/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import SD3LoraLoaderMixin
from diffusers.loaders import FluxLoraLoaderMixin
from diffusers.models.autoencoders import AutoencoderKL
from diffusers.models.transformers import FluxTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
Expand Down Expand Up @@ -147,7 +147,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
r"""
The Flux pipeline for text-to-image generation.
Expand Down Expand Up @@ -361,7 +361,7 @@ def encode_prompt(

# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
Expand Down Expand Up @@ -395,12 +395,12 @@ def encode_prompt(
)

if self.text_encoder is not None:
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)

if self.text_encoder_2 is not None:
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale)

Expand Down Expand Up @@ -794,9 +794,9 @@ def __call__(
self._num_timesteps = len(timesteps)

latents = latents.to(self.transformer.device)
latent_image_ids = latent_image_ids.to(self.transformer.device)
latent_image_ids = latent_image_ids.to(self.transformer.device)[0]
timesteps = timesteps.to(self.transformer.device)
text_ids = text_ids.to(self.transformer.device)
text_ids = text_ids.to(self.transformer.device)[0]

# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
Expand Down Expand Up @@ -824,16 +824,16 @@ def __call__(

noise_pred = self.transformer(
hidden_states=latents.to(
device=self.transformer.device, dtype=self.transformer.dtype
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds.to(
device=self.transformer.device, dtype=self.transformer.dtype
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
encoder_hidden_states=prompt_embeds.to(
device=self.transformer.device, dtype=self.transformer.dtype
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
txt_ids=text_ids,
img_ids=latent_image_ids,
Expand All @@ -846,16 +846,16 @@ def __call__(
if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
noise_pred_uncond = self.transformer(
hidden_states=latents.to(
device=self.transformer.device, dtype=self.transformer.dtype
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds.to(
device=self.transformer.device, dtype=self.transformer.dtype
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
encoder_hidden_states=negative_prompt_embeds.to(
device=self.transformer.device, dtype=self.transformer.dtype
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
txt_ids=negative_text_ids.to(device=self.transformer.device),
img_ids=latent_image_ids.to(device=self.transformer.device),
Expand Down
1 change: 1 addition & 0 deletions helpers/training/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
quantised_precision_levels = [
"no_change",
"nf4-bnb",
# "fp4-bnb",
# "fp8-bnb",
"fp8-quanto",
Expand Down
15 changes: 11 additions & 4 deletions helpers/training/diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,21 @@ def load_diffusion_model(args, weight_dtype):
pretrained_load_args = {
"revision": args.revision,
"variant": args.variant,
"torch_dtype": weight_dtype,
}
unet = None
transformer = None

if "nf4-bnb" == args.base_model_precision:
import torch
from diffusers import BitsAndBytesConfig
pretrained_load_args["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=weight_dtype,
)

if args.model_family == "sd3":
# Stable Diffusion 3 uses a Diffusion transformer.
logger.info("Loading Stable Diffusion 3 diffusion transformer..")
Expand All @@ -45,7 +56,6 @@ def load_diffusion_model(args, weight_dtype):
args.pretrained_transformer_model_name_or_path
or args.pretrained_model_name_or_path,
subfolder=determine_subfolder(args.pretrained_transformer_subfolder),
torch_dtype=weight_dtype,
**pretrained_load_args,
)
elif args.model_family.lower() == "flux" and args.flux_attention_masked_training:
Expand All @@ -56,7 +66,6 @@ def load_diffusion_model(args, weight_dtype):
transformer = FluxTransformer2DModelWithMasking.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
torch_dtype=weight_dtype,
**pretrained_load_args,
)
elif args.model_family == "pixart_sigma":
Expand All @@ -66,7 +75,6 @@ def load_diffusion_model(args, weight_dtype):
args.pretrained_transformer_model_name_or_path
or args.pretrained_model_name_or_path,
subfolder=determine_subfolder(args.pretrained_transformer_subfolder),
torch_dtype=weight_dtype,
**pretrained_load_args,
)
elif args.model_family == "smoldit":
Expand Down Expand Up @@ -100,7 +108,6 @@ def load_diffusion_model(args, weight_dtype):
args.pretrained_unet_model_name_or_path
or args.pretrained_model_name_or_path,
subfolder=determine_subfolder(args.pretrained_unet_subfolder),
torch_dtype=weight_dtype,
**pretrained_load_args,
)

Expand Down
Loading

0 comments on commit 01de5d0

Please sign in to comment.