Skip to content

Commit

Permalink
Merge pull request #209 from bghira/main
Browse files Browse the repository at this point in the history
v0.7.1
  • Loading branch information
bghira authored Oct 15, 2023
2 parents 7a54f38 + 09b0426 commit f37884b
Show file tree
Hide file tree
Showing 8 changed files with 187 additions and 80 deletions.
152 changes: 99 additions & 53 deletions OPTIONS.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion documentation/DEEPSPEED.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ SimpleTuner v0.7 includes preliminary support for training SDXL using DeepSpeed
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... Off | 00000000:08:00.0 Off | Off |
| 0% 43C P2 100W / 450W | 9237MiB / 24564MiB | 0% Default |
| 0% 43C P2 100W / 450W | 9237MiB / 24564MiB | 100% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
Expand Down
2 changes: 2 additions & 0 deletions helpers/caching/sdxl_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,12 @@ def encode_sdxl_prompt(
# We are always interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds_output[0]
prompt_embeds = prompt_embeds_output.hidden_states[-2]
del prompt_embeds_output
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)

# Clear out anything we moved to the text encoder device
text_input_ids.to('cpu')
del text_input_ids

prompt_embeds_list.append(prompt_embeds)
Expand Down
60 changes: 45 additions & 15 deletions helpers/caching/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from helpers.data_backend.base import BaseDataBackend
from helpers.training.state_tracker import StateTracker
from helpers.training.multi_process import _get_rank as get_rank
from helpers.training.multi_process import rank_info

logger = logging.getLogger("VAECache")
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL") or "INFO")
Expand Down Expand Up @@ -39,6 +40,10 @@ def __init__(
self.vae_batch_size = vae_batch_size
self.instance_data_root = instance_data_root
self.transform = MultiaspectImage.get_image_transforms()
self.rank_info = rank_info()

def debug_log(self, msg: str):
logger.debug(f"{self.rank_info}{msg}")

def generate_vae_cache_filename(self, filepath: str) -> tuple:
"""Get the cache filename for a given image filepath and its base name."""
Expand Down Expand Up @@ -74,31 +79,32 @@ def discover_all_files(self, directory: str = None):
)
)
)
logger.debug(f"VAECache discover_all_files found {len(all_image_files)} images")
self.debug_log(
f"VAECache discover_all_files found {len(all_image_files)} images"
)
return all_image_files

def discover_unprocessed_files(self, directory: str = None):
"""Identify files that haven't been processed yet."""
all_image_files = StateTracker.get_image_files()
existing_cache_files = StateTracker.get_vae_cache_files()
logger.debug(
self.debug_log(
f"discover_unprocessed_files found {len(all_image_files)} images from StateTracker (truncated): {list(all_image_files)[:5]}"
)
logger.debug(
self.debug_log(
f"discover_unprocessed_files found {len(existing_cache_files)} already-processed cache files (truncated): {list(existing_cache_files)[:5]}"
)
cache_filenames = {
self.generate_vae_cache_filename(file)[1] for file in all_image_files
}
logger.debug(
self.debug_log(
f"discover_unprocessed_files found {len(cache_filenames)} cache filenames (truncated): {list(cache_filenames)[:5]}"
)
unprocessed_files = {
f"{os.path.splitext(file)[0]}.png"
for file in cache_filenames
if file not in existing_cache_files
}

return list(unprocessed_files)

def _list_cached_images(self):
Expand All @@ -107,9 +113,12 @@ def _list_cached_images(self):
"""
# Extract array of tuple into just, an array of files:
pt_files = StateTracker.get_vae_cache_files()
logging.debug(f"Found {len(pt_files)} cached files in {self.cache_dir}")
# Extract just the base filename without the extension
return {os.path.splitext(f)[0] for f in pt_files}
results = {os.path.splitext(f)[0] for f in pt_files}
logging.debug(
f"Found {len(pt_files)} cached files in {self.cache_dir} (truncated): {list(results)[:5]}"
)
return results

def encode_image(self, image, filepath):
"""
Expand Down Expand Up @@ -177,14 +186,16 @@ def encode_images(self, images, filepaths, load_from_cache=True):

def split_cache_between_processes(self):
all_unprocessed_files = self.discover_unprocessed_files(self.cache_dir)
logger.debug(f"All unprocessed files: {all_unprocessed_files[:5]} (truncated)")
self.debug_log(
f"All unprocessed files: {all_unprocessed_files[:5]} (truncated)"
)
# Use the accelerator to split the data
with self.accelerator.split_between_processes(
all_unprocessed_files
) as split_files:
self.local_unprocessed_files = split_files
# Print the first 5 as a debug log:
logger.debug(
self.debug_log(
f"Local unprocessed files: {self.local_unprocessed_files[:5]} (truncated)"
)

Expand Down Expand Up @@ -237,16 +248,21 @@ def process_buckets(self, bucket_manager):
aspect_bucket_cache = bucket_manager.read_cache().copy()

# Extract and shuffle the keys of the dictionary
shuffled_keys = list(aspect_bucket_cache.keys())
shuffle(shuffled_keys)
do_shuffle = os.environ.get('SIMPLETUNER_SHUFFLE_ASPECTS', 'true').lower() == 'true'
if do_shuffle:
shuffled_keys = list(aspect_bucket_cache.keys())
shuffle(shuffled_keys)

for bucket in shuffled_keys:
relevant_files = [
f
for f in aspect_bucket_cache[bucket]
if os.path.splitext(os.path.basename(f))[0] not in processed_images
and f in self.local_unprocessed_files
]
logger.debug(
if do_shuffle:
shuffle(relevant_files)
self.debug_log(
f"Reduced bucket {bucket} down from {len(aspect_bucket_cache[bucket])} to {len(relevant_files)} relevant files"
)
if len(relevant_files) == 0:
Expand All @@ -271,12 +287,20 @@ def process_buckets(self, bucket_manager):
)
test_filepath = f"{os.path.splitext(self.generate_vae_cache_filename(filepath)[1])[0]}.png"
if test_filepath not in self.local_unprocessed_files:
logger.debug(
self.debug_log(
f"Skipping {test_filepath} because it is not in local unprocessed files"
)
continue
try:
logger.debug(
# Does it exist on the backend?
if self.data_backend.exists(
self.generate_vae_cache_filename(filepath)[0]
):
self.debug_log(
f"Skipping {filepath} because it is already in the cache"
)
continue
self.debug_log(
f"Processing {filepath} because it is in local unprocessed files"
)
image = self.data_backend.read_image(filepath)
Expand All @@ -289,6 +313,9 @@ def process_buckets(self, bucket_manager):
)
vae_input_images.append(pixel_values)
vae_input_filepaths.append(filepath)
self.debug_log(
f"Completed processing {filepath}"
)
except ValueError as e:
logger.error(f"Received fatal error: {e}")
raise e
Expand All @@ -304,7 +331,7 @@ def process_buckets(self, bucket_manager):

# If VAE input batch is ready
if len(vae_input_images) >= self.vae_batch_size:
logger.debug(
self.debug_log(
f"Reached a VAE batch size of {self.vae_batch_size} pixel groups, so we will now encode them into latents."
)
latents_batch = self.encode_images(
Expand All @@ -331,6 +358,9 @@ def process_buckets(self, bucket_manager):

# Handle remainders after processing the bucket
if vae_input_images: # If there are images left to be encoded
self.debug_log(
f"Processing the remainder, {len(vae_input_images)} images"
)
latents_batch = self.encode_images(
vae_input_images, vae_input_filepaths, load_from_cache=False
)
Expand Down
2 changes: 1 addition & 1 deletion helpers/data_backend/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def list_files(self, str_pattern: str, instance_data_root: str = None):
)

# Paginating over the entire bucket objects
for page in paginator.paginate(Bucket=self.bucket_name):
for page in paginator.paginate(Bucket=self.bucket_name, MaxKeys=10000):
for obj in page.get("Contents", []):
# Filter based on the provided pattern
if fnmatch.fnmatch(obj["Key"], pattern):
Expand Down
22 changes: 19 additions & 3 deletions helpers/log_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def format(self, record):
accel_logger = logging.getLogger("DeepSpeed")
accel_logger.setLevel(logging.WARNING)
new_handler = logging.StreamHandler()
new_handler.setFormatter(ColorizedFormatter("%(asctime)s [%(levelname)s] (%(name)s) %(message)s"))
new_handler.setFormatter(
ColorizedFormatter("%(asctime)s [%(levelname)s] (%(name)s) %(message)s")
)
# Remove existing handlers
for handler in logger.handlers[:]:
logger.removeHandler(handler)
Expand All @@ -40,6 +42,20 @@ def format(self, record):
pil_logger = logging.getLogger("PIL")
pil_logger.setLevel(logging.INFO)
pil_logger = logging.getLogger("PIL.Image")
pil_logger.setLevel("WARNING")
pil_logger.setLevel("ERROR")
pil_logger = logging.getLogger("PIL.PngImagePlugin")
pil_logger.setLevel("WARNING")
pil_logger.setLevel("ERROR")
transformers_logger = logging.getLogger("transformers.configuration_utils")
transformers_logger.setLevel("ERROR")
diffusers_logger = logging.getLogger("diffusers.configuration_utils")
diffusers_logger.setLevel("ERROR")

import warnings

# Suppress specific PIL warning
warnings.filterwarnings(
"ignore",
category=UserWarning,
module="PIL",
message="Palette images with Transparency expressed in bytes should be converted to RGBA images",
)
12 changes: 10 additions & 2 deletions helpers/multiaspect/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from helpers.multiaspect.image import MultiaspectImage
from helpers.data_backend.base import BaseDataBackend
from pathlib import Path
import json, logging, os
import json, logging, os, time
from multiprocessing import Manager
from tqdm import tqdm
from multiprocessing import Process, Queue
Expand Down Expand Up @@ -217,7 +217,13 @@ def compute_aspect_ratio_bucket_indices(self):
for worker in workers:
worker.start()

with tqdm(total=len(new_files), leave=False, ncols=100) as pbar:
with tqdm(
desc="Generating aspect bucket cache",
total=len(new_files),
leave=False,
ncols=100,
miniters=int(len(new_files) / 100),
) as pbar:
while any(worker.is_alive() for worker in workers):
while not tqdm_queue.empty():
pbar.update(tqdm_queue.get())
Expand All @@ -237,6 +243,8 @@ def compute_aspect_ratio_bucket_indices(self):
filepath=filepath, metadata=meta, update_json=False
)

time.sleep(0.1)

for worker in workers:
worker.join()

Expand Down
15 changes: 10 additions & 5 deletions train_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,11 +817,15 @@ def main():
accelerator.register_load_state_pre_hook(model_hooks.load_model_hook)

# Prepare everything with our `accelerator`.
logger.info(f"Loading our accelerator...")
unet, train_dataloader, lr_scheduler, optimizer = accelerator.prepare(unet, train_dataloader, lr_scheduler, optimizer)
if args.use_ema:
logger.info("Moving EMA model weights to accelerator...")
ema_unet.to(accelerator.device, dtype=weight_dtype)
disable_accelerator = os.environ.get('SIMPLETUNER_DISABLE_ACCELERATOR', False)
if not disable_accelerator:
logger.info(f"Loading our accelerator...")
unet, train_dataloader, lr_scheduler, optimizer = accelerator.prepare(
unet, train_dataloader, lr_scheduler, optimizer
)
if args.use_ema:
logger.info("Moving EMA model weights to accelerator...")
ema_unet.to(accelerator.device, dtype=weight_dtype)

# Move vae, unet and text_encoder to device and cast to weight_dtype
# The VAE is in float32 to avoid NaN losses.
Expand Down Expand Up @@ -1022,6 +1026,7 @@ def main():
accelerator.wait_for_everyone()
timesteps_buffer = []
train_loss = 0.0
step = global_step
training_luminance_values = []

for epoch in range(first_epoch, args.num_train_epochs):
Expand Down

0 comments on commit f37884b

Please sign in to comment.