Skip to content

Commit

Permalink
Merge pull request #346 from bghira/main
Browse files Browse the repository at this point in the history
multigpu fixes | regression fixes | better image detection
  • Loading branch information
bghira authored Apr 9, 2024
2 parents a9b8499 + f2fc3bf commit 42fd40e
Show file tree
Hide file tree
Showing 18 changed files with 1,651 additions and 146 deletions.
2 changes: 1 addition & 1 deletion OPTIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ This guide provides a user-friendly breakdown of the command-line options availa
### `--crop_style`

- **What**: When `--crop=true`, the trainer may be instructed to crop in different ways.
- **Why**: The `crop_style` option can be set to `center` (or `centre`) for a classic centre-crop, `corner` to elect for the lowest-right corner, and `random` for a random image slice. Default: random.
- **Why**: The `crop_style` option can be set to `center` (or `centre`) for a classic centre-crop, `corner` to elect for the lowest-right corner, `face` to detect and centre upon the largest subject face, and `random` for a random image slice. Default: random.

### `--crop_aspect`

Expand Down
4 changes: 2 additions & 2 deletions documentation/DATALOADER.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Here is an example dataloader configuration file, as `multidatabackend.example.j
"type": "local",
"instance_data_dir": "/path/to/data/tree",
"crop": false,
"crop_style": "random|center|corner",
"crop_style": "random|center|corner|face",
"crop_aspect": "square|preserve",
"resolution": 1.0,
"resolution_type": "area|pixel",
Expand Down Expand Up @@ -87,7 +87,7 @@ Here is an example dataloader configuration file, as `multidatabackend.example.j

### Cropping Options
- `crop`: Enables or disables image cropping.
- `crop_style`: Selects the cropping style (`random`, `center`, `corner`).
- `crop_style`: Selects the cropping style (`random`, `center`, `corner`, `face`).
- `crop_aspect`: Chooses the cropping aspect (`square` or `preserve`).

### `resolution`
Expand Down
9 changes: 8 additions & 1 deletion helpers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
from helpers.training.state_tracker import StateTracker

logger = logging.getLogger("ArgsParser")
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))
# Are we the main process?
if __name__ == "__main__":
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))
else:
logger.setLevel("ERROR")

if torch.cuda.is_available():
os.environ["NCCL_SOCKET_NTIMEO"] = "2000000"


def parse_args(input_args=None):
Expand Down
78 changes: 45 additions & 33 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import json, os, torch, logging, io

logger = logging.getLogger("DataBackendFactory")
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))


def init_backend_config(backend: dict, args: dict, accelerator) -> dict:
Expand Down Expand Up @@ -206,6 +205,11 @@ def configure_multi_databackend(
"""
Configure a multiple dataloaders based on the provided commandline args.
"""
logger.setLevel(
os.environ.get(
"SIMPLETUNER_LOG_LEVEL", "INFO" if accelerator.is_main_process else "ERROR"
)
)
if args.data_backend_config is None:
raise ValueError(
"Must provide a data backend config file via --data_backend_config"
Expand Down Expand Up @@ -270,7 +274,7 @@ def configure_multi_databackend(
raise ValueError(f"Unknown data backend type: {backend['type']}")

preserve_data_backend_cache = backend.get("preserve_data_backend_cache", False)
if not preserve_data_backend_cache:
if not preserve_data_backend_cache and accelerator.is_local_main_process:
StateTracker.delete_cache_files(
data_backend_id=init_backend["id"],
preserve_data_backend_cache=preserve_data_backend_cache,
Expand Down Expand Up @@ -298,7 +302,7 @@ def configure_multi_databackend(
and args.caption_dropout_probability > 0
):
logger.info("Pre-computing null embedding for caption dropout")
with accelerator.main_process_first():
if accelerator.is_local_main_process:
init_backend["text_embed_cache"].compute_embeddings_for_prompts(
[""], return_concat=False, load_from_cache=False
)
Expand Down Expand Up @@ -412,29 +416,32 @@ def configure_multi_databackend(
)
else:
raise ValueError(f"Unknown metadata backend type: {metadata_backend}")
init_backend["metadata_backend"] = BucketManager_cls(
id=init_backend["id"],
instance_data_root=init_backend["instance_data_root"],
data_backend=init_backend["data_backend"],
accelerator=accelerator,
resolution=backend.get("resolution", args.resolution),
minimum_image_size=backend.get(
"minimum_image_size", args.minimum_image_size
),
resolution_type=backend.get("resolution_type", args.resolution_type),
batch_size=args.train_batch_size,
metadata_update_interval=backend.get(
"metadata_update_interval", args.metadata_update_interval
),
cache_file=os.path.join(
init_backend["instance_data_root"], "aspect_ratio_bucket_indices.json"
),
metadata_file=os.path.join(
init_backend["instance_data_root"], "aspect_ratio_bucket_metadata.json"
),
delete_problematic_images=args.delete_problematic_images or False,
**metadata_backend_args,
)
with accelerator.main_process_first():
init_backend["metadata_backend"] = BucketManager_cls(
id=init_backend["id"],
instance_data_root=init_backend["instance_data_root"],
data_backend=init_backend["data_backend"],
accelerator=accelerator,
resolution=backend.get("resolution", args.resolution),
minimum_image_size=backend.get(
"minimum_image_size", args.minimum_image_size
),
resolution_type=backend.get("resolution_type", args.resolution_type),
batch_size=args.train_batch_size,
metadata_update_interval=backend.get(
"metadata_update_interval", args.metadata_update_interval
),
cache_file=os.path.join(
init_backend["instance_data_root"],
"aspect_ratio_bucket_indices.json",
),
metadata_file=os.path.join(
init_backend["instance_data_root"],
"aspect_ratio_bucket_metadata.json",
),
delete_problematic_images=args.delete_problematic_images or False,
**metadata_backend_args,
)

if "aspect" not in args.skip_file_discovery or "aspect" not in backend.get(
"skip_file_discovery", ""
Expand All @@ -460,6 +467,7 @@ def configure_multi_databackend(
init_backend["metadata_backend"].split_buckets_between_processes(
gradient_accumulation_steps=args.gradient_accumulation_steps,
)
accelerator.wait_for_everyone()

# Check if there is an existing 'config' in the metadata_backend.config
excluded_keys = [
Expand Down Expand Up @@ -559,7 +567,7 @@ def configure_multi_databackend(
f"Backend {init_backend['id']} has prepend_instance_prompt=True, but no instance_prompt was provided. You must provide an instance_prompt, or disable this option."
)

with accelerator.main_process_first():
if accelerator.is_local_main_process:
# We get captions from the IMAGE dataset. Not the text embeds dataset.
captions = PromptHandler.get_all_captions(
data_backend=init_backend["data_backend"],
Expand All @@ -569,6 +577,7 @@ def configure_multi_databackend(
use_captions=use_captions,
caption_strategy=backend.get("caption_strategy", args.caption_strategy),
)

if "text" not in args.skip_file_discovery and "text" not in backend.get(
"skip_file_discovery", ""
):
Expand Down Expand Up @@ -619,10 +628,11 @@ def configure_multi_databackend(
vae_cache_preprocess=args.vae_cache_preprocess,
)

logger.info(f"(id={init_backend['id']}) Discovering cache objects..")
if accelerator.is_local_main_process:
init_backend["vaecache"].discover_all_files()
accelerator.wait_for_everyone()
if args.vae_cache_preprocess:
logger.info(f"(id={init_backend['id']}) Discovering cache objects..")
if accelerator.is_local_main_process:
init_backend["vaecache"].discover_all_files()
accelerator.wait_for_everyone()

if (
(
Expand Down Expand Up @@ -651,8 +661,10 @@ def configure_multi_databackend(
init_backend["metadata_backend"].load_image_metadata()
accelerator.wait_for_everyone()

if "vae" not in args.skip_file_discovery and "vae" not in backend.get(
"skip_file_discovery", ""
if (
args.vae_cache_preprocess
and "vae" not in args.skip_file_discovery
and "vae" not in backend.get("skip_file_discovery", "")
):
init_backend["vaecache"].split_cache_between_processes()
if args.vae_cache_preprocess:
Expand Down
25 changes: 0 additions & 25 deletions helpers/image_manipulation/broken_images.py

This file was deleted.

2 changes: 1 addition & 1 deletion helpers/legacy/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from helpers.training.state_tracker import StateTracker
from helpers.training.wrappers import unwrap_model
from helpers.prompts import PromptHandler
from helpers.sdxl.pipeline import StableDiffusionXLPipeline
from diffusers import (
AutoencoderKL,
StableDiffusionXLPipeline,
DDIMScheduler,
DiffusionPipeline,
)
Expand Down
6 changes: 3 additions & 3 deletions helpers/metadata/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def __init__(
self.metadata_file = Path(metadata_file)
self.aspect_ratio_bucket_indices = {}
self.image_metadata = {} # Store image metadata
self.instance_images_path = set()
self.seen_images = {}
self.config = {}
self.reload_cache()
Expand Down Expand Up @@ -154,6 +153,9 @@ def compute_aspect_ratio_bucket_indices(self):
new_files = self._discover_new_files()

existing_files_set = set().union(*self.aspect_ratio_bucket_indices.values())
logger.info(
f"Compressed {len(existing_files_set)} existing files from {len(self.aspect_ratio_bucket_indices.values())}."
)
# Initialize aggregated statistics
aggregated_statistics = {
"total_processed": 0,
Expand Down Expand Up @@ -250,7 +252,6 @@ def compute_aspect_ratio_bucket_indices(self):
logger.debug(
f"In-flight metadata update after {processing_duration} seconds. Saving {len(self.image_metadata)} metadata entries and {len(self.aspect_ratio_bucket_indices)} aspect bucket lists."
)
self.instance_images_path.update(written_files)
self.save_cache(enforce_constraints=False)
self.save_image_metadata()
last_write_time = current_time
Expand All @@ -260,7 +261,6 @@ def compute_aspect_ratio_bucket_indices(self):
for worker in workers:
worker.join()
logger.info(f"Image processing statistics: {aggregated_statistics}")
self.instance_images_path.update(new_files)
self.save_image_metadata()
self.save_cache(enforce_constraints=True)
logger.info("Completed aspect bucket update.")
Expand Down
61 changes: 38 additions & 23 deletions helpers/metadata/backends/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,33 +69,44 @@ def _discover_new_files(self, for_metadata: bool = False):
"""
all_image_files = StateTracker.get_image_files(
data_backend_id=self.data_backend.id
) or StateTracker.set_image_files(
self.data_backend.list_files(
)
if all_image_files is None:
logger.debug("No image file cache available, retrieving fresh")
all_image_files = self.data_backend.list_files(
instance_data_root=self.instance_data_root,
str_pattern="*.[jJpP][pPnN][gG]",
),
data_backend_id=self.data_backend.id,
)
# Log an excerpt of the all_image_files:
# logger.debug(
# f"Found {len(all_image_files)} images in the instance data root (truncated): {list(all_image_files)[:5]}"
# )
# Extract only the files from the data
)
all_image_files = StateTracker.set_image_files(
all_image_files, data_backend_id=self.data_backend.id
)
else:
logger.debug("Using cached image file list")

# Flatten the list if it contains nested lists
if any(isinstance(i, list) for i in all_image_files):
all_image_files = [item for sublist in all_image_files for item in sublist]

logger.debug(f"All image files: {json.dumps(all_image_files, indent=4)}")

all_image_files_set = set(all_image_files)

if for_metadata:
result = [
file
for file in all_image_files
if self.get_metadata_by_filepath(file) is None
]
# logger.debug(
# f"Found {len(result)} new images for metadata scan (truncated): {list(result)[:5]}"
# )
return result
return [
file
for file in all_image_files
if str(file) not in self.instance_images_path
]
else:
processed_files = set(
path
for paths in self.aspect_ratio_bucket_indices.values()
for path in paths
)
result = [
file for file in all_image_files_set if file not in processed_files
]

return result

def reload_cache(self):
"""
Expand All @@ -105,13 +116,13 @@ def reload_cache(self):
dict: The cache data.
"""
# Query our DataBackend to see whether the cache file exists.
logger.info(f"Checking for cache file: {self.cache_file}")
if self.data_backend.exists(self.cache_file):
try:
# Use our DataBackend to actually read the cache file.
logger.debug("Pulling cache file from storage.")
logger.info(f"Pulling cache file from storage")
cache_data_raw = self.data_backend.read(self.cache_file)
cache_data = json.loads(cache_data_raw)
logger.debug("Completed loading cache data.")
except Exception as e:
logger.warning(
f"Error loading aspect bucket cache, creating new one: {e}"
Expand All @@ -128,7 +139,11 @@ def reload_cache(self):
data_backend_id=self.id,
config=self.config,
)
self.instance_images_path = set(cache_data.get("instance_images_path", []))
logger.debug(
f"(id={self.id}) Loaded {len(self.aspect_ratio_bucket_indices)} aspect ratio buckets"
)
else:
logger.warning("No cache file found, creating new one.")

def save_cache(self, enforce_constraints: bool = False):
"""
Expand All @@ -148,7 +163,6 @@ def save_cache(self, enforce_constraints: bool = False):
data_backend_id=self.data_backend.id
),
"aspect_ratio_bucket_indices": aspect_ratio_bucket_indices_str,
"instance_images_path": [str(path) for path in self.instance_images_path],
}
logger.debug(f"save_cache has config to write: {cache_data['config']}")
cache_data_str = json.dumps(cache_data)
Expand Down Expand Up @@ -184,6 +198,7 @@ def _process_for_bucket(
logger.debug(
f"Image {image_path_str} was not found on the backend. Skipping image."
)
del image_data
statistics["skipped"]["not_found"] += 1
return aspect_ratio_bucket_indices
with Image.open(BytesIO(image_data)) as image:
Expand Down
Loading

0 comments on commit 42fd40e

Please sign in to comment.