Skip to content

Commit

Permalink
wip: omnigen initial training support attempt, batch size 1 only, ext…
Browse files Browse the repository at this point in the history
…remely high loss
  • Loading branch information
bghira committed Nov 2, 2024
1 parent 39c05a7 commit e1122fd
Show file tree
Hide file tree
Showing 16 changed files with 1,335 additions and 312 deletions.
8 changes: 6 additions & 2 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
"full": [
"flux",
"sdxl",
"omnigen",
"pixart_sigma",
"kolors",
"sd3",
"legacy",
],
"lora": ["flux", "sdxl", "kolors", "sd3", "legacy"],
"lora": ["flux", "sdxl", "kolors", "sd3", "legacy", "omnigen"],
"controlnet": ["sdxl", "legacy"],
}

Expand All @@ -34,6 +35,7 @@
"kolors": "kwai-kolors/kolors-diffusers",
"terminus": "ptx0/terminus-xl-velocity-v2",
"sd3": "stabilityai/stable-diffusion-3.5-large",
"omnigen": "Shitao/OmniGen-v1",
"legacy": "stabilityai/stable-diffusion-2-1-base",
}

Expand All @@ -43,12 +45,14 @@
"pixart_sigma": 3.4,
"kolors": 5.0,
"terminus": 8.0,
"sd3": 5.0,
"omnigen": 3.0,
"sd3": 6.0,
}

model_labels = {
"sd3": "Stable Diffusion 3",
"flux": "FLUX",
"omnigen": "OmniGen",
"pixart_sigma": "PixArt Sigma",
"kolors": "Kwai Kolors",
"terminus": "Terminus",
Expand Down
220 changes: 194 additions & 26 deletions helpers/caching/text_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,21 +208,14 @@ def discover_all_files(self):

def save_to_cache(self, filename, embeddings):
"""Add write requests to the queue instead of writing directly."""
if not self.batch_write_thread.is_alive():
logger.debug("Restarting background write thread.")
# Start the thread again.
self.process_write_batches = True
self.batch_write_thread = Thread(target=self.batch_write_embeddings)
self.batch_write_thread.start()
self.process_write_batches = True
self.write_queue.put((embeddings, filename))
logger.debug(
f"save_to_cache called for {filename}, write queue has {self.write_queue.qsize()} items, and the write thread's status: {self.batch_write_thread.is_alive()}"
)

def batch_write_embeddings(self):
"""Process write requests in batches."""
batch = []
written_elements = 0
while True:
try:
# Block until an item is available or timeout occurs
Expand All @@ -233,25 +226,14 @@ def batch_write_embeddings(self):
while (
not self.write_queue.empty() and len(batch) < self.write_batch_size
):
logger.debug("Retrieving more items from the queue.")
items = self.write_queue.get_nowait()
batch.append(items)
logger.debug(f"Batch now contains {len(batch)} items.")

self.process_write_batch(batch)
self.write_thread_bar.update(len(batch))
logger.debug("Processed batch write.")
written_elements += len(batch)

except queue.Empty:
# Timeout occurred, no items were ready
if not self.process_write_batches:
if len(batch) > 0:
self.process_write_batch(batch)
self.write_thread_bar.update(len(batch))
logger.debug(f"Exiting batch write thread, no more work to do after writing {written_elements} elements")
break
logger.debug(f"Queue is empty. Retrieving new entries. Should retrieve? {self.process_write_batches}")
pass
except Exception:
logger.exception("An error occurred while writing embeddings to disk.")
Expand All @@ -260,7 +242,6 @@ def batch_write_embeddings(self):
def process_write_batch(self, batch):
"""Write a batch of embeddings to the cache."""
logger.debug(f"Writing {len(batch)} items to disk")
logger.debug(f"Batch: {batch}")
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [
executor.submit(self.data_backend.torch_save, *args) for args in batch
Expand Down Expand Up @@ -288,8 +269,8 @@ def encode_flux_prompt(
text_encoders: List of text encoders.
tokenizers: List of tokenizers.
prompt: The prompt to encode.
num_images_per_prompt: The number of images to generate per prompt.
is_validation: Whether the prompt is for validation. No-op for SD3.
zero_padding_tokens: Whether to zero out padding tokens.
Returns:
Tuple of (prompt_embeds, pooled_prompt_embeds).
Expand Down Expand Up @@ -320,6 +301,31 @@ def encode_flux_prompt(

return prompt_embeds, pooled_prompt_embeds, time_ids, masks

def encode_omnigen_prompt(
self, text_encoders, tokenizers, prompt: str, is_validation: bool = False
):
"""
Encode a prompt for an OmniGen model.
Args:
text_encoders: List of text encoders.
tokenizers: List of tokenizers.
prompt: The prompt to encode.
is_validation: Whether the prompt is for validation. No-op for OmniGen.
Returns:
Dict of OmniGen inputs
"""
# it's not a text encoder, it's the MLLM preprocessor / tokeniser.
processed = text_encoders[0](
instructions=prompt,
# use_img_cfg=False,
# separate_cfg_input=False,
# use_input_image_size_as_output=False,
)

return processed

# Adapted from pipelines.StableDiffusion3Pipeline.encode_prompt
def encode_sd3_prompt(
self,
Expand Down Expand Up @@ -525,9 +531,7 @@ def encode_prompt(self, prompt: str, is_validation: bool = False):
prompt,
is_validation,
zero_padding_tokens=(
True
if StateTracker.get_args().t5_padding == "zero"
else False
True if StateTracker.get_args().t5_padding == "zero" else False
),
)
else:
Expand Down Expand Up @@ -666,6 +670,12 @@ def compute_embeddings_for_prompts(
return_concat=return_concat,
load_from_cache=load_from_cache,
)
elif self.model_type == "omnigen":
output = self.compute_embeddings_for_omnigen_prompts(
raw_prompts,
return_concat=return_concat,
load_from_cache=load_from_cache,
)
else:
raise ValueError(
f"No such text encoding backend for model type '{self.model_type}'"
Expand Down Expand Up @@ -1039,6 +1049,164 @@ def compute_embeddings_for_legacy_prompts(
return prompt_embeds_all, attention_masks_all
return prompt_embeds_all

def compute_embeddings_for_omnigen_prompts(
self,
prompts: list = None,
return_concat: bool = True,
is_validation: bool = False,
load_from_cache: bool = True,
):
# print(f"Computing embeddings for Omnigen prompts")
# processed = self.text_encoders[0](
# prompts,
# )

# # processed looks like:
# # {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
# print(f"Processed: {processed.keys()}")

# return processed
processed_all = []
should_encode = not load_from_cache
args = StateTracker.get_args()
if should_encode:
local_caption_split = self.split_captions_between_processes(
prompts or self.prompts
)
else:
local_caption_split = prompts or self.prompts
if (
hasattr(args, "cache_clear_validation_prompts")
and args.cache_clear_validation_prompts
and is_validation
):
# If --cache_clear_validation_prompts was provided, we will forcibly overwrite them.
load_from_cache = False
should_encode = True

if self.webhook_handler is not None:
last_reported_index = 0
self.send_progress_update(
type="init_cache_text_embeds_started",
progress=int(0 // len(local_caption_split)),
total=len(local_caption_split),
current=0,
)
self.write_thread_bar = tqdm(
desc="Write embeds to disk",
leave=False,
ncols=125,
disable=return_concat,
total=len(local_caption_split),
position=get_rank(),
)
with torch.no_grad():
last_reported_index = 0
for prompt in tqdm(
local_caption_split,
desc="Processing prompts",
disable=return_concat,
miniters=50,
leave=False,
ncols=125,
position=get_rank() + self.accelerator.num_processes + 1,
):
filename = os.path.join(self.cache_dir, self.hash_prompt(prompt))
debug_msg = f"Processing file: {filename}, prompt: {prompt}"
prompt = PromptHandler.filter_caption(self.data_backend, prompt)
debug_msg = f"{debug_msg}\n -> filtered prompt: {prompt}"
if prompt is None:
logger.error(f"Filename {filename} does not have a caption.")
continue
logger.debug(debug_msg)
if return_concat and load_from_cache:
try:
# We attempt to load.
_processed = self.load_from_cache(filename)
logger.debug(f"Cached OmniGen inputs: {_processed}")
except Exception as e:
# We failed to load. Now encode the prompt.
logger.error(
f"Failed retrieving prompt from cache:"
f"\n-> prompt: {prompt}"
f"\n-> filename: {filename}"
f"\n-> error: {e}"
f"\n-> id: {self.id}, data_backend id: {self.data_backend.id}"
)
should_encode = True
raise Exception(
"Cache retrieval for text embed file failed. Ensure your dataloader config value for skip_file_discovery does not contain 'text', and that preserve_data_backend_cache is disabled or unset."
)
if should_encode:
# If load_from_cache is True, should_encode would be False unless we failed to load.
self.debug_log(f"Encoding prompt: {prompt}")
_processed = self.encode_omnigen_prompt(
self.text_encoders, self.tokenizers, [prompt], is_validation
)
logger.debug(f"OmniGen prompt embeds: {_processed}")
current_size = self.write_queue.qsize()
if current_size >= 2048:
log_msg = str(
f"[WARNING] Write queue size is {current_size}. This is quite large."
" Consider increasing the write batch size. Delaying encode so that writes can catch up."
)
self.write_thread_bar.write(log_msg)
while self.write_queue.qsize() > 100:
time.sleep(0.1)

self.debug_log(f"Adding embed to write queue: {filename}")
self.save_to_cache(filename, _processed)
if (
self.webhook_handler is not None
and int(
self.write_thread_bar.n % self.webhook_progress_interval
)
< 10
):
last_reported_index = int(
self.write_thread_bar.n % self.webhook_progress_interval
)
self.send_progress_update(
type="init_cache_text_embeds_status_update",
progress=int(
self.write_thread_bar.n
// len(local_caption_split)
* 100
),
total=len(local_caption_split),
current=0,
)

if not return_concat:
del _processed
continue

if return_concat:
processed_all.append(_processed)

while self.write_queue.qsize() > 0:
time.sleep(0.1) # Sleep briefly to avoid busy-waiting

if self.webhook_handler is not None:
self.send_progress_update(
type="init_cache_text_embeds_status_complete",
progress=100,
total=len(local_caption_split),
current=len(local_caption_split),
)

# Close the tqdm progress bar after the loop
self.write_thread_bar.close()
self.process_write_batches = False

if not return_concat:
del processed_all
return

logger.debug(f"Returning all prompt embeds: {processed_all}")

return processed_all

def compute_embeddings_for_flux_prompts(
self,
prompts: list = None,
Expand Down Expand Up @@ -1320,7 +1488,7 @@ def compute_embeddings_for_sd3_prompts(
)
if should_encode:
# If load_from_cache is True, should_encode would be False unless we failed to load.
self.debug_log(f"Encoding filename {filename} :: device {self.text_encoders[0].device} :: prompt {prompt}")
self.debug_log(f"Encoding prompt: {prompt}")
prompt_embeds, pooled_prompt_embeds = self.encode_sd3_prompt(
self.text_encoders,
self.tokenizers,
Expand All @@ -1333,7 +1501,7 @@ def compute_embeddings_for_sd3_prompts(
),
)
logger.debug(
f"Filename {filename} SD3 prompt embeds: {prompt_embeds.shape}, {pooled_prompt_embeds.shape}"
f"SD3 prompt embeds: {prompt_embeds.shape}, {pooled_prompt_embeds.shape}"
)
add_text_embeds = pooled_prompt_embeds
# StabilityAI say not to zero them out.
Expand Down
14 changes: 11 additions & 3 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,16 @@ def get_argument_parser():
)
parser.add_argument(
"--model_family",
choices=["pixart_sigma", "kolors", "sd3", "flux", "smoldit", "sdxl", "legacy"],
choices=[
"omnigen",
"pixart_sigma",
"kolors",
"sd3",
"flux",
"smoldit",
"sdxl",
"legacy",
],
default=None,
required=True,
help=("The model family to train. This option is required."),
Expand Down Expand Up @@ -2079,7 +2088,7 @@ def parse_cmdline_args(input_args=None):

if (
args.pretrained_vae_model_name_or_path is not None
and args.model_family in ["legacy", "flux", "sd3"]
and args.model_family in ["legacy", "flux", "sd3", "omnigen"]
and "sdxl" in args.pretrained_vae_model_name_or_path
and "deepfloyd" not in args.model_type
):
Expand Down Expand Up @@ -2109,7 +2118,6 @@ def parse_cmdline_args(input_args=None):
info_log(
f"SD3 embeds for unconditional captions: t5={args.sd3_t5_uncond_behaviour}, clip={args.sd3_clip_uncond_behaviour}"
)

elif "deepfloyd" in args.model_type:
deepfloyd_pixel_alignment = 8
if args.aspect_bucket_alignment != deepfloyd_pixel_alignment:
Expand Down
Loading

0 comments on commit e1122fd

Please sign in to comment.