Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update submodules as well since MPS changes were needed #1606

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 159 additions & 11 deletions README.md

Large diffs are not rendered by default.

20 changes: 15 additions & 5 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers(
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)

# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
Expand Down Expand Up @@ -354,7 +358,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

# Predict the noise residual
with accelerator.autocast():
Expand All @@ -368,7 +374,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])

if args.min_snr_gamma:
Expand All @@ -380,7 +388,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

loss = loss.mean() # mean over batch dimension
else:
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
)

accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
Expand Down Expand Up @@ -471,7 +481,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

accelerator.end_training()

if is_main_process and (args.save_state or args.save_state_on_train_end):
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)

del accelerator # この後メモリを使うのでこれは消す
Expand Down
55 changes: 40 additions & 15 deletions gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,7 @@ class BatchDataBase(NamedTuple):
clip_prompt: str
guide_image: Any
raw_prompt: str
file_name: Optional[str]


class BatchDataExt(NamedTuple):
Expand Down Expand Up @@ -2316,7 +2317,7 @@ def scale_and_round(x):
# このバッチの情報を取り出す
(
return_latents,
(step_first, _, _, _, init_image, mask_image, _, guide_image, _),
(step_first, _, _, _, init_image, mask_image, _, guide_image, _, _),
(
width,
height,
Expand All @@ -2339,6 +2340,7 @@ def scale_and_round(x):
prompts = []
negative_prompts = []
raw_prompts = []
filenames = []
start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
noises = [
torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
Expand Down Expand Up @@ -2371,14 +2373,15 @@ def scale_and_round(x):
all_guide_images_are_same = True
for i, (
_,
(_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt),
(_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt, filename),
_,
) in enumerate(batch):
prompts.append(prompt)
negative_prompts.append(negative_prompt)
seeds.append(seed)
clip_prompts.append(clip_prompt)
raw_prompts.append(raw_prompt)
filenames.append(filename)

if init_image is not None:
init_images.append(init_image)
Expand Down Expand Up @@ -2478,8 +2481,8 @@ def scale_and_round(x):
# save image
highres_prefix = ("0" if highres_1st else "1") if highres_fix else ""
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate(
zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts)
for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt, filename) in enumerate(
zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts, filenames)
):
if highres_fix:
seed -= 1 # record original seed
Expand All @@ -2505,17 +2508,23 @@ def scale_and_round(x):
metadata.add_text("crop-top", str(crop_top))
metadata.add_text("crop-left", str(crop_left))

if args.use_original_file_name and init_images is not None:
if type(init_images) is list:
fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png"
else:
fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png"
elif args.sequential_file_name:
fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png"
if filename is not None:
fln = filename
else:
fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png"
if args.use_original_file_name and init_images is not None:
if type(init_images) is list:
fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png"
else:
fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png"
elif args.sequential_file_name:
fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png"
else:
fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png"

image.save(os.path.join(args.outdir, fln), pnginfo=metadata)
if fln.endswith(".webp"):
image.save(os.path.join(args.outdir, fln), pnginfo=metadata, quality=100) # lossy
else:
image.save(os.path.join(args.outdir, fln), pnginfo=metadata)

if not args.no_preview and not highres_1st and args.interactive:
try:
Expand Down Expand Up @@ -2562,6 +2571,7 @@ def scale_and_round(x):
# repeat prompt
for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0]
filename = None

if pi == 0 or len(raw_prompts) > 1:
# parse prompt: if prompt is not changed, skip parsing
Expand Down Expand Up @@ -2783,6 +2793,12 @@ def scale_and_round(x):
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
continue

m = re.match(r"f (.+)", parg, re.IGNORECASE)
if m: # filename
filename = m.group(1)
logger.info(f"filename: {filename}")
continue

except ValueError as ex:
logger.error(f"Exception in parsing / 解析エラー: {parg}")
logger.error(f"{ex}")
Expand Down Expand Up @@ -2873,7 +2889,16 @@ def scale_and_round(x):
b1 = BatchData(
False,
BatchDataBase(
global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt
global_step,
prompt,
negative_prompt,
seed,
init_image,
mask_image,
clip_prompt,
guide_image,
raw_prompt,
filename,
),
BatchDataExt(
width,
Expand Down Expand Up @@ -2916,7 +2941,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()

add_logging_arguments(parser)

parser.add_argument(
"--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む"
)
Expand Down
Binary file added library/.DS_Store
Binary file not shown.
106 changes: 106 additions & 0 deletions library/adafactor_fused.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import math
import torch
from transformers import Adafactor

@torch.no_grad()
def adafactor_step_param(self, p, group):
if p.grad is None:
return
grad = p.grad
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError("Adafactor does not support sparse gradients.")

state = self.state[p]
grad_shape = grad.shape

factored, use_first_moment = Adafactor._get_options(group, grad_shape)
# State Initialization
if len(state) == 0:
state["step"] = 0

if use_first_moment:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
if factored:
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
else:
state["exp_avg_sq"] = torch.zeros_like(grad)

state["RMS"] = 0
else:
if use_first_moment:
state["exp_avg"] = state["exp_avg"].to(grad)
if factored:
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
else:
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)

p_data_fp32 = p
if p.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()

state["step"] += 1
state["RMS"] = Adafactor._rms(p_data_fp32)
lr = Adafactor._get_lr(group, state)

beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
update = (grad ** 2) + group["eps"][0]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]

exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))

# Approximation of exponential moving average of square of gradient
update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
exp_avg_sq = state["exp_avg_sq"]

exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
update = exp_avg_sq.rsqrt().mul_(grad)

update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
update.mul_(lr)

if use_first_moment:
exp_avg = state["exp_avg"]
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
update = exp_avg

if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))

p_data_fp32.add_(-update)

if p.dtype in {torch.float16, torch.bfloat16}:
p.copy_(p_data_fp32)


@torch.no_grad()
def adafactor_step(self, closure=None):
"""
Performs a single optimization step

Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group["params"]:
adafactor_step_param(self, p, group)

return loss

def patch_adafactor_fused(optimizer: Adafactor):
optimizer.step_param = adafactor_step_param.__get__(optimizer)
optimizer.step = adafactor_step.__get__(optimizer)
7 changes: 7 additions & 0 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,13 @@ class DreamBoothSubsetParams(BaseSubsetParams):
class_tokens: Optional[str] = None
caption_extension: str = ".caption"
cache_info: bool = False
alpha_mask: bool = False


@dataclass
class FineTuningSubsetParams(BaseSubsetParams):
metadata_file: Optional[str] = None
alpha_mask: bool = False


@dataclass
Expand Down Expand Up @@ -191,6 +193,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"keep_tokens": int,
"keep_tokens_separator": str,
"secondary_separator": str,
"caption_separator": str,
"enable_wildcard": bool,
"token_warmup_min": int,
"token_warmup_step": Any(float, int),
Expand All @@ -212,11 +215,13 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
DB_SUBSET_DISTINCT_SCHEMA = {
Required("image_dir"): str,
"is_reg": bool,
"alpha_mask": bool,
}
# FT means FineTuning
FT_SUBSET_DISTINCT_SCHEMA = {
Required("metadata_file"): str,
"image_dir": str,
"alpha_mask": bool,
}
CN_SUBSET_ASCENDABLE_SCHEMA = {
"caption_extension": str,
Expand Down Expand Up @@ -523,6 +528,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
caption_separator: {subset.caption_separator}
secondary_separator: {subset.secondary_separator}
enable_wildcard: {subset.enable_wildcard}
caption_dropout_rate: {subset.caption_dropout_rate}
Expand All @@ -536,6 +542,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
alpha_mask: {subset.alpha_mask},
"""
),
" ",
Expand Down
14 changes: 11 additions & 3 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,12 +480,20 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):


def apply_masked_loss(loss, batch):
# mask image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
if "conditioning_images" in batch:
# conditioning image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
mask_image = mask_image / 2 + 0.5
# print(f"conditioning_image: {mask_image.shape}")
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
# alpha mask is 0 to 1
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
else:
return loss

# resize to the same size as the loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
mask_image = mask_image / 2 + 0.5
loss = loss * mask_image
return loss

Expand Down
2 changes: 1 addition & 1 deletion library/ipex/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# pylint: disable=protected-access, missing-function-docstring, line-too-long

# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers

sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
Expand Down
Loading