Skip to content

Commit

Permalink
WIP: new latents caching
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Jul 8, 2024
1 parent 50e3d62 commit c9de7c4
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 8 deletions.
94 changes: 93 additions & 1 deletion library/sd3_train_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import math
import os
from typing import Optional, Tuple
from typing import List, Optional, Tuple

import torch
from safetensors.torch import save_file
Expand Down Expand Up @@ -283,6 +283,98 @@ def sample_images(*args, **kwargs):
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)


class Sd3LatensCachingStrategy(train_util.LatentsCachingStrategy):

This comment has been minimized.

Copy link
@rockerBOO

rockerBOO Jul 8, 2024

Contributor

*Sd3LatentsCachingStrategy

This comment has been minimized.

Copy link
@kohya-ss

kohya-ss Jul 9, 2024

Author Owner

Thank you! I fixed it :)

SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"

def __init__(self, vae: sd3_models.SDVAE, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
self.vae = vae

def get_latents_npz_path(self, absolute_path: str):
return os.path.splitext(absolute_path)[0] + self.SD3_LATENTS_NPZ_SUFFIX

def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True

expected_latents_size = (bucket_reso[1] // 8, bucket_reso[0] // 8) # bucket_reso is (W, H)

try:
npz = np.load(npz_path)
if npz["latents"].shape[1:3] != expected_latents_size:
return False

if flip_aug:
if "latents_flipped" not in npz:
return False
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
return False

if alpha_mask:
if "alpha_mask" not in npz:
return False
if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]):
return False
else:
if "alpha_mask" in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e

return True

def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
image_infos, alpha_mask, random_crop
)
img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype)

with torch.no_grad():
latents = self.vae.encode(img_tensor).to("cpu")
if flip_aug:
img_tensor = torch.flip(img_tensor, dims=[3])
with torch.no_grad():
flipped_latents = self.vae.encode(img_tensor).to("cpu")
else:
flipped_latents = [None] * len(latents)

for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks):
if self.cache_to_disk:
# save_latents_to_disk(
# info.latents_npz,
# latent,
# info.latents_original_size,
# info.latents_crop_ltrb,
# flipped_latent,
# alpha_mask,
# )
kwargs = {}
if flipped_latent is not None:
kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy()
if alpha_mask is not None:
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
np.savez(
info.latents_npz,
latents=latents.float().cpu().numpy(),
original_size=np.array(original_sizes),
crop_ltrb=np.array(crop_ltrbs),
**kwargs,
)
else:
info.latents = latent
if flip_aug:
info.latents_flipped = flipped_latent
info.alpha_mask = alpha_mask

if not train_util.HIGH_VRAM:
clean_memory_on_device(self.vae.device)


# region Diffusers


Expand Down
147 changes: 144 additions & 3 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,30 @@ def get_augmentor(self, use_color_aug: bool): # -> Optional[Callable[[np.ndarra
return self.color_aug if use_color_aug else None


class LatentsCachingStrategy:
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check

@property
def cache_to_disk(self):
return self._cache_to_disk

@property
def batch_size(self):
return self._batch_size

def get_latents_npz_path(self, absolute_path: str):
raise NotImplementedError

def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
raise NotImplementedError

def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
raise NotImplementedError


class BaseSubset:
def __init__(
self,
Expand Down Expand Up @@ -986,6 +1010,69 @@ def is_text_encoder_output_cacheable(self):
]
)

def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCachingStrategy):
r"""
a brand new method to cache latents. This method caches latents with caching strategy.
normal cache_latents method is used by default, but this method is used when caching strategy is specified.
"""
logger.info("caching latents with caching strategy.")
image_infos = list(self.image_data.values())

# sort by resolution
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])

# split by resolution
batches = []
batch = []
logger.info("checking cache validity...")
for info in tqdm(image_infos):
subset = self.image_to_subset[info.image_key]

if info.latents_npz is not None: # fine tuning dataset
continue

# check disk cache exists and size of latents
if caching_strategy.cache_to_disk:
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path)
if not is_main_process: # prepare for multi-gpu, only store to info
continue

cache_available = caching_strategy.is_disk_cached_latents_expected(
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
)
if cache_available: # do not add to batch
continue

# if last member of batch has different resolution, flush the batch
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
batches.append(batch)
batch = []

batch.append(info)

# if number of data in batch is enough, flush the batch
if len(batch) >= caching_strategy.batch_size:
batches.append(batch)
batch = []

if len(batch) > 0:
batches.append(batch)

# if cache to disk, don't cache latents in non-main process, set to info only
if caching_strategy.cache_to_disk and not is_main_process:
return

if len(batches) == 0:
logger.info("no latents to cache")
return

# iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded
logger.info("caching latents...")
for batch in tqdm(batches, smoothing=1, total=len(batches)):
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
caching_strategy.cache_batch_latents(batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)

def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
logger.info("caching latents.")
Expand Down Expand Up @@ -1086,7 +1173,7 @@ def cache_text_encoder_outputs_common(

if batch_size is None:
batch_size = self.batch_size

image_infos = list(self.image_data.values())

logger.info("checking cache existence...")
Expand Down Expand Up @@ -2207,6 +2294,11 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc
logger.info(f"[Dataset {i}]")
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix)

def new_cache_latents(self, is_main_process: bool, strategy: LatentsCachingStrategy):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.new_cache_latents(is_main_process, strategy)

def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
):
Expand Down Expand Up @@ -2550,6 +2642,51 @@ def trim_and_resize_if_required(
return image, original_size, crop_ltrb


# for new_cache_latents
def load_images_and_masks_for_caching(
image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool
) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
r"""
requires image_infos to have: [absolute_path or image], bucket_reso, resized_size
returns: image_tensor, alpha_masks, original_sizes, crop_ltrbs
image_tensor: torch.Tensor = torch.Size([B, 3, H, W]), ...], normalized to [-1, 1]
alpha_masks: List[np.ndarray] = [np.ndarray([H, W]), ...], normalized to [0, 1]
original_sizes: List[Tuple[int, int]] = [(W, H), ...]
crop_ltrbs: List[Tuple[int, int, int, int]] = [(L, T, R, B), ...]
"""
images: List[torch.Tensor] = []
alpha_masks: List[np.ndarray] = []
original_sizes: List[Tuple[int, int]] = []
crop_ltrbs: List[Tuple[int, int, int, int]] = []
for info in image_infos:
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)

original_sizes.append(original_size)
crop_ltrbs.append(crop_ltrb)

if use_alpha_mask:
if image.shape[2] == 4:
alpha_mask = image[:, :, 3] # [H,W]
alpha_mask = alpha_mask.astype(np.float32) / 255.0
alpha_mask = torch.FloatTensor(alpha_mask) # [H,W]
else:
alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W]
else:
alpha_mask = None
alpha_masks.append(alpha_mask)

image = image[:, :, :3] # remove alpha channel if exists
image = IMAGE_TRANSFORMS(image)
images.append(image)

img_tensor = torch.stack(images, dim=0)
return img_tensor, alpha_masks, original_sizes, crop_ltrbs


def cache_batch_latents(
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool
) -> None:
Expand Down Expand Up @@ -2661,7 +2798,7 @@ def cache_batch_text_encoder_outputs_sd3(
):
# make input_ids for each text encoder
l_tokens, g_tokens, t5_tokens = input_ids

clip_l, clip_g, t5xxl = text_encoders
with torch.no_grad():
b_lg_out, b_t5_out, b_pool = sd3_utils.get_cond_from_tokens(
Expand All @@ -2670,8 +2807,12 @@ def cache_batch_text_encoder_outputs_sd3(
b_lg_out = b_lg_out.detach()
b_t5_out = b_t5_out.detach()
b_pool = b_pool.detach()

for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool):
# debug: NaN check
if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any():
raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}")

if cache_to_disk:
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool)
else:
Expand Down
37 changes: 33 additions & 4 deletions sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,22 @@ def train(args):
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible
with torch.no_grad():
train_dataset_group.cache_latents(
vae_wrapper, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process, file_suffix="_sd3.npz"

if not args.new_caching:
vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible
with torch.no_grad():
train_dataset_group.cache_latents(
vae_wrapper,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
file_suffix="_sd3.npz",
)
else:
strategy = sd3_train_utils.Sd3LatensCachingStrategy(
vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
)
train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy)
vae.to("cpu")
clean_memory_on_device(accelerator.device)

Expand Down Expand Up @@ -699,6 +710,17 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents

# debug: NaN check for all inputs
if torch.any(torch.isnan(noisy_model_input)):
accelerator.print("NaN found in noisy_model_input, replacing with zeros")
noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input)
if torch.any(torch.isnan(context)):
accelerator.print("NaN found in context, replacing with zeros")
context = torch.nan_to_num(context, 0, out=context)
if torch.any(torch.isnan(pool)):
accelerator.print("NaN found in pool, replacing with zeros")
pool = torch.nan_to_num(pool, 0, out=pool)

# call model
with accelerator.autocast():
model_pred = mmdit(noisy_model_input, timesteps, context=context, y=pool)
Expand Down Expand Up @@ -908,6 +930,13 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
)

parser.add_argument("--new_caching", action="store_true", help="use new caching method / 新しいキャッシング方法を使う")
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",
help="skip latents validity check / latentsの正当性チェックをスキップする",
)
return parser


Expand Down

0 comments on commit c9de7c4

Please sign in to comment.