diff --git a/configure.py b/configure.py
index f14e13b9..83aa5564 100644
--- a/configure.py
+++ b/configure.py
@@ -18,12 +18,13 @@
"full": [
"flux",
"sdxl",
+ "omnigen",
"pixart_sigma",
"kolors",
"sd3",
"legacy",
],
- "lora": ["flux", "sdxl", "kolors", "sd3", "legacy"],
+ "lora": ["flux", "sdxl", "kolors", "sd3", "legacy", "omnigen"],
"controlnet": ["sdxl", "legacy"],
}
@@ -34,6 +35,7 @@
"kolors": "kwai-kolors/kolors-diffusers",
"terminus": "ptx0/terminus-xl-velocity-v2",
"sd3": "stabilityai/stable-diffusion-3.5-large",
+ "omnigen": "Shitao/OmniGen-v1",
"legacy": "stabilityai/stable-diffusion-2-1-base",
}
@@ -43,12 +45,14 @@
"pixart_sigma": 3.4,
"kolors": 5.0,
"terminus": 8.0,
- "sd3": 5.0,
+ "omnigen": 3.0,
+ "sd3": 6.0,
}
model_labels = {
"sd3": "Stable Diffusion 3",
"flux": "FLUX",
+ "omnigen": "OmniGen",
"pixart_sigma": "PixArt Sigma",
"kolors": "Kwai Kolors",
"terminus": "Terminus",
diff --git a/helpers/caching/text_embeds.py b/helpers/caching/text_embeds.py
index 3ee330db..78ef578d 100644
--- a/helpers/caching/text_embeds.py
+++ b/helpers/caching/text_embeds.py
@@ -208,12 +208,7 @@ def discover_all_files(self):
def save_to_cache(self, filename, embeddings):
"""Add write requests to the queue instead of writing directly."""
- if not self.batch_write_thread.is_alive():
- logger.debug("Restarting background write thread.")
- # Start the thread again.
- self.process_write_batches = True
- self.batch_write_thread = Thread(target=self.batch_write_embeddings)
- self.batch_write_thread.start()
+ self.process_write_batches = True
self.write_queue.put((embeddings, filename))
logger.debug(
f"save_to_cache called for {filename}, write queue has {self.write_queue.qsize()} items, and the write thread's status: {self.batch_write_thread.is_alive()}"
@@ -221,8 +216,6 @@ def save_to_cache(self, filename, embeddings):
def batch_write_embeddings(self):
"""Process write requests in batches."""
- batch = []
- written_elements = 0
while True:
try:
# Block until an item is available or timeout occurs
@@ -233,25 +226,14 @@ def batch_write_embeddings(self):
while (
not self.write_queue.empty() and len(batch) < self.write_batch_size
):
- logger.debug("Retrieving more items from the queue.")
items = self.write_queue.get_nowait()
batch.append(items)
- logger.debug(f"Batch now contains {len(batch)} items.")
self.process_write_batch(batch)
self.write_thread_bar.update(len(batch))
- logger.debug("Processed batch write.")
- written_elements += len(batch)
except queue.Empty:
# Timeout occurred, no items were ready
- if not self.process_write_batches:
- if len(batch) > 0:
- self.process_write_batch(batch)
- self.write_thread_bar.update(len(batch))
- logger.debug(f"Exiting batch write thread, no more work to do after writing {written_elements} elements")
- break
- logger.debug(f"Queue is empty. Retrieving new entries. Should retrieve? {self.process_write_batches}")
pass
except Exception:
logger.exception("An error occurred while writing embeddings to disk.")
@@ -260,7 +242,6 @@ def batch_write_embeddings(self):
def process_write_batch(self, batch):
"""Write a batch of embeddings to the cache."""
logger.debug(f"Writing {len(batch)} items to disk")
- logger.debug(f"Batch: {batch}")
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [
executor.submit(self.data_backend.torch_save, *args) for args in batch
@@ -288,8 +269,8 @@ def encode_flux_prompt(
text_encoders: List of text encoders.
tokenizers: List of tokenizers.
prompt: The prompt to encode.
- num_images_per_prompt: The number of images to generate per prompt.
is_validation: Whether the prompt is for validation. No-op for SD3.
+ zero_padding_tokens: Whether to zero out padding tokens.
Returns:
Tuple of (prompt_embeds, pooled_prompt_embeds).
@@ -320,6 +301,31 @@ def encode_flux_prompt(
return prompt_embeds, pooled_prompt_embeds, time_ids, masks
+ def encode_omnigen_prompt(
+ self, text_encoders, tokenizers, prompt: str, is_validation: bool = False
+ ):
+ """
+ Encode a prompt for an OmniGen model.
+
+ Args:
+ text_encoders: List of text encoders.
+ tokenizers: List of tokenizers.
+ prompt: The prompt to encode.
+ is_validation: Whether the prompt is for validation. No-op for OmniGen.
+
+ Returns:
+ Dict of OmniGen inputs
+ """
+ # it's not a text encoder, it's the MLLM preprocessor / tokeniser.
+ processed = text_encoders[0](
+ instructions=prompt,
+ # use_img_cfg=False,
+ # separate_cfg_input=False,
+ # use_input_image_size_as_output=False,
+ )
+
+ return processed
+
# Adapted from pipelines.StableDiffusion3Pipeline.encode_prompt
def encode_sd3_prompt(
self,
@@ -525,9 +531,7 @@ def encode_prompt(self, prompt: str, is_validation: bool = False):
prompt,
is_validation,
zero_padding_tokens=(
- True
- if StateTracker.get_args().t5_padding == "zero"
- else False
+ True if StateTracker.get_args().t5_padding == "zero" else False
),
)
else:
@@ -666,6 +670,12 @@ def compute_embeddings_for_prompts(
return_concat=return_concat,
load_from_cache=load_from_cache,
)
+ elif self.model_type == "omnigen":
+ output = self.compute_embeddings_for_omnigen_prompts(
+ raw_prompts,
+ return_concat=return_concat,
+ load_from_cache=load_from_cache,
+ )
else:
raise ValueError(
f"No such text encoding backend for model type '{self.model_type}'"
@@ -1039,6 +1049,164 @@ def compute_embeddings_for_legacy_prompts(
return prompt_embeds_all, attention_masks_all
return prompt_embeds_all
+ def compute_embeddings_for_omnigen_prompts(
+ self,
+ prompts: list = None,
+ return_concat: bool = True,
+ is_validation: bool = False,
+ load_from_cache: bool = True,
+ ):
+ # print(f"Computing embeddings for Omnigen prompts")
+ # processed = self.text_encoders[0](
+ # prompts,
+ # )
+
+ # # processed looks like:
+ # # {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
+ # print(f"Processed: {processed.keys()}")
+
+ # return processed
+ processed_all = []
+ should_encode = not load_from_cache
+ args = StateTracker.get_args()
+ if should_encode:
+ local_caption_split = self.split_captions_between_processes(
+ prompts or self.prompts
+ )
+ else:
+ local_caption_split = prompts or self.prompts
+ if (
+ hasattr(args, "cache_clear_validation_prompts")
+ and args.cache_clear_validation_prompts
+ and is_validation
+ ):
+ # If --cache_clear_validation_prompts was provided, we will forcibly overwrite them.
+ load_from_cache = False
+ should_encode = True
+
+ if self.webhook_handler is not None:
+ last_reported_index = 0
+ self.send_progress_update(
+ type="init_cache_text_embeds_started",
+ progress=int(0 // len(local_caption_split)),
+ total=len(local_caption_split),
+ current=0,
+ )
+ self.write_thread_bar = tqdm(
+ desc="Write embeds to disk",
+ leave=False,
+ ncols=125,
+ disable=return_concat,
+ total=len(local_caption_split),
+ position=get_rank(),
+ )
+ with torch.no_grad():
+ last_reported_index = 0
+ for prompt in tqdm(
+ local_caption_split,
+ desc="Processing prompts",
+ disable=return_concat,
+ miniters=50,
+ leave=False,
+ ncols=125,
+ position=get_rank() + self.accelerator.num_processes + 1,
+ ):
+ filename = os.path.join(self.cache_dir, self.hash_prompt(prompt))
+ debug_msg = f"Processing file: {filename}, prompt: {prompt}"
+ prompt = PromptHandler.filter_caption(self.data_backend, prompt)
+ debug_msg = f"{debug_msg}\n -> filtered prompt: {prompt}"
+ if prompt is None:
+ logger.error(f"Filename {filename} does not have a caption.")
+ continue
+ logger.debug(debug_msg)
+ if return_concat and load_from_cache:
+ try:
+ # We attempt to load.
+ _processed = self.load_from_cache(filename)
+ logger.debug(f"Cached OmniGen inputs: {_processed}")
+ except Exception as e:
+ # We failed to load. Now encode the prompt.
+ logger.error(
+ f"Failed retrieving prompt from cache:"
+ f"\n-> prompt: {prompt}"
+ f"\n-> filename: {filename}"
+ f"\n-> error: {e}"
+ f"\n-> id: {self.id}, data_backend id: {self.data_backend.id}"
+ )
+ should_encode = True
+ raise Exception(
+ "Cache retrieval for text embed file failed. Ensure your dataloader config value for skip_file_discovery does not contain 'text', and that preserve_data_backend_cache is disabled or unset."
+ )
+ if should_encode:
+ # If load_from_cache is True, should_encode would be False unless we failed to load.
+ self.debug_log(f"Encoding prompt: {prompt}")
+ _processed = self.encode_omnigen_prompt(
+ self.text_encoders, self.tokenizers, [prompt], is_validation
+ )
+ logger.debug(f"OmniGen prompt embeds: {_processed}")
+ current_size = self.write_queue.qsize()
+ if current_size >= 2048:
+ log_msg = str(
+ f"[WARNING] Write queue size is {current_size}. This is quite large."
+ " Consider increasing the write batch size. Delaying encode so that writes can catch up."
+ )
+ self.write_thread_bar.write(log_msg)
+ while self.write_queue.qsize() > 100:
+ time.sleep(0.1)
+
+ self.debug_log(f"Adding embed to write queue: {filename}")
+ self.save_to_cache(filename, _processed)
+ if (
+ self.webhook_handler is not None
+ and int(
+ self.write_thread_bar.n % self.webhook_progress_interval
+ )
+ < 10
+ ):
+ last_reported_index = int(
+ self.write_thread_bar.n % self.webhook_progress_interval
+ )
+ self.send_progress_update(
+ type="init_cache_text_embeds_status_update",
+ progress=int(
+ self.write_thread_bar.n
+ // len(local_caption_split)
+ * 100
+ ),
+ total=len(local_caption_split),
+ current=0,
+ )
+
+ if not return_concat:
+ del _processed
+ continue
+
+ if return_concat:
+ processed_all.append(_processed)
+
+ while self.write_queue.qsize() > 0:
+ time.sleep(0.1) # Sleep briefly to avoid busy-waiting
+
+ if self.webhook_handler is not None:
+ self.send_progress_update(
+ type="init_cache_text_embeds_status_complete",
+ progress=100,
+ total=len(local_caption_split),
+ current=len(local_caption_split),
+ )
+
+ # Close the tqdm progress bar after the loop
+ self.write_thread_bar.close()
+ self.process_write_batches = False
+
+ if not return_concat:
+ del processed_all
+ return
+
+ logger.debug(f"Returning all prompt embeds: {processed_all}")
+
+ return processed_all
+
def compute_embeddings_for_flux_prompts(
self,
prompts: list = None,
@@ -1320,7 +1488,7 @@ def compute_embeddings_for_sd3_prompts(
)
if should_encode:
# If load_from_cache is True, should_encode would be False unless we failed to load.
- self.debug_log(f"Encoding filename {filename} :: device {self.text_encoders[0].device} :: prompt {prompt}")
+ self.debug_log(f"Encoding prompt: {prompt}")
prompt_embeds, pooled_prompt_embeds = self.encode_sd3_prompt(
self.text_encoders,
self.tokenizers,
@@ -1333,7 +1501,7 @@ def compute_embeddings_for_sd3_prompts(
),
)
logger.debug(
- f"Filename {filename} SD3 prompt embeds: {prompt_embeds.shape}, {pooled_prompt_embeds.shape}"
+ f"SD3 prompt embeds: {prompt_embeds.shape}, {pooled_prompt_embeds.shape}"
)
add_text_embeds = pooled_prompt_embeds
# StabilityAI say not to zero them out.
diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py
index ee3b4f79..13f3fa5e 100644
--- a/helpers/configuration/cmd_args.py
+++ b/helpers/configuration/cmd_args.py
@@ -85,7 +85,16 @@ def get_argument_parser():
)
parser.add_argument(
"--model_family",
- choices=["pixart_sigma", "kolors", "sd3", "flux", "smoldit", "sdxl", "legacy"],
+ choices=[
+ "omnigen",
+ "pixart_sigma",
+ "kolors",
+ "sd3",
+ "flux",
+ "smoldit",
+ "sdxl",
+ "legacy",
+ ],
default=None,
required=True,
help=("The model family to train. This option is required."),
@@ -2079,7 +2088,7 @@ def parse_cmdline_args(input_args=None):
if (
args.pretrained_vae_model_name_or_path is not None
- and args.model_family in ["legacy", "flux", "sd3"]
+ and args.model_family in ["legacy", "flux", "sd3", "omnigen"]
and "sdxl" in args.pretrained_vae_model_name_or_path
and "deepfloyd" not in args.model_type
):
@@ -2109,7 +2118,6 @@ def parse_cmdline_args(input_args=None):
info_log(
f"SD3 embeds for unconditional captions: t5={args.sd3_t5_uncond_behaviour}, clip={args.sd3_clip_uncond_behaviour}"
)
-
elif "deepfloyd" in args.model_type:
deepfloyd_pixel_alignment = 8
if args.aspect_bucket_alignment != deepfloyd_pixel_alignment:
diff --git a/helpers/models/omnigen/pipeline.py b/helpers/models/omnigen/pipeline.py
new file mode 100644
index 00000000..88c61ec9
--- /dev/null
+++ b/helpers/models/omnigen/pipeline.py
@@ -0,0 +1,363 @@
+import os
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+import gc
+
+from PIL import Image
+import numpy as np
+import torch
+from huggingface_hub import snapshot_download
+from peft import LoraConfig, PeftModel
+from diffusers.models import AutoencoderKL
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from safetensors.torch import load_file
+
+from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler
+
+
+logger = logging.get_logger(__name__)
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from OmniGen import OmniGenPipeline
+ >>> pipe = FluxControlNetPipeline.from_pretrained(
+ ... base_model
+ ... )
+ >>> prompt = "A woman holds a bouquet of flowers and faces the camera"
+ >>> image = pipe(
+ ... prompt,
+ ... guidance_scale=2.5,
+ ... num_inference_steps=50,
+ ... ).images[0]
+ >>> image.save("t2i.png")
+ ```
+"""
+
+
+90
+
+
+class OmniGenPipeline:
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ model: OmniGen,
+ processor: OmniGenProcessor,
+ device: Union[str, torch.device],
+ ):
+ self.vae = vae
+ self.model = model
+ self.processor = processor
+ self.device = device
+
+ self.model.to(torch.bfloat16)
+ self.model.eval()
+ self.vae.eval()
+
+ self.model_cpu_offload = False
+
+ @classmethod
+ def from_pretrained(
+ cls, pretrained_model_name_or_path, vae_path: str = None, **kwargs
+ ):
+ if not os.path.exists(pretrained_model_name_or_path) or (
+ not os.path.exists(
+ os.path.join(pretrained_model_name_or_path, "model.safetensors")
+ )
+ and pretrained_model_name_or_path == "Shitao/OmniGen-v1"
+ ):
+ logger.info("Model not found, downloading...")
+ cache_folder = os.getenv("HF_HUB_CACHE")
+ pretrained_model_name_or_path = snapshot_download(
+ repo_id=pretrained_model_name_or_path,
+ cache_dir=cache_folder,
+ ignore_patterns=[
+ "flax_model.msgpack",
+ "rust_model.ot",
+ "tf_model.h5",
+ "model.pt",
+ ],
+ )
+ logger.info(f"Downloaded model to {pretrained_model_name_or_path}")
+ model = OmniGen.from_pretrained(pretrained_model_name_or_path)
+ processor = OmniGenProcessor.from_pretrained(pretrained_model_name_or_path)
+
+ if os.path.exists(os.path.join(pretrained_model_name_or_path, "vae")):
+ vae = AutoencoderKL.from_pretrained(
+ os.path.join(pretrained_model_name_or_path, "vae")
+ )
+ elif vae_path is not None:
+ vae = AutoencoderKL.from_pretrained(vae_path).to(device)
+ else:
+ logger.info(
+ f"No VAE found in {pretrained_model_name_or_path}, downloading stabilityai/sdxl-vae from HF"
+ )
+ vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)
+
+ print(f"OmniGenPipeline received unexpected arguments: {kwargs.keys()}")
+
+ return cls(vae, model, processor)
+
+ def merge_lora(self, lora_path: str):
+ model = PeftModel.from_pretrained(self.model, lora_path)
+ model.merge_and_unload()
+
+ self.model = model
+
+ def to(self, device: Union[str, torch.device]):
+ if isinstance(device, str):
+ device = torch.device(device)
+ self.model.to(device)
+ self.vae.to(device)
+ self.device = device
+
+ def vae_encode(self, x, dtype):
+ if self.vae.config.shift_factor is not None:
+ x = self.vae.encode(x).latent_dist.sample()
+ x = (x - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+ else:
+ x = (
+ self.vae.encode(x)
+ .latent_dist.sample()
+ .mul_(self.vae.config.scaling_factor)
+ )
+ x = x.to(dtype)
+ return x
+
+ def move_to_device(self, data):
+ if isinstance(data, list):
+ return [x.to(self.device) for x in data]
+ return data.to(self.device)
+
+ def enable_model_cpu_offload(self):
+ self.model_cpu_offload = True
+ self.model.to("cpu")
+ self.vae.to("cpu")
+ torch.cuda.empty_cache() # Clear VRAM
+ gc.collect() # Run garbage collection to free system RAM
+
+ def disable_model_cpu_offload(self):
+ self.model_cpu_offload = False
+ self.model.to(self.device)
+ self.vae.to(self.device)
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ input_images: Union[List[str], List[List[str]]] = None,
+ height: int = 1024,
+ width: int = 1024,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 3,
+ use_img_guidance: bool = True,
+ img_guidance_scale: float = 1.6,
+ max_input_image_size: int = 1024,
+ separate_cfg_infer: bool = True,
+ offload_model: bool = False,
+ use_kv_cache: bool = True,
+ offload_kv_cache: bool = True,
+ use_input_image_size_as_output: bool = False,
+ dtype: torch.dtype = torch.bfloat16,
+ seed: int = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ input_images (`List[str]` or `List[List[str]]`, *optional*):
+ The list of input images. We will replace the "<|image_i|>" in prompt with the 1-th image in list.
+ height (`int`, *optional*, defaults to 1024):
+ The height in pixels of the generated image. The number must be a multiple of 16.
+ width (`int`, *optional*, defaults to 1024):
+ The width in pixels of the generated image. The number must be a multiple of 16.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ use_img_guidance (`bool`, *optional*, defaults to True):
+ Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
+ img_guidance_scale (`float`, *optional*, defaults to 1.6):
+ Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
+ max_input_image_size (`int`, *optional*, defaults to 1024): the maximum size of input image, which will be used to crop the input image to the maximum size
+ separate_cfg_infer (`bool`, *optional*, defaults to False):
+ Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference.
+ use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference
+ offload_kv_cache (`bool`, *optional*, defaults to True): offload the cached key and value to cpu, which can save memory but slow down the generation silightly
+ offload_model (`bool`, *optional*, defaults to False): offload the model to cpu, which can save memory but slow down the generation
+ use_input_image_size_as_output (bool, defaults to False): whether to use the input image size as the output image size, which can be used for single-image input, e.g., image editing task
+ seed (`int`, *optional*):
+ A random seed for generating output.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
+ data type for the model
+ Examples:
+
+ Returns:
+ A list with the generated images.
+ """
+ # check inputs:
+ if use_input_image_size_as_output:
+ assert (
+ isinstance(prompt, str) and len(input_images) == 1
+ ), "if you want to make sure the output image have the same size as the input image, please only input one image instead of multiple input images"
+ else:
+ assert (
+ height % 16 == 0 and width % 16 == 0
+ ), "The height and width must be a multiple of 16."
+ if input_images is None:
+ use_img_guidance = False
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ input_images = [input_images] if input_images is not None else None
+
+ # set model and processor
+ if max_input_image_size != self.processor.max_image_size:
+ self.processor = OmniGenProcessor(
+ self.processor.text_tokenizer, max_image_size=max_input_image_size
+ )
+ if offload_model:
+ self.enable_model_cpu_offload()
+ else:
+ self.disable_model_cpu_offload()
+
+ input_data = self.processor(
+ prompt,
+ input_images,
+ height=height,
+ width=width,
+ use_img_cfg=use_img_guidance,
+ separate_cfg_input=separate_cfg_infer,
+ use_input_image_size_as_output=use_input_image_size_as_output,
+ )
+ print(f"Input shapes: {input_data['attention_mask'][0].shape}")
+
+ num_prompt = len(prompt)
+ num_cfg = 2 if use_img_guidance else 1
+ if use_input_image_size_as_output:
+ if separate_cfg_infer:
+ height, width = input_data["input_pixel_values"][0][0].shape[-2:]
+ else:
+ height, width = input_data["input_pixel_values"][0].shape[-2:]
+ latent_size_h, latent_size_w = height // 8, width // 8
+
+ if seed is not None:
+ generator = torch.Generator(device=self.device).manual_seed(seed)
+ else:
+ generator = None
+ latents = torch.randn(
+ num_prompt,
+ 4,
+ latent_size_h,
+ latent_size_w,
+ device=self.device,
+ generator=generator,
+ )
+ latents = torch.cat([latents] * (1 + num_cfg), 0).to(dtype)
+
+ if input_images is not None and self.model_cpu_offload:
+ self.vae.to(self.device)
+ input_img_latents = []
+ if separate_cfg_infer:
+ for temp_pixel_values in input_data["input_pixel_values"]:
+ temp_input_latents = []
+ for img in temp_pixel_values:
+ img = self.vae_encode(img.to(self.device), dtype)
+ temp_input_latents.append(img)
+ input_img_latents.append(temp_input_latents)
+ else:
+ for img in input_data["input_pixel_values"]:
+ img = self.vae_encode(img.to(self.device), dtype)
+ input_img_latents.append(img)
+ if input_images is not None and self.model_cpu_offload:
+ self.vae.to("cpu")
+ torch.cuda.empty_cache() # Clear VRAM
+ gc.collect() # Run garbage collection to free system RAM
+
+ model_kwargs = dict(
+ input_ids=self.move_to_device(input_data["input_ids"]),
+ input_img_latents=input_img_latents,
+ input_image_sizes=input_data["input_image_sizes"],
+ attention_mask=self.move_to_device(input_data["attention_mask"]),
+ position_ids=self.move_to_device(input_data["position_ids"]),
+ cfg_scale=guidance_scale,
+ img_cfg_scale=img_guidance_scale,
+ use_img_cfg=use_img_guidance,
+ use_kv_cache=use_kv_cache,
+ offload_model=offload_model,
+ )
+
+ if separate_cfg_infer:
+ func = self.model.forward_with_separate_cfg
+ else:
+ func = self.model.forward_with_cfg
+ self.model.to(dtype)
+
+ if self.model_cpu_offload:
+ for name, param in self.model.named_parameters():
+ if "layers" in name and "layers.0" not in name:
+ param.data = param.data.cpu()
+ else:
+ param.data = param.data.to(self.device)
+ for buffer_name, buffer in self.model.named_buffers():
+ setattr(self.model, buffer_name, buffer.to(self.device))
+ # else:
+ # self.model.to(self.device)
+
+ scheduler = OmniGenScheduler(num_steps=num_inference_steps)
+ samples = scheduler(
+ latents,
+ func,
+ model_kwargs,
+ use_kv_cache=use_kv_cache,
+ offload_kv_cache=offload_kv_cache,
+ )
+ samples = samples.chunk((1 + num_cfg), dim=0)[0]
+
+ if self.model_cpu_offload:
+ self.model.to("cpu")
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ self.vae.to(self.device)
+ samples = samples.to(torch.float32)
+ if self.vae.config.shift_factor is not None:
+ samples = (
+ samples / self.vae.config.scaling_factor + self.vae.config.shift_factor
+ )
+ else:
+ samples = samples / self.vae.config.scaling_factor
+ samples = self.vae.decode(
+ samples.to(dtype=self.vae.dtype, device=self.vae.device)
+ ).sample
+
+ if self.model_cpu_offload:
+ self.vae.to("cpu")
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ output_samples = (samples * 0.5 + 0.5).clamp(0, 1) * 255
+ output_samples = (
+ output_samples.permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
+ )
+ output_images = []
+ for i, sample in enumerate(output_samples):
+ output_images.append(Image.fromarray(sample))
+
+ torch.cuda.empty_cache() # Clear VRAM
+ gc.collect() # Run garbage collection to free system RAM
+ return output_images
diff --git a/helpers/models/omnigen/processor.py b/helpers/models/omnigen/processor.py
new file mode 100644
index 00000000..bc634d26
--- /dev/null
+++ b/helpers/models/omnigen/processor.py
@@ -0,0 +1,187 @@
+import os
+import re
+from typing import Dict, List
+import json
+
+import torch
+import numpy as np
+import random
+from PIL import Image
+from torchvision import transforms
+from transformers import AutoTokenizer
+from huggingface_hub import snapshot_download
+
+from OmniGen.utils import (
+ create_logger,
+ update_ema,
+ requires_grad,
+ center_crop_arr,
+ crop_arr,
+)
+
+
+class OmniGenCollator:
+ def __init__(self, pad_token_id=2):
+ self.pad_token_id = pad_token_id
+
+ def __call__(self, features):
+ print(f"features: {features}")
+ input_ids = [f[0]["input_ids"] for f in features]
+ attention_masks = []
+ max_length = max(len(ids) for ids in input_ids)
+
+ # Pad input_ids and create attention masks
+ padded_input_ids = []
+ for ids in input_ids:
+ pad_length = max_length - len(ids)
+ padded_ids = [self.pad_token_id] * pad_length + ids
+ attention_mask = [0] * pad_length + [1] * len(ids)
+ padded_input_ids.append(padded_ids)
+ attention_masks.append(attention_mask)
+
+ padded_input_ids = torch.tensor(padded_input_ids)
+ attention_masks = torch.tensor(attention_masks)
+
+ # Handle pixel values
+ pixel_values = [
+ f[0]["pixel_values"] for f in features if f[0]["pixel_values"] is not None
+ ]
+ if pixel_values:
+ pixel_values = [pv for sublist in pixel_values for pv in sublist]
+ pixel_values = torch.stack(pixel_values)
+ else:
+ pixel_values = None
+
+ return {
+ "input_ids": padded_input_ids,
+ "attention_mask": attention_masks,
+ "pixel_values": pixel_values,
+ # Include other necessary fields
+ }
+
+
+class OmniGenTrainingProcessor:
+ def __init__(self, text_tokenizer, max_image_size: int = 1024):
+ self.text_tokenizer = text_tokenizer
+ self.max_image_size = max_image_size
+
+ self.image_transform = transforms.Compose(
+ [
+ transforms.Lambda(
+ lambda pil_image: crop_arr(pil_image, max_image_size)
+ ),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True
+ ),
+ ]
+ )
+
+ self.collator = OmniGenCollator()
+
+ @classmethod
+ def from_pretrained(cls, model_name):
+ if not os.path.exists(model_name):
+ cache_folder = os.getenv("HF_HUB_CACHE")
+ model_name = snapshot_download(
+ repo_id=model_name, cache_dir=cache_folder, allow_patterns="*.json"
+ )
+ text_tokenizer = AutoTokenizer.from_pretrained(model_name)
+
+ return cls(text_tokenizer)
+
+ def process_image(self, image):
+ image = Image.open(image).convert("RGB")
+ return self.image_transform(image)
+
+ def process_multi_modal_prompt(self, text, input_images):
+ text = self.add_prefix_instruction(text)
+ if input_images is None or len(input_images) == 0:
+ model_inputs = self.text_tokenizer(text)
+ return {
+ "input_ids": model_inputs.input_ids,
+ "pixel_values": None,
+ "image_sizes": None,
+ }
+
+ pattern = r"<\|image_\d+\|>"
+ prompt_chunks = [
+ self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)
+ ]
+
+ for i in range(1, len(prompt_chunks)):
+ if prompt_chunks[i][0] == 1:
+ prompt_chunks[i] = prompt_chunks[i][1:]
+
+ image_tags = re.findall(pattern, text)
+ image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
+
+ unique_image_ids = sorted(list(set(image_ids)))
+ assert unique_image_ids == list(
+ range(1, len(unique_image_ids) + 1)
+ ), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
+ # total images must be the same as the number of image tags
+ assert len(unique_image_ids) == len(
+ input_images
+ ), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
+
+ input_images = [input_images[x - 1] for x in image_ids]
+
+ all_input_ids = []
+ img_inx = []
+ idx = 0
+ for i in range(len(prompt_chunks)):
+ all_input_ids.extend(prompt_chunks[i])
+ if i != len(prompt_chunks) - 1:
+ start_inx = len(all_input_ids)
+ size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16
+ img_inx.append([start_inx, start_inx + size])
+ all_input_ids.extend([0] * size)
+
+ return {
+ "input_ids": all_input_ids,
+ "pixel_values": input_images,
+ "image_sizes": img_inx,
+ }
+
+ def add_prefix_instruction(self, prompt):
+ user_prompt = "<|user|>\n"
+ generation_prompt = (
+ "Generate an image according to the following instructions\n"
+ )
+ assistant_prompt = "<|assistant|>\n<|diffusion|>"
+ prompt_suffix = "<|end|>\n"
+ prompt = (
+ f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
+ )
+ return prompt
+
+ def __call__(
+ self,
+ instructions: List[str],
+ input_images: List[List[str]] = None,
+ height: int = 1024,
+ width: int = 1024,
+ ) -> Dict:
+
+ if isinstance(instructions, str):
+ instructions = [instructions]
+ input_images = [input_images]
+
+ input_data = []
+ for i in range(len(instructions)):
+ cur_instruction = instructions[i]
+ cur_input_images = None if input_images is None else input_images[i]
+ if cur_input_images is not None and len(cur_input_images) > 0:
+ cur_input_images = [self.process_image(x) for x in cur_input_images]
+ else:
+ cur_input_images = None
+ assert "<|image_1|>" not in cur_instruction
+
+ mllm_input = self.process_multi_modal_prompt(
+ cur_instruction, cur_input_images
+ )
+
+ input_data.append((mllm_input, [height, width]))
+
+ return self.collator(input_data)
diff --git a/helpers/training/adapter.py b/helpers/training/adapter.py
index 04b99069..06676d06 100644
--- a/helpers/training/adapter.py
+++ b/helpers/training/adapter.py
@@ -9,6 +9,9 @@ def determine_adapter_target_modules(args, unet, transformer):
elif transformer is not None:
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
+ if args.model_family.lower() == "omnigen":
+ target_modules = ["qkv_proj", "o_proj"]
+
if args.model_family.lower() == "flux" and args.flux_lora_target == "all":
# target_modules = mmdit layers here
target_modules = [
diff --git a/helpers/training/collate.py b/helpers/training/collate.py
index c1549e3a..f9c9fb40 100644
--- a/helpers/training/collate.py
+++ b/helpers/training/collate.py
@@ -1,5 +1,6 @@
import torch
import logging
+import random
import concurrent.futures
import numpy as np
from os import environ
@@ -213,7 +214,12 @@ def compute_latents(filepaths, data_backend_id: str):
def compute_single_embedding(
- caption, text_embed_cache, is_sdxl, is_sd3: bool = False, is_flux: bool = False
+ caption,
+ text_embed_cache,
+ is_sdxl,
+ is_sd3: bool = False,
+ is_flux: bool = False,
+ is_omnigen: bool = False,
):
"""Worker function to compute embedding for a single caption."""
if caption == "" or not caption:
@@ -246,6 +252,11 @@ def compute_single_embedding(
time_ids[0],
masks[0] if masks is not None else None,
)
+ elif is_omnigen:
+ processed = text_embed_cache.compute_embeddings_for_omnigen_prompts(
+ prompts=[caption]
+ )
+ return processed
else:
prompt_embeds = text_embed_cache.compute_embeddings_for_legacy_prompts(
[caption]
@@ -284,6 +295,7 @@ def compute_prompt_embeddings(captions, text_embed_cache):
is_pixart_sigma = text_embed_cache.model_type == "pixart_sigma"
is_smoldit = text_embed_cache.model_type == "smoldit"
is_flux = text_embed_cache.model_type == "flux"
+ is_omnigen = text_embed_cache.model_type == "omnigen"
# Use a thread pool to compute embeddings concurrently
with ThreadPoolExecutor() as executor:
@@ -295,6 +307,7 @@ def compute_prompt_embeddings(captions, text_embed_cache):
[is_sdxl] * len(captions),
[is_sd3] * len(captions),
[is_flux] * len(captions),
+ [is_omnigen] * len(captions),
)
)
@@ -331,6 +344,9 @@ def compute_prompt_embeddings(captions, text_embed_cache):
torch.stack(time_ids),
torch.stack(masks) if None not in masks else None,
)
+ elif is_omnigen:
+ embeddings = [e[0] for e in embeddings]
+ return embeddings
else:
# Separate the tuples
prompt_embeds = [t[0] for t in embeddings]
@@ -426,8 +442,17 @@ def collate_fn(batch):
"This trainer is not designed to handle multiple batches in a single collate."
)
debug_log("Begin collate_fn on batch")
-
- # SDXL Dropout
+ (
+ latent_batch,
+ prompt_embeds_all,
+ add_text_embeds_all,
+ input_ids,
+ batch_time_ids,
+ batch_luminance,
+ conditioning_pixel_values,
+ attn_mask,
+ conditioning_type,
+ ) = (None, None, None, None, None, None, None, None, None)
dropout_probability = StateTracker.get_args().caption_dropout_probability
batch = batch[0]
examples = batch["training_samples"]
@@ -528,11 +553,70 @@ def collate_fn(batch):
attn_mask = None
batch_time_ids = None
+ input_ids = None
+ extra_batch_inputs = {}
if StateTracker.get_model_family() == "flux":
debug_log("Compute and stack Flux time ids")
prompt_embeds_all, add_text_embeds_all, batch_time_ids, attn_mask = (
compute_prompt_embeddings(captions, text_embed_cache)
)
+ elif StateTracker.get_model_family() == "omnigen":
+ # instruction, output_image = example['instruction'], example['input_images'], example['output_image']
+ omnigen_processed_embeddings = compute_prompt_embeddings(
+ captions, text_embed_cache
+ )
+ from OmniGen.processor import OmniGenCollator
+
+ attn_mask = [e.get("attention_mask") for e in omnigen_processed_embeddings]
+ attn_mask_len = len(attn_mask[0][0])
+ attn_mask = torch.stack(attn_mask, dim=0)
+
+ # we can use the OmniGenCollator.create_position to make positional ids
+ num_tokens_for_output_images = []
+ for img_size in [
+ [
+ latent_batch.shape[3] * 8,
+ latent_batch.shape[2] * 8,
+ ]
+ * len(latent_batch)
+ ]:
+ num_img_tokens = img_size[0] * img_size[1] // 16 // 16
+ num_text_tokens = attn_mask_len
+ total_num_tokens = num_img_tokens - num_text_tokens
+ num_tokens_for_output_images.append(total_num_tokens)
+ position_ids = OmniGenCollator.create_position(
+ attn_mask, num_tokens_for_output_images
+ )
+ # pad attn_mask to match the position_ids, eg. mask [1, 1, 1, 57] -> [1, 1, 1, 4097]
+ attn_mask = torch.cat(
+ [
+ attn_mask,
+ torch.zeros(
+ (
+ attn_mask.shape[0],
+ attn_mask.shape[1],
+ num_tokens_for_output_images[0] + 1,
+ )
+ ),
+ ],
+ dim=-1,
+ )
+
+ # TODO: support "input images" for OmniGen which behave as conditioning images, eg. ControlNet Canny, Depth, etc.
+ # conditioning_pixel_values = torch.stack([e.get('input_pixel_values') for e in omnigen_processed_embeddings], dim=0)
+ # input_image_sizes = [e.get('input_image_size') for e in omnigen_processed_embeddings]
+ # extra_batch_inputs['conditioning_pixel_values'] = conditioning_pixel_values
+ # extra_batch_inputs['input_image_sizes'] = input_image_sizes
+ # input_ids = [e.get('input_ids') for e in omnigen_processed_embeddings]
+ # input_ids = torch.stack(input_ids, dim=0)
+ # TODO: Support instruction/conditioning image dropout for OmniGen.
+ # if random.random() < StateTracker.get_args().caption_dropout_probability:
+ # instruction = ''
+ # latent_batch = None
+ padding_images = [e.get("padding_image") for e in omnigen_processed_embeddings]
+ extra_batch_inputs["position_ids"] = position_ids
+ extra_batch_inputs["padding_images"] = padding_images
+ extra_batch_inputs["input_ids"] = input_ids
else:
prompt_embeds_all, add_text_embeds_all = compute_prompt_embeddings(
captions, text_embed_cache
@@ -566,4 +650,5 @@ def collate_fn(batch):
"encoder_attention_mask": attn_mask,
"is_regularisation_data": is_regularisation_data,
"conditioning_type": conditioning_type,
+ **extra_batch_inputs,
}
diff --git a/helpers/training/diffusion_model.py b/helpers/training/diffusion_model.py
index 78611dc8..6ef9c86b 100644
--- a/helpers/training/diffusion_model.py
+++ b/helpers/training/diffusion_model.py
@@ -34,7 +34,24 @@ def load_diffusion_model(args, weight_dtype):
bnb_4bit_compute_dtype=weight_dtype,
)
- if args.model_family == "sd3":
+ if args.model_family == "omnigen":
+ try:
+ from OmniGen import OmniGen
+ except ImportError:
+ logger.error(
+ "Could not import Omnigen. Please install Omnigen to use this model."
+ )
+ raise
+ logger.info("Loading OmniGen model..")
+ transformer = OmniGen.from_pretrained(
+ args.pretrained_transformer_model_name_or_path
+ or args.pretrained_model_name_or_path
+ )
+ transformer.llm.config.use_cache = False
+ logger.info(f"Enabling gradient checkpointing..")
+ transformer.llm.gradient_checkpointing_enable()
+
+ elif args.model_family == "sd3":
# Stable Diffusion 3 uses a Diffusion transformer.
logger.info("Loading Stable Diffusion 3 diffusion transformer..")
try:
diff --git a/helpers/training/save_hooks.py b/helpers/training/save_hooks.py
index 51fb85de..38d66d73 100644
--- a/helpers/training/save_hooks.py
+++ b/helpers/training/save_hooks.py
@@ -164,6 +164,12 @@ def __init__(
elif args.model_family == "smoldit":
self.denoiser_class = SmolDiT2DModel
self.pipeline_class = SmolDiTPipeline
+ elif args.model_family == "omnigen":
+ from OmniGen import OmniGen
+ from helpers.models.omnigen.pipeline import OmniGenPipeline
+
+ self.denoiser_class = OmniGen
+ self.pipeline_class = OmniGenPipeline
self.denoiser_subdir = "transformer"
if args.controlnet:
diff --git a/helpers/training/schedulers.py b/helpers/training/schedulers.py
index b7636189..83e37ccf 100644
--- a/helpers/training/schedulers.py
+++ b/helpers/training/schedulers.py
@@ -22,6 +22,10 @@ def load_scheduler_from_args(args):
subfolder="scheduler",
shift=1 if args.model_family == "sd3" else 3,
)
+ elif args.model_family == "omnigen":
+ from OmniGen import OmniGenScheduler
+
+ noise_scheduler = OmniGenScheduler()
else:
if args.model_family == "legacy":
args.rescale_betas_zero_snr = True
diff --git a/helpers/training/state_tracker.py b/helpers/training/state_tracker.py
index d069a2fb..51c9fb57 100644
--- a/helpers/training/state_tracker.py
+++ b/helpers/training/state_tracker.py
@@ -117,6 +117,7 @@ def set_model_family(cls, model_type: str):
"legacy",
"sdxl",
"sd3",
+ "omnigen",
"pixart_sigma",
"kolors",
"smoldit",
diff --git a/helpers/training/text_encoding.py b/helpers/training/text_encoding.py
index 5261b6fb..9c082482 100644
--- a/helpers/training/text_encoding.py
+++ b/helpers/training/text_encoding.py
@@ -15,6 +15,19 @@ def import_model_class_from_model_name_or_path(
args,
subfolder: str = "text_encoder",
):
+ if args.model_family.lower() == "omnigen":
+ try:
+ from helpers.models.omnigen.processor import (
+ OmniGenTrainingProcessor as OmniGenProcessor,
+ )
+ except ImportError:
+ logger.error(
+ "Could not import Omnigen. Please install omnigen to use this model."
+ )
+ raise
+
+ return OmniGenProcessor
+
if args.model_family.lower() == "smoldit":
from transformers import AutoModelForSeq2SeqLM
@@ -51,6 +64,9 @@ def import_model_class_from_model_name_or_path(
def get_tokenizers(args):
tokenizer_1, tokenizer_2, tokenizer_3 = None, None, None
try:
+ if args.model_family.lower() == "omnigen":
+ return None, None, None
+
if args.model_family.lower() == "smoldit":
from transformers import AutoTokenizer
@@ -225,6 +241,11 @@ def load_tes(
f"Loading ChatGLM language model from {text_encoder_path}/{text_encoder_subfolder}.."
)
text_encoder_variant = "fp16"
+ elif args.model_family.lower() == "omnigen":
+ logger.info(f"Loading OmniGen processor from {text_encoder_path}..")
+ text_encoder_1 = text_encoder_cls_1.from_pretrained(text_encoder_path)
+
+ return text_encoder_variant, text_encoder_1, text_encoder_2, text_encoder_3
else:
logger.info(
f"Loading CLIP text encoder from {text_encoder_path}/{text_encoder_subfolder}.."
diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py
index a5e70d44..92a844c4 100644
--- a/helpers/training/trainer.py
+++ b/helpers/training/trainer.py
@@ -505,12 +505,26 @@ def init_text_tokenizer(self):
def init_text_encoder(self, move_to_accelerator: bool = True):
self.init_text_tokenizer()
+ self.text_encoders = []
+ self.tokenizers = []
self.text_encoder_1, self.text_encoder_2, self.text_encoder_3 = None, None, None
self.text_encoder_cls_1, self.text_encoder_cls_2, self.text_encoder_cls_3 = (
None,
None,
None,
)
+ if self.config.model_family.lower() == "omnigen":
+ # omnigen uses a preprocessor w/ a simple tokeniser, not a text encoder.
+ from helpers.models.omnigen.processor import (
+ OmniGenTrainingProcessor as OmniGenProcessor,
+ )
+
+ self.text_encoder_1 = OmniGenProcessor.from_pretrained(
+ self.config.pretrained_transformer_model_name_or_path
+ or self.config.pretrained_model_name_or_path
+ )
+ self.text_encoders.append(self.text_encoder_1)
+ self.tokenizers.append(self.text_encoder_1.text_tokenizer)
if self.tokenizer_1 is not None:
self.text_encoder_cls_1 = import_model_class_from_model_name_or_path(
self.config.text_encoder_path,
@@ -555,8 +569,6 @@ def init_text_encoder(self, move_to_accelerator: bool = True):
if not move_to_accelerator:
logger.debug("Not moving text encoders to accelerator.")
return
- self.text_encoders = []
- self.tokenizers = []
if self.tokenizer_1 is not None:
logger.info("Moving text encoder to GPU.")
self.text_encoder_1.to(
@@ -579,6 +591,9 @@ def init_text_encoder(self, move_to_accelerator: bool = True):
self.tokenizers.append(self.tokenizer_3)
self.text_encoders.append(self.text_encoder_3)
+ if not any(self.text_encoders):
+ logger.warning("No text encoders loaded. This may cause issues.")
+
def init_freeze_models(self):
# Freeze vae and text_encoders
if self.vae is not None:
@@ -676,6 +691,10 @@ def init_data_backend(self):
self.accelerator.wait_for_everyone()
def init_validation_prompts(self):
+ self.validation_prompts = None
+ self.validation_shortnames = None
+ self.validation_negative_prompt_embeds = None
+ self.validation_negative_pooled_embeds = None
if self.accelerator.is_main_process:
if self.config.model_family == "flux":
(
@@ -688,6 +707,14 @@ def init_validation_prompts(self):
args=self.config,
embed_cache=StateTracker.get_default_text_embed_cache(),
)
+ elif self.config.model_family == "omnigen":
+ (
+ self.validation_prompts,
+ self.validation_shortnames,
+ ) = prepare_validation_prompt_list(
+ args=self.config,
+ embed_cache=StateTracker.get_default_text_embed_cache(),
+ )
else:
(
self.validation_prompts,
@@ -698,11 +725,6 @@ def init_validation_prompts(self):
args=self.config,
embed_cache=StateTracker.get_default_text_embed_cache(),
)
- else:
- self.validation_prompts = None
- self.validation_shortnames = None
- self.validation_negative_prompt_embeds = None
- self.validation_negative_pooled_embeds = None
self.accelerator.wait_for_everyone()
def stats_memory_used(self):
@@ -869,6 +891,9 @@ def init_trainable_peft_adapter(self):
target_modules=target_modules,
use_dora=self.config.use_dora,
)
+ if self.config.model_family == "omnigen":
+ self.transformer.llm.enable_input_require_grads()
+
self.transformer.add_adapter(transformer_lora_config)
if self.config.init_lora:
addkeys, misskeys = load_lora_weights(
@@ -974,7 +999,11 @@ def init_post_load_freeze(self):
unwrap_model(
self.accelerator, self.unet
).enable_gradient_checkpointing()
- if self.transformer is not None and self.config.model_family != "smoldit":
+ if (
+ self.transformer is not None
+ and self.config.model_family != "smoldit"
+ and hasattr(self.transformer, "enable_gradient_checkpointing")
+ ):
unwrap_model(
self.accelerator, self.transformer
).enable_gradient_checkpointing()
@@ -1996,6 +2025,29 @@ def model_predict(
),
}
model_pred = self.transformer(**inputs).sample
+ elif self.config.model_family == "omnigen":
+ inputs = {
+ "x": noisy_latents,
+ "timestep": timesteps,
+ "input_ids": (
+ batch.get("input_ids").to(self.accelerator.device)
+ if batch.get("input_ids") is not None
+ else None
+ ),
+ "input_img_latents": (
+ batch.get("input_img_latents").to(self.accelerator.device)
+ if batch.get("input_img_latents") is not None
+ else None
+ ),
+ "input_image_sizes": batch.get("input_image_sizes"),
+ "attention_mask": batch.get("encoder_attention_mask").to(
+ self.accelerator.device
+ ),
+ "position_ids": batch.get("position_ids").to(
+ self.accelerator.device
+ ),
+ }
+ model_pred = self.transformer(**inputs)[0]
elif self.unet is not None:
if self.config.model_family == "legacy":
# SD 1.5 or 2.x
@@ -2182,7 +2234,26 @@ def train(self):
f"Received {bsz} latents, but expected {self.config.train_batch_size}. Processing short batch."
)
training_logger.debug(f"Working on batch size: {bsz}")
- if self.config.flow_matching:
+ if self.config.model_family == "omnigen":
+ # x1 corresponds to your latents
+ x1 = latents
+
+ # Sample x0 from a standard normal distribution with the same shape as latents
+ x0 = torch.randn_like(latents)
+
+ # Sample t for each sample in the batch using the specified distribution
+ u = torch.randn(bsz, device=latents.device)
+ t = 1 / (1 + torch.exp(-u)) # t ∈ (0, 1)
+ t = t.to(latents.device, dtype=latents.dtype)
+
+ # Convert t to timesteps compatible with the model (scaled appropriately)
+ timesteps = t * 999
+ timesteps = timesteps.to(
+ self.accelerator.device, dtype=latents.dtype
+ )
+ timesteps = timesteps.long()
+
+ elif self.config.flow_matching:
if (
not self.config.flux_fast_schedule
and not self.config.flux_use_beta_schedule
@@ -2281,6 +2352,14 @@ def train(self):
if self.config.flow_matching:
noisy_latents = (1 - sigmas) * latents + sigmas * input_noise
+ elif self.config.model_family == "omnigen":
+ # Reshape t to match the dimensions of latents for broadcasting
+ dims = [1] * (latents.dim() - 1)
+ t_reshaped = t.view(-1, *dims)
+
+ # Compute noisy_latents (xt) using the Omnigen sampling formula
+ noisy_latents = t_reshaped * x1 + (1 - t_reshaped) * x0
+
else:
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
@@ -2291,19 +2370,25 @@ def train(self):
dtype=self.config.weight_dtype,
)
- encoder_hidden_states = batch["prompt_embeds"].to(
- dtype=self.config.weight_dtype, device=self.accelerator.device
- )
- training_logger.debug(
- f"Encoder hidden states: {encoder_hidden_states.shape}"
- )
+ encoder_hidden_states = None
+ if hasattr(batch["prompt_embeds"], "to"):
+ encoder_hidden_states = batch["prompt_embeds"].to(
+ dtype=self.config.weight_dtype,
+ device=self.accelerator.device,
+ )
+ training_logger.debug(
+ f"Encoder hidden states: {encoder_hidden_states.shape}"
+ )
add_text_embeds = batch["add_text_embeds"]
training_logger.debug(
f"Pooled embeds: {add_text_embeds.shape if add_text_embeds is not None else None}"
)
# Get the target for loss depending on the prediction type
- if self.config.flow_matching:
+ if (
+ self.config.flow_matching
+ or self.config.model_family == "omnigen"
+ ):
# This is the flow-matching target for vanilla SD3.
# If self.config.flow_matching_loss == "diffusion", we will instead use v_prediction (see below)
if self.config.flow_matching_loss == "diffusers":
@@ -2311,9 +2396,10 @@ def train(self):
elif self.config.flow_matching_loss == "compatible":
target = noise - latents
elif self.config.flow_matching_loss == "sd35":
- sigma_reshaped = sigmas.view(-1, 1, 1, 1) # Ensure sigma has the correct shape
+ sigma_reshaped = sigmas.view(
+ -1, 1, 1, 1
+ ) # Ensure sigma has the correct shape
target = (noisy_latents - latents) / sigma_reshaped
-
elif self.noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif (
@@ -2420,7 +2506,10 @@ def train(self):
parent_loss = None
# Compute the per-pixel loss without reducing over spatial dimensions
- if self.config.flow_matching:
+ if (
+ self.config.flow_matching
+ or self.config.model_family == "omnigen"
+ ):
# For flow matching, compute the per-pixel squared differences
loss = (
model_pred.float() - target.float()
diff --git a/helpers/training/validation.py b/helpers/training/validation.py
index b9406a77..65c17817 100644
--- a/helpers/training/validation.py
+++ b/helpers/training/validation.py
@@ -280,6 +280,8 @@ def prepare_validation_prompt_list(args, embed_cache):
validation_negative_pooled_embeds,
validation_negative_time_ids,
)
+ elif model_type == "omnigen":
+ return (validation_prompts, validation_shortnames)
else:
raise ValueError(f"Unknown model type '{model_type}'")
@@ -590,6 +592,10 @@ def _pipeline_cls(self):
from helpers.models.smoldit import SmolDiTPipeline
return SmolDiTPipeline
+ elif model_type == "omnigen":
+ from helpers.models.omnigen.pipeline import OmniGenPipeline
+
+ return OmniGenPipeline
else:
raise NotImplementedError(
f"Model type {model_type} not implemented for validation."
@@ -598,6 +604,9 @@ def _pipeline_cls(self):
def _gather_prompt_embeds(self, validation_prompt: str):
prompt_embeds = {}
current_validation_prompt_mask = None
+ current_validation_prompt_embeds = None
+ current_validation_pooled_embeds = None
+ current_validation_time_ids = None
if (
StateTracker.get_model_family() == "sdxl"
or StateTracker.get_model_family() == "sd3"
@@ -680,25 +689,33 @@ def _gather_prompt_embeds(self, validation_prompt: str):
# logger.debug(
# f"Dtypes: {current_validation_prompt_embeds.dtype}, {self.validation_negative_prompt_embeds.dtype}"
# )
+ elif StateTracker.get_model_family() == "omnigen":
+ # no special treatment needed here.
+ pass
else:
raise NotImplementedError(
f"Model type {StateTracker.get_model_family()} not implemented for validation."
)
- current_validation_prompt_embeds = current_validation_prompt_embeds.to(
- device=self.inference_device, dtype=self.weight_dtype
- )
- self.validation_negative_prompt_embeds = (
- self.validation_negative_prompt_embeds.to(
+ if current_validation_prompt_embeds is not None:
+ current_validation_prompt_embeds = current_validation_prompt_embeds.to(
device=self.inference_device, dtype=self.weight_dtype
)
- )
+ if self.validation_negative_prompt_embeds is not None:
+ self.validation_negative_prompt_embeds = (
+ self.validation_negative_prompt_embeds.to(
+ device=self.inference_device, dtype=self.weight_dtype
+ )
+ )
# when sampling unconditional guidance, you should only zero one or the other prompt, and not both.
# we'll assume that the user has a negative prompt, so that the unconditional sampling works.
# the positive prompt embed is zeroed out for SDXL at the time of it being placed into the cache.
# the embeds are not zeroed out for any other model, including Stable Diffusion 3.
- prompt_embeds["prompt_embeds"] = current_validation_prompt_embeds
- prompt_embeds["negative_prompt_embeds"] = self.validation_negative_prompt_embeds
+ if current_validation_prompt_embeds is not None:
+ prompt_embeds["prompt_embeds"] = current_validation_prompt_embeds
+ prompt_embeds["negative_prompt_embeds"] = (
+ self.validation_negative_prompt_embeds
+ )
if (
StateTracker.get_model_family() == "pixart_sigma"
or StateTracker.get_model_family() == "smoldit"
@@ -1063,6 +1080,19 @@ def setup_pipeline(self, validation_type, enable_ema_model: bool = True):
text_encoder=self.text_encoder_1,
scheduler=self.setup_scheduler(),
)
+ elif self.args.model_family == "omnigen":
+ # we use upstream processor to get negative prompting.
+ from OmniGen import OmniGenProcessor
+
+ self.pipeline = pipeline_cls(
+ vae=self.vae,
+ processor=OmniGenProcessor.from_pretrained(
+ self.args.pretrained_transformer_model_name_or_path
+ or self.args.pretrained_model_name_or_path
+ ),
+ model=self.transformer,
+ device=self.accelerator.device,
+ )
else:
self.pipeline = pipeline_cls.from_pretrained(**pipeline_kwargs)
except Exception as e:
@@ -1261,7 +1291,7 @@ def validate_prompt(
for key, value in pipeline_kwargs.items():
if hasattr(value, "device"):
logger.debug(f"Device for {key}: {value.device}")
- for key, value in self.pipeline.components.items():
+ for key, value in getattr(self.pipeline, "components", {}).items():
if hasattr(value, "device"):
logger.debug(f"Device for {key}: {value.device}")
if StateTracker.get_model_family() == "flux":
@@ -1281,8 +1311,26 @@ def validate_prompt(
pipeline_kwargs["negative_prompt_attention_mask"] = torch.unsqueeze(
pipeline_kwargs.pop("negative_mask")[0], dim=0
).to(device=self.inference_device, dtype=self.weight_dtype)
+ if StateTracker.get_model_family() == "omnigen":
+ pipeline_kwargs["prompt"] = prompt
+ del pipeline_kwargs["negative_prompt"]
+ del pipeline_kwargs["num_images_per_prompt"]
+ del pipeline_kwargs["generator"]
+ del pipeline_kwargs["guidance_rescale"]
+ if "image" in pipeline_kwargs:
+ pipeline_kwargs["input_image"] = pipeline_kwargs.pop("image")
+ pipeline_kwargs["seed"] = self.args.validation_seed
+ pipeline_kwargs["use_kv_cache"] = (
+ False if torch.backends.mps.is_available() else True
+ )
+ pipeline_kwargs["offload_kv_cache"] = (
+ False if torch.backends.mps.is_available() else True
+ )
+ logger.debug(f"OmniGen pipeline kwargs: {pipeline_kwargs}")
- validation_image_results = self.pipeline(**pipeline_kwargs).images
+ validation_image_results = self.pipeline(**pipeline_kwargs)
+ if hasattr(validation_image_results, "images"):
+ validation_image_results = validation_image_results.images
if self.args.controlnet:
validation_image_results = self.stitch_conditioning_images(
validation_image_results, extra_validation_kwargs["image"]
diff --git a/install/apple/poetry.lock b/install/apple/poetry.lock
index dde26187..fb299253 100644
--- a/install/apple/poetry.lock
+++ b/install/apple/poetry.lock
@@ -1559,50 +1559,46 @@ files = [
[[package]]
name = "nvidia-cublas-cu12"
-version = "12.4.5.8"
+version = "12.1.3.1"
description = "CUBLAS native runtime libraries"
optional = false
python-versions = ">=3"
files = [
- {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3"},
- {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b"},
- {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-win_amd64.whl", hash = "sha256:5a796786da89203a0657eda402bcdcec6180254a8ac22d72213abc42069522dc"},
+ {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"},
+ {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"},
]
[[package]]
name = "nvidia-cuda-cupti-cu12"
-version = "12.4.127"
+version = "12.1.105"
description = "CUDA profiling tools runtime libs."
optional = false
python-versions = ">=3"
files = [
- {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a"},
- {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb"},
- {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922"},
+ {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"},
+ {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"},
]
[[package]]
name = "nvidia-cuda-nvrtc-cu12"
-version = "12.4.127"
+version = "12.1.105"
description = "NVRTC native runtime libraries"
optional = false
python-versions = ">=3"
files = [
- {file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198"},
- {file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338"},
- {file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:a961b2f1d5f17b14867c619ceb99ef6fcec12e46612711bcec78eb05068a60ec"},
+ {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"},
+ {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"},
]
[[package]]
name = "nvidia-cuda-runtime-cu12"
-version = "12.4.127"
+version = "12.1.105"
description = "CUDA Runtime native Libraries"
optional = false
python-versions = ">=3"
files = [
- {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3"},
- {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5"},
- {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:09c2e35f48359752dfa822c09918211844a3d93c100a715d79b59591130c5e1e"},
+ {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"},
+ {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"},
]
[[package]]
@@ -1621,41 +1617,35 @@ nvidia-cublas-cu12 = "*"
[[package]]
name = "nvidia-cufft-cu12"
-version = "11.2.1.3"
+version = "11.0.2.54"
description = "CUFFT native runtime libraries"
optional = false
python-versions = ">=3"
files = [
- {file = "nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399"},
- {file = "nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9"},
- {file = "nvidia_cufft_cu12-11.2.1.3-py3-none-win_amd64.whl", hash = "sha256:d802f4954291101186078ccbe22fc285a902136f974d369540fd4a5333d1440b"},
+ {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"},
+ {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"},
]
-[package.dependencies]
-nvidia-nvjitlink-cu12 = "*"
-
[[package]]
name = "nvidia-curand-cu12"
-version = "10.3.5.147"
+version = "10.3.2.106"
description = "CURAND native runtime libraries"
optional = false
python-versions = ">=3"
files = [
- {file = "nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9"},
- {file = "nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b"},
- {file = "nvidia_curand_cu12-10.3.5.147-py3-none-win_amd64.whl", hash = "sha256:f307cc191f96efe9e8f05a87096abc20d08845a841889ef78cb06924437f6771"},
+ {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"},
+ {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"},
]
[[package]]
name = "nvidia-cusolver-cu12"
-version = "11.6.1.9"
+version = "11.4.5.107"
description = "CUDA solver native runtime libraries"
optional = false
python-versions = ">=3"
files = [
- {file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e"},
- {file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260"},
- {file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-win_amd64.whl", hash = "sha256:e77314c9d7b694fcebc84f58989f3aa4fb4cb442f12ca1a9bde50f5e8f6d1b9c"},
+ {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"},
+ {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"},
]
[package.dependencies]
@@ -1665,14 +1655,13 @@ nvidia-nvjitlink-cu12 = "*"
[[package]]
name = "nvidia-cusparse-cu12"
-version = "12.3.1.170"
+version = "12.1.0.106"
description = "CUSPARSE native runtime libraries"
optional = false
python-versions = ">=3"
files = [
- {file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3"},
- {file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1"},
- {file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-win_amd64.whl", hash = "sha256:9bc90fb087bc7b4c15641521f31c0371e9a612fc2ba12c338d3ae032e6b6797f"},
+ {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"},
+ {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"},
]
[package.dependencies]
@@ -1680,12 +1669,13 @@ nvidia-nvjitlink-cu12 = "*"
[[package]]
name = "nvidia-nccl-cu12"
-version = "2.21.5"
+version = "2.20.5"
description = "NVIDIA Collective Communication Library (NCCL) Runtime"
optional = false
python-versions = ">=3"
files = [
- {file = "nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0"},
+ {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01"},
+ {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56"},
]
[[package]]
@@ -1702,16 +1692,41 @@ files = [
[[package]]
name = "nvidia-nvtx-cu12"
-version = "12.4.127"
+version = "12.1.105"
description = "NVIDIA Tools Extension"
optional = false
python-versions = ">=3"
files = [
- {file = "nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3"},
- {file = "nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a"},
- {file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"},
+ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"},
+ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"},
]
+[[package]]
+name = "OmniGen"
+version = "1.0.3"
+description = "OmniGen"
+optional = false
+python-versions = "*"
+files = []
+develop = false
+
+[package.dependencies]
+accelerate = ">=0.26.1"
+datasets = "*"
+diffusers = ">=0.30.3"
+peft = ">=0.9.0"
+safetensors = "*"
+setuptools = "*"
+timm = "*"
+torch = "<2.5"
+transformers = ">=4.45.2"
+
+[package.source]
+type = "git"
+url = "https://github.com/bghira/omnigen"
+reference = "dependency-update/peft"
+resolved_reference = "93c149a1a5f6526a98bca5e6cff764e0a4790782"
+
[[package]]
name = "open-clip-torch"
version = "2.26.1"
@@ -3105,111 +3120,111 @@ torchvision = "*"
[[package]]
name = "tokenizers"
-version = "0.19.1"
+version = "0.20.1"
description = ""
optional = false
python-versions = ">=3.7"
files = [
- {file = "tokenizers-0.19.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:952078130b3d101e05ecfc7fc3640282d74ed26bcf691400f872563fca15ac97"},
- {file = "tokenizers-0.19.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:82c8b8063de6c0468f08e82c4e198763e7b97aabfe573fd4cf7b33930ca4df77"},
- {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f03727225feaf340ceeb7e00604825addef622d551cbd46b7b775ac834c1e1c4"},
- {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:453e4422efdfc9c6b6bf2eae00d5e323f263fff62b29a8c9cd526c5003f3f642"},
- {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:02e81bf089ebf0e7f4df34fa0207519f07e66d8491d963618252f2e0729e0b46"},
- {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b07c538ba956843833fee1190cf769c60dc62e1cf934ed50d77d5502194d63b1"},
- {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e28cab1582e0eec38b1f38c1c1fb2e56bce5dc180acb1724574fc5f47da2a4fe"},
- {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b01afb7193d47439f091cd8f070a1ced347ad0f9144952a30a41836902fe09e"},
- {file = "tokenizers-0.19.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7fb297edec6c6841ab2e4e8f357209519188e4a59b557ea4fafcf4691d1b4c98"},
- {file = "tokenizers-0.19.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2e8a3dd055e515df7054378dc9d6fa8c8c34e1f32777fb9a01fea81496b3f9d3"},
- {file = "tokenizers-0.19.1-cp310-none-win32.whl", hash = "sha256:7ff898780a155ea053f5d934925f3902be2ed1f4d916461e1a93019cc7250837"},
- {file = "tokenizers-0.19.1-cp310-none-win_amd64.whl", hash = "sha256:bea6f9947e9419c2fda21ae6c32871e3d398cba549b93f4a65a2d369662d9403"},
- {file = "tokenizers-0.19.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:5c88d1481f1882c2e53e6bb06491e474e420d9ac7bdff172610c4f9ad3898059"},
- {file = "tokenizers-0.19.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ddf672ed719b4ed82b51499100f5417d7d9f6fb05a65e232249268f35de5ed14"},
- {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:dadc509cc8a9fe460bd274c0e16ac4184d0958117cf026e0ea8b32b438171594"},
- {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dfedf31824ca4915b511b03441784ff640378191918264268e6923da48104acc"},
- {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac11016d0a04aa6487b1513a3a36e7bee7eec0e5d30057c9c0408067345c48d2"},
- {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:76951121890fea8330d3a0df9a954b3f2a37e3ec20e5b0530e9a0044ca2e11fe"},
- {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b342d2ce8fc8d00f376af068e3274e2e8649562e3bc6ae4a67784ded6b99428d"},
- {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d16ff18907f4909dca9b076b9c2d899114dd6abceeb074eca0c93e2353f943aa"},
- {file = "tokenizers-0.19.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:706a37cc5332f85f26efbe2bdc9ef8a9b372b77e4645331a405073e4b3a8c1c6"},
- {file = "tokenizers-0.19.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:16baac68651701364b0289979ecec728546133e8e8fe38f66fe48ad07996b88b"},
- {file = "tokenizers-0.19.1-cp311-none-win32.whl", hash = "sha256:9ed240c56b4403e22b9584ee37d87b8bfa14865134e3e1c3fb4b2c42fafd3256"},
- {file = "tokenizers-0.19.1-cp311-none-win_amd64.whl", hash = "sha256:ad57d59341710b94a7d9dbea13f5c1e7d76fd8d9bcd944a7a6ab0b0da6e0cc66"},
- {file = "tokenizers-0.19.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:621d670e1b1c281a1c9698ed89451395d318802ff88d1fc1accff0867a06f153"},
- {file = "tokenizers-0.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d924204a3dbe50b75630bd16f821ebda6a5f729928df30f582fb5aade90c818a"},
- {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:4f3fefdc0446b1a1e6d81cd4c07088ac015665d2e812f6dbba4a06267d1a2c95"},
- {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9620b78e0b2d52ef07b0d428323fb34e8ea1219c5eac98c2596311f20f1f9266"},
- {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:04ce49e82d100594715ac1b2ce87d1a36e61891a91de774755f743babcd0dd52"},
- {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5c2ff13d157afe413bf7e25789879dd463e5a4abfb529a2d8f8473d8042e28f"},
- {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3174c76efd9d08f836bfccaca7cfec3f4d1c0a4cf3acbc7236ad577cc423c840"},
- {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7c9d5b6c0e7a1e979bec10ff960fae925e947aab95619a6fdb4c1d8ff3708ce3"},
- {file = "tokenizers-0.19.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a179856d1caee06577220ebcfa332af046d576fb73454b8f4d4b0ba8324423ea"},
- {file = "tokenizers-0.19.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:952b80dac1a6492170f8c2429bd11fcaa14377e097d12a1dbe0ef2fb2241e16c"},
- {file = "tokenizers-0.19.1-cp312-none-win32.whl", hash = "sha256:01d62812454c188306755c94755465505836fd616f75067abcae529c35edeb57"},
- {file = "tokenizers-0.19.1-cp312-none-win_amd64.whl", hash = "sha256:b70bfbe3a82d3e3fb2a5e9b22a39f8d1740c96c68b6ace0086b39074f08ab89a"},
- {file = "tokenizers-0.19.1-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:bb9dfe7dae85bc6119d705a76dc068c062b8b575abe3595e3c6276480e67e3f1"},
- {file = "tokenizers-0.19.1-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:1f0360cbea28ea99944ac089c00de7b2e3e1c58f479fb8613b6d8d511ce98267"},
- {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:71e3ec71f0e78780851fef28c2a9babe20270404c921b756d7c532d280349214"},
- {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b82931fa619dbad979c0ee8e54dd5278acc418209cc897e42fac041f5366d626"},
- {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e8ff5b90eabdcdaa19af697885f70fe0b714ce16709cf43d4952f1f85299e73a"},
- {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e742d76ad84acbdb1a8e4694f915fe59ff6edc381c97d6dfdd054954e3478ad4"},
- {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d8c5d59d7b59885eab559d5bc082b2985555a54cda04dda4c65528d90ad252ad"},
- {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b2da5c32ed869bebd990c9420df49813709e953674c0722ff471a116d97b22d"},
- {file = "tokenizers-0.19.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:638e43936cc8b2cbb9f9d8dde0fe5e7e30766a3318d2342999ae27f68fdc9bd6"},
- {file = "tokenizers-0.19.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:78e769eb3b2c79687d9cb0f89ef77223e8e279b75c0a968e637ca7043a84463f"},
- {file = "tokenizers-0.19.1-cp37-none-win32.whl", hash = "sha256:72791f9bb1ca78e3ae525d4782e85272c63faaef9940d92142aa3eb79f3407a3"},
- {file = "tokenizers-0.19.1-cp37-none-win_amd64.whl", hash = "sha256:f3bbb7a0c5fcb692950b041ae11067ac54826204318922da754f908d95619fbc"},
- {file = "tokenizers-0.19.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:07f9295349bbbcedae8cefdbcfa7f686aa420be8aca5d4f7d1ae6016c128c0c5"},
- {file = "tokenizers-0.19.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:10a707cc6c4b6b183ec5dbfc5c34f3064e18cf62b4a938cb41699e33a99e03c1"},
- {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6309271f57b397aa0aff0cbbe632ca9d70430839ca3178bf0f06f825924eca22"},
- {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ad23d37d68cf00d54af184586d79b84075ada495e7c5c0f601f051b162112dc"},
- {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:427c4f0f3df9109314d4f75b8d1f65d9477033e67ffaec4bca53293d3aca286d"},
- {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e83a31c9cf181a0a3ef0abad2b5f6b43399faf5da7e696196ddd110d332519ee"},
- {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c27b99889bd58b7e301468c0838c5ed75e60c66df0d4db80c08f43462f82e0d3"},
- {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bac0b0eb952412b0b196ca7a40e7dce4ed6f6926489313414010f2e6b9ec2adf"},
- {file = "tokenizers-0.19.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8a6298bde623725ca31c9035a04bf2ef63208d266acd2bed8c2cb7d2b7d53ce6"},
- {file = "tokenizers-0.19.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:08a44864e42fa6d7d76d7be4bec62c9982f6f6248b4aa42f7302aa01e0abfd26"},
- {file = "tokenizers-0.19.1-cp38-none-win32.whl", hash = "sha256:1de5bc8652252d9357a666e609cb1453d4f8e160eb1fb2830ee369dd658e8975"},
- {file = "tokenizers-0.19.1-cp38-none-win_amd64.whl", hash = "sha256:0bcce02bf1ad9882345b34d5bd25ed4949a480cf0e656bbd468f4d8986f7a3f1"},
- {file = "tokenizers-0.19.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:0b9394bd204842a2a1fd37fe29935353742be4a3460b6ccbaefa93f58a8df43d"},
- {file = "tokenizers-0.19.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4692ab92f91b87769d950ca14dbb61f8a9ef36a62f94bad6c82cc84a51f76f6a"},
- {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6258c2ef6f06259f70a682491c78561d492e885adeaf9f64f5389f78aa49a051"},
- {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c85cf76561fbd01e0d9ea2d1cbe711a65400092bc52b5242b16cfd22e51f0c58"},
- {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:670b802d4d82bbbb832ddb0d41df7015b3e549714c0e77f9bed3e74d42400fbe"},
- {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:85aa3ab4b03d5e99fdd31660872249df5e855334b6c333e0bc13032ff4469c4a"},
- {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cbf001afbbed111a79ca47d75941e9e5361297a87d186cbfc11ed45e30b5daba"},
- {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4c89aa46c269e4e70c4d4f9d6bc644fcc39bb409cb2a81227923404dd6f5227"},
- {file = "tokenizers-0.19.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:39c1ec76ea1027438fafe16ecb0fb84795e62e9d643444c1090179e63808c69d"},
- {file = "tokenizers-0.19.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c2a0d47a89b48d7daa241e004e71fb5a50533718897a4cd6235cb846d511a478"},
- {file = "tokenizers-0.19.1-cp39-none-win32.whl", hash = "sha256:61b7fe8886f2e104d4caf9218b157b106207e0f2a4905c9c7ac98890688aabeb"},
- {file = "tokenizers-0.19.1-cp39-none-win_amd64.whl", hash = "sha256:f97660f6c43efd3e0bfd3f2e3e5615bf215680bad6ee3d469df6454b8c6e8256"},
- {file = "tokenizers-0.19.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3b11853f17b54c2fe47742c56d8a33bf49ce31caf531e87ac0d7d13d327c9334"},
- {file = "tokenizers-0.19.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d26194ef6c13302f446d39972aaa36a1dda6450bc8949f5eb4c27f51191375bd"},
- {file = "tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:e8d1ed93beda54bbd6131a2cb363a576eac746d5c26ba5b7556bc6f964425594"},
- {file = "tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca407133536f19bdec44b3da117ef0d12e43f6d4b56ac4c765f37eca501c7bda"},
- {file = "tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce05fde79d2bc2e46ac08aacbc142bead21614d937aac950be88dc79f9db9022"},
- {file = "tokenizers-0.19.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:35583cd46d16f07c054efd18b5d46af4a2f070a2dd0a47914e66f3ff5efb2b1e"},
- {file = "tokenizers-0.19.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:43350270bfc16b06ad3f6f07eab21f089adb835544417afda0f83256a8bf8b75"},
- {file = "tokenizers-0.19.1-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b4399b59d1af5645bcee2072a463318114c39b8547437a7c2d6a186a1b5a0e2d"},
- {file = "tokenizers-0.19.1-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6852c5b2a853b8b0ddc5993cd4f33bfffdca4fcc5d52f89dd4b8eada99379285"},
- {file = "tokenizers-0.19.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bcd266ae85c3d39df2f7e7d0e07f6c41a55e9a3123bb11f854412952deacd828"},
- {file = "tokenizers-0.19.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ecb2651956eea2aa0a2d099434134b1b68f1c31f9a5084d6d53f08ed43d45ff2"},
- {file = "tokenizers-0.19.1-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:b279ab506ec4445166ac476fb4d3cc383accde1ea152998509a94d82547c8e2a"},
- {file = "tokenizers-0.19.1-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:89183e55fb86e61d848ff83753f64cded119f5d6e1f553d14ffee3700d0a4a49"},
- {file = "tokenizers-0.19.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b2edbc75744235eea94d595a8b70fe279dd42f3296f76d5a86dde1d46e35f574"},
- {file = "tokenizers-0.19.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:0e64bfde9a723274e9a71630c3e9494ed7b4c0f76a1faacf7fe294cd26f7ae7c"},
- {file = "tokenizers-0.19.1-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0b5ca92bfa717759c052e345770792d02d1f43b06f9e790ca0a1db62838816f3"},
- {file = "tokenizers-0.19.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f8a20266e695ec9d7a946a019c1d5ca4eddb6613d4f466888eee04f16eedb85"},
- {file = "tokenizers-0.19.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63c38f45d8f2a2ec0f3a20073cccb335b9f99f73b3c69483cd52ebc75369d8a1"},
- {file = "tokenizers-0.19.1-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:dd26e3afe8a7b61422df3176e06664503d3f5973b94f45d5c45987e1cb711876"},
- {file = "tokenizers-0.19.1-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:eddd5783a4a6309ce23432353cdb36220e25cbb779bfa9122320666508b44b88"},
- {file = "tokenizers-0.19.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:56ae39d4036b753994476a1b935584071093b55c7a72e3b8288e68c313ca26e7"},
- {file = "tokenizers-0.19.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f9939ca7e58c2758c01b40324a59c034ce0cebad18e0d4563a9b1beab3018243"},
- {file = "tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6c330c0eb815d212893c67a032e9dc1b38a803eccb32f3e8172c19cc69fbb439"},
- {file = "tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec11802450a2487cdf0e634b750a04cbdc1c4d066b97d94ce7dd2cb51ebb325b"},
- {file = "tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2b718f316b596f36e1dae097a7d5b91fc5b85e90bf08b01ff139bd8953b25af"},
- {file = "tokenizers-0.19.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:ed69af290c2b65169f0ba9034d1dc39a5db9459b32f1dd8b5f3f32a3fcf06eab"},
- {file = "tokenizers-0.19.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f8a9c828277133af13f3859d1b6bf1c3cb6e9e1637df0e45312e6b7c2e622b1f"},
- {file = "tokenizers-0.19.1.tar.gz", hash = "sha256:ee59e6680ed0fdbe6b724cf38bd70400a0c1dd623b07ac729087270caeac88e3"},
+ {file = "tokenizers-0.20.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:439261da7c0a5c88bda97acb284d49fbdaf67e9d3b623c0bfd107512d22787a9"},
+ {file = "tokenizers-0.20.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:03dae629d99068b1ea5416d50de0fea13008f04129cc79af77a2a6392792d93c"},
+ {file = "tokenizers-0.20.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b61f561f329ffe4b28367798b89d60c4abf3f815d37413b6352bc6412a359867"},
+ {file = "tokenizers-0.20.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ec870fce1ee5248a10be69f7a8408a234d6f2109f8ea827b4f7ecdbf08c9fd15"},
+ {file = "tokenizers-0.20.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d388d1ea8b7447da784e32e3b86a75cce55887e3b22b31c19d0b186b1c677800"},
+ {file = "tokenizers-0.20.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:299c85c1d21135bc01542237979bf25c32efa0d66595dd0069ae259b97fb2dbe"},
+ {file = "tokenizers-0.20.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e96f6c14c9752bb82145636b614d5a78e9cde95edfbe0a85dad0dd5ddd6ec95c"},
+ {file = "tokenizers-0.20.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc9e95ad49c932b80abfbfeaf63b155761e695ad9f8a58c52a47d962d76e310f"},
+ {file = "tokenizers-0.20.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:f22dee205329a636148c325921c73cf3e412e87d31f4d9c3153b302a0200057b"},
+ {file = "tokenizers-0.20.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a2ffd9a8895575ac636d44500c66dffaef133823b6b25067604fa73bbc5ec09d"},
+ {file = "tokenizers-0.20.1-cp310-none-win32.whl", hash = "sha256:2847843c53f445e0f19ea842a4e48b89dd0db4e62ba6e1e47a2749d6ec11f50d"},
+ {file = "tokenizers-0.20.1-cp310-none-win_amd64.whl", hash = "sha256:f9aa93eacd865f2798b9e62f7ce4533cfff4f5fbd50c02926a78e81c74e432cd"},
+ {file = "tokenizers-0.20.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:4a717dcb08f2dabbf27ae4b6b20cbbb2ad7ed78ce05a829fae100ff4b3c7ff15"},
+ {file = "tokenizers-0.20.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3f84dad1ff1863c648d80628b1b55353d16303431283e4efbb6ab1af56a75832"},
+ {file = "tokenizers-0.20.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:929c8f3afa16a5130a81ab5079c589226273ec618949cce79b46d96e59a84f61"},
+ {file = "tokenizers-0.20.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d10766473954397e2d370f215ebed1cc46dcf6fd3906a2a116aa1d6219bfedc3"},
+ {file = "tokenizers-0.20.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9300fac73ddc7e4b0330acbdda4efaabf74929a4a61e119a32a181f534a11b47"},
+ {file = "tokenizers-0.20.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0ecaf7b0e39caeb1aa6dd6e0975c405716c82c1312b55ac4f716ef563a906969"},
+ {file = "tokenizers-0.20.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5170be9ec942f3d1d317817ced8d749b3e1202670865e4fd465e35d8c259de83"},
+ {file = "tokenizers-0.20.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef3f1ae08fa9aea5891cbd69df29913e11d3841798e0bfb1ff78b78e4e7ea0a4"},
+ {file = "tokenizers-0.20.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ee86d4095d3542d73579e953c2e5e07d9321af2ffea6ecc097d16d538a2dea16"},
+ {file = "tokenizers-0.20.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:86dcd08da163912e17b27bbaba5efdc71b4fbffb841530fdb74c5707f3c49216"},
+ {file = "tokenizers-0.20.1-cp311-none-win32.whl", hash = "sha256:9af2dc4ee97d037bc6b05fa4429ddc87532c706316c5e11ce2f0596dfcfa77af"},
+ {file = "tokenizers-0.20.1-cp311-none-win_amd64.whl", hash = "sha256:899152a78b095559c287b4c6d0099469573bb2055347bb8154db106651296f39"},
+ {file = "tokenizers-0.20.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:407ab666b38e02228fa785e81f7cf79ef929f104bcccf68a64525a54a93ceac9"},
+ {file = "tokenizers-0.20.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2f13a2d16032ebc8bd812eb8099b035ac65887d8f0c207261472803b9633cf3e"},
+ {file = "tokenizers-0.20.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e98eee4dca22849fbb56a80acaa899eec5b72055d79637dd6aa15d5e4b8628c9"},
+ {file = "tokenizers-0.20.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:47c1bcdd61e61136087459cb9e0b069ff23b5568b008265e5cbc927eae3387ce"},
+ {file = "tokenizers-0.20.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:128c1110e950534426e2274837fc06b118ab5f2fa61c3436e60e0aada0ccfd67"},
+ {file = "tokenizers-0.20.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e2e2d47a819d2954f2c1cd0ad51bb58ffac6f53a872d5d82d65d79bf76b9896d"},
+ {file = "tokenizers-0.20.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bdd67a0e3503a9a7cf8bc5a4a49cdde5fa5bada09a51e4c7e1c73900297539bd"},
+ {file = "tokenizers-0.20.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:689b93d2e26d04da337ac407acec8b5d081d8d135e3e5066a88edd5bdb5aff89"},
+ {file = "tokenizers-0.20.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0c6a796ddcd9a19ad13cf146997cd5895a421fe6aec8fd970d69f9117bddb45c"},
+ {file = "tokenizers-0.20.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3ea919687aa7001a8ff1ba36ac64f165c4e89035f57998fa6cedcfd877be619d"},
+ {file = "tokenizers-0.20.1-cp312-none-win32.whl", hash = "sha256:6d3ac5c1f48358ffe20086bf065e843c0d0a9fce0d7f0f45d5f2f9fba3609ca5"},
+ {file = "tokenizers-0.20.1-cp312-none-win_amd64.whl", hash = "sha256:b0874481aea54a178f2bccc45aa2d0c99cd3f79143a0948af6a9a21dcc49173b"},
+ {file = "tokenizers-0.20.1-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:96af92e833bd44760fb17f23f402e07a66339c1dcbe17d79a9b55bb0cc4f038e"},
+ {file = "tokenizers-0.20.1-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:65f34e5b731a262dfa562820818533c38ce32a45864437f3d9c82f26c139ca7f"},
+ {file = "tokenizers-0.20.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17f98fccb5c12ab1ce1f471731a9cd86df5d4bd2cf2880c5a66b229802d96145"},
+ {file = "tokenizers-0.20.1-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b8c0fc3542cf9370bf92c932eb71bdeb33d2d4aeeb4126d9fd567b60bd04cb30"},
+ {file = "tokenizers-0.20.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4b39356df4575d37f9b187bb623aab5abb7b62c8cb702867a1768002f814800c"},
+ {file = "tokenizers-0.20.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bfdad27b0e50544f6b838895a373db6114b85112ba5c0cefadffa78d6daae563"},
+ {file = "tokenizers-0.20.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:094663dd0e85ee2e573126918747bdb40044a848fde388efb5b09d57bc74c680"},
+ {file = "tokenizers-0.20.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14e4cf033a2aa207d7ac790e91adca598b679999710a632c4a494aab0fc3a1b2"},
+ {file = "tokenizers-0.20.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:9310951c92c9fb91660de0c19a923c432f110dbfad1a2d429fbc44fa956bf64f"},
+ {file = "tokenizers-0.20.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:05e41e302c315bd2ed86c02e917bf03a6cf7d2f652c9cee1a0eb0d0f1ca0d32c"},
+ {file = "tokenizers-0.20.1-cp37-none-win32.whl", hash = "sha256:212231ab7dfcdc879baf4892ca87c726259fa7c887e1688e3f3cead384d8c305"},
+ {file = "tokenizers-0.20.1-cp37-none-win_amd64.whl", hash = "sha256:896195eb9dfdc85c8c052e29947169c1fcbe75a254c4b5792cdbd451587bce85"},
+ {file = "tokenizers-0.20.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:741fb22788482d09d68e73ece1495cfc6d9b29a06c37b3df90564a9cfa688e6d"},
+ {file = "tokenizers-0.20.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:10be14ebd8082086a342d969e17fc2d6edc856c59dbdbddd25f158fa40eaf043"},
+ {file = "tokenizers-0.20.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:514cf279b22fa1ae0bc08e143458c74ad3b56cd078b319464959685a35c53d5e"},
+ {file = "tokenizers-0.20.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a647c5b7cb896d6430cf3e01b4e9a2d77f719c84cefcef825d404830c2071da2"},
+ {file = "tokenizers-0.20.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7cdf379219e1e1dd432091058dab325a2e6235ebb23e0aec8d0508567c90cd01"},
+ {file = "tokenizers-0.20.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ba72260449e16c4c2f6f3252823b059fbf2d31b32617e582003f2b18b415c39"},
+ {file = "tokenizers-0.20.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:910b96ed87316e4277b23c7bcaf667ce849c7cc379a453fa179e7e09290eeb25"},
+ {file = "tokenizers-0.20.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e53975a6694428a0586534cc1354b2408d4e010a3103117f617cbb550299797c"},
+ {file = "tokenizers-0.20.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:07c4b7be58da142b0730cc4e5fd66bb7bf6f57f4986ddda73833cd39efef8a01"},
+ {file = "tokenizers-0.20.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b605c540753e62199bf15cf69c333e934077ef2350262af2ccada46026f83d1c"},
+ {file = "tokenizers-0.20.1-cp38-none-win32.whl", hash = "sha256:88b3bc76ab4db1ab95ead623d49c95205411e26302cf9f74203e762ac7e85685"},
+ {file = "tokenizers-0.20.1-cp38-none-win_amd64.whl", hash = "sha256:d412a74cf5b3f68a90c615611a5aa4478bb303d1c65961d22db45001df68afcb"},
+ {file = "tokenizers-0.20.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a25dcb2f41a0a6aac31999e6c96a75e9152fa0127af8ece46c2f784f23b8197a"},
+ {file = "tokenizers-0.20.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a12c3cebb8c92e9c35a23ab10d3852aee522f385c28d0b4fe48c0b7527d59762"},
+ {file = "tokenizers-0.20.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:02e18da58cf115b7c40de973609c35bde95856012ba42a41ee919c77935af251"},
+ {file = "tokenizers-0.20.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f326a1ac51ae909b9760e34671c26cd0dfe15662f447302a9d5bb2d872bab8ab"},
+ {file = "tokenizers-0.20.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0b4872647ea6f25224e2833b044b0b19084e39400e8ead3cfe751238b0802140"},
+ {file = "tokenizers-0.20.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce6238a3311bb8e4c15b12600927d35c267b92a52c881ef5717a900ca14793f7"},
+ {file = "tokenizers-0.20.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57b7a8880b208866508b06ce365dc631e7a2472a3faa24daa430d046fb56c885"},
+ {file = "tokenizers-0.20.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a908c69c2897a68f412aa05ba38bfa87a02980df70f5a72fa8490479308b1f2d"},
+ {file = "tokenizers-0.20.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:da1001aa46f4490099c82e2facc4fbc06a6a32bf7de3918ba798010954b775e0"},
+ {file = "tokenizers-0.20.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:42c097390e2f0ed0a5c5d569e6669dd4e9fff7b31c6a5ce6e9c66a61687197de"},
+ {file = "tokenizers-0.20.1-cp39-none-win32.whl", hash = "sha256:3d4d218573a3d8b121a1f8c801029d70444ffb6d8f129d4cca1c7b672ee4a24c"},
+ {file = "tokenizers-0.20.1-cp39-none-win_amd64.whl", hash = "sha256:37d1e6f616c84fceefa7c6484a01df05caf1e207669121c66213cb5b2911d653"},
+ {file = "tokenizers-0.20.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:48689da7a395df41114f516208d6550e3e905e1239cc5ad386686d9358e9cef0"},
+ {file = "tokenizers-0.20.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:712f90ea33f9bd2586b4a90d697c26d56d0a22fd3c91104c5858c4b5b6489a79"},
+ {file = "tokenizers-0.20.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:359eceb6a620c965988fc559cebc0a98db26713758ec4df43fb76d41486a8ed5"},
+ {file = "tokenizers-0.20.1-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d3caf244ce89d24c87545aafc3448be15870096e796c703a0d68547187192e1"},
+ {file = "tokenizers-0.20.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:03b03cf8b9a32254b1bf8a305fb95c6daf1baae0c1f93b27f2b08c9759f41dee"},
+ {file = "tokenizers-0.20.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:218e5a3561561ea0f0ef1559c6d95b825308dbec23fb55b70b92589e7ff2e1e8"},
+ {file = "tokenizers-0.20.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f40df5e0294a95131cc5f0e0eb91fe86d88837abfbee46b9b3610b09860195a7"},
+ {file = "tokenizers-0.20.1-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:08aaa0d72bb65058e8c4b0455f61b840b156c557e2aca57627056624c3a93976"},
+ {file = "tokenizers-0.20.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:998700177b45f70afeb206ad22c08d9e5f3a80639dae1032bf41e8cbc4dada4b"},
+ {file = "tokenizers-0.20.1-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62f7fbd3c2c38b179556d879edae442b45f68312019c3a6013e56c3947a4e648"},
+ {file = "tokenizers-0.20.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31e87fca4f6bbf5cc67481b562147fe932f73d5602734de7dd18a8f2eee9c6dd"},
+ {file = "tokenizers-0.20.1-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:956f21d359ae29dd51ca5726d2c9a44ffafa041c623f5aa33749da87cfa809b9"},
+ {file = "tokenizers-0.20.1-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:1fbbaf17a393c78d8aedb6a334097c91cb4119a9ced4764ab8cfdc8d254dc9f9"},
+ {file = "tokenizers-0.20.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:ebe63e31f9c1a970c53866d814e35ec2ec26fda03097c486f82f3891cee60830"},
+ {file = "tokenizers-0.20.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:81970b80b8ac126910295f8aab2d7ef962009ea39e0d86d304769493f69aaa1e"},
+ {file = "tokenizers-0.20.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:130e35e76f9337ed6c31be386e75d4925ea807055acf18ca1a9b0eec03d8fe23"},
+ {file = "tokenizers-0.20.1-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cd28a8614f5c82a54ab2463554e84ad79526c5184cf4573bbac2efbbbcead457"},
+ {file = "tokenizers-0.20.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9041ee665d0fa7f5c4ccf0f81f5e6b7087f797f85b143c094126fc2611fec9d0"},
+ {file = "tokenizers-0.20.1-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:62eb9daea2a2c06bcd8113a5824af8ef8ee7405d3a71123ba4d52c79bb3d9f1a"},
+ {file = "tokenizers-0.20.1-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f861889707b54a9ab1204030b65fd6c22bdd4a95205deec7994dc22a8baa2ea4"},
+ {file = "tokenizers-0.20.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:89d5c337d74ea6e5e7dc8af124cf177be843bbb9ca6e58c01f75ea103c12c8a9"},
+ {file = "tokenizers-0.20.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:0b7f515c83397e73292accdbbbedc62264e070bae9682f06061e2ddce67cacaf"},
+ {file = "tokenizers-0.20.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e0305fc1ec6b1e5052d30d9c1d5c807081a7bd0cae46a33d03117082e91908c"},
+ {file = "tokenizers-0.20.1-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5dc611e6ac0fa00a41de19c3bf6391a05ea201d2d22b757d63f5491ec0e67faa"},
+ {file = "tokenizers-0.20.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5ffe0d7f7bfcfa3b2585776ecf11da2e01c317027c8573c78ebcb8985279e23"},
+ {file = "tokenizers-0.20.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:e7edb8ec12c100d5458d15b1e47c0eb30ad606a05641f19af7563bc3d1608c14"},
+ {file = "tokenizers-0.20.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:de291633fb9303555793cc544d4a86e858da529b7d0b752bcaf721ae1d74b2c9"},
+ {file = "tokenizers-0.20.1.tar.gz", hash = "sha256:84edcc7cdeeee45ceedb65d518fffb77aec69311c9c8e30f77ad84da3025f002"},
]
[package.dependencies]
@@ -3233,28 +3248,31 @@ files = [
[[package]]
name = "torch"
-version = "2.5.0"
+version = "2.4.1"
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
optional = false
python-versions = ">=3.8.0"
files = [
- {file = "torch-2.5.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:7f179373a047b947dec448243f4e6598a1c960fa3bb978a9a7eecd529fbc363f"},
- {file = "torch-2.5.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:15fbc95e38d330e5b0ef1593b7bc0a19f30e5bdad76895a5cffa1a6a044235e9"},
- {file = "torch-2.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:f499212f1cffea5d587e5f06144630ed9aa9c399bba12ec8905798d833bd1404"},
- {file = "torch-2.5.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:c54db1fade17287aabbeed685d8e8ab3a56fea9dd8d46e71ced2da367f09a49f"},
- {file = "torch-2.5.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:499a68a756d3b30d10f7e0f6214dc3767b130b797265db3b1c02e9094e2a07be"},
- {file = "torch-2.5.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:9f3df8138a1126a851440b7d5a4869bfb7c9cc43563d64fd9d96d0465b581024"},
- {file = "torch-2.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:b81da3bdb58c9de29d0e1361e52f12fcf10a89673f17a11a5c6c7da1cb1a8376"},
- {file = "torch-2.5.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:ba135923295d564355326dc409b6b7f5bd6edc80f764cdaef1fb0a1b23ff2f9c"},
- {file = "torch-2.5.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2dd40c885a05ef7fe29356cca81be1435a893096ceb984441d6e2c27aff8c6f4"},
- {file = "torch-2.5.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:bc52d603d87fe1da24439c0d5fdbbb14e0ae4874451d53f0120ffb1f6c192727"},
- {file = "torch-2.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:ea718746469246cc63b3353afd75698a288344adb55e29b7f814a5d3c0a7c78d"},
- {file = "torch-2.5.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:6de1fd253e27e7f01f05cd7c37929ae521ca23ca4620cfc7c485299941679112"},
- {file = "torch-2.5.0-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:83dcf518685db20912b71fc49cbddcc8849438cdb0e9dcc919b02a849e2cd9e8"},
- {file = "torch-2.5.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:65e0a60894435608334d68c8811e55fd8f73e5bf8ee6f9ccedb0064486a7b418"},
- {file = "torch-2.5.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:38c21ff1bd39f076d72ab06e3c88c2ea6874f2e6f235c9450816b6c8e7627094"},
- {file = "torch-2.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:ce4baeba9804da5a346e210b3b70826f5811330c343e4fe1582200359ee77fe5"},
- {file = "torch-2.5.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:03e53f577a96e4d41aca472da8faa40e55df89d2273664af390ce1f570e885bd"},
+ {file = "torch-2.4.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:362f82e23a4cd46341daabb76fba08f04cd646df9bfaf5da50af97cb60ca4971"},
+ {file = "torch-2.4.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:e8ac1985c3ff0f60d85b991954cfc2cc25f79c84545aead422763148ed2759e3"},
+ {file = "torch-2.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:91e326e2ccfb1496e3bee58f70ef605aeb27bd26be07ba64f37dcaac3d070ada"},
+ {file = "torch-2.4.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:d36a8ef100f5bff3e9c3cea934b9e0d7ea277cb8210c7152d34a9a6c5830eadd"},
+ {file = "torch-2.4.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:0b5f88afdfa05a335d80351e3cea57d38e578c8689f751d35e0ff36bce872113"},
+ {file = "torch-2.4.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:ef503165f2341942bfdf2bd520152f19540d0c0e34961232f134dc59ad435be8"},
+ {file = "torch-2.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:092e7c2280c860eff762ac08c4bdcd53d701677851670695e0c22d6d345b269c"},
+ {file = "torch-2.4.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:ddddbd8b066e743934a4200b3d54267a46db02106876d21cf31f7da7a96f98ea"},
+ {file = "torch-2.4.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:fdc4fe11db3eb93c1115d3e973a27ac7c1a8318af8934ffa36b0370efe28e042"},
+ {file = "torch-2.4.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:18835374f599207a9e82c262153c20ddf42ea49bc76b6eadad8e5f49729f6e4d"},
+ {file = "torch-2.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:ebea70ff30544fc021d441ce6b219a88b67524f01170b1c538d7d3ebb5e7f56c"},
+ {file = "torch-2.4.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:72b484d5b6cec1a735bf3fa5a1c4883d01748698c5e9cfdbeb4ffab7c7987e0d"},
+ {file = "torch-2.4.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:c99e1db4bf0c5347107845d715b4aa1097e601bdc36343d758963055e9599d93"},
+ {file = "torch-2.4.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:b57f07e92858db78c5b72857b4f0b33a65b00dc5d68e7948a8494b0314efb880"},
+ {file = "torch-2.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:f18197f3f7c15cde2115892b64f17c80dbf01ed72b008020e7da339902742cf6"},
+ {file = "torch-2.4.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:5fc1d4d7ed265ef853579caf272686d1ed87cebdcd04f2a498f800ffc53dab71"},
+ {file = "torch-2.4.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:40f6d3fe3bae74efcf08cb7f8295eaddd8a838ce89e9d26929d4edd6d5e4329d"},
+ {file = "torch-2.4.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:c9299c16c9743001ecef515536ac45900247f4338ecdf70746f2461f9e4831db"},
+ {file = "torch-2.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:6bce130f2cd2d52ba4e2c6ada461808de7e5eccbac692525337cfb4c19421846"},
+ {file = "torch-2.4.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:a38de2803ee6050309aac032676536c3d3b6a9804248537e38e098d0e14817ec"},
]
[package.dependencies]
@@ -3262,26 +3280,25 @@ filelock = "*"
fsspec = "*"
jinja2 = "*"
networkx = "*"
-nvidia-cublas-cu12 = {version = "12.4.5.8", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-nvidia-cuda-cupti-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-nvidia-cuda-nvrtc-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-nvidia-cuda-runtime-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
nvidia-cudnn-cu12 = {version = "9.1.0.70", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-nvidia-cufft-cu12 = {version = "11.2.1.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-nvidia-curand-cu12 = {version = "10.3.5.147", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-nvidia-cusolver-cu12 = {version = "11.6.1.9", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-nvidia-cusparse-cu12 = {version = "12.3.1.170", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-nvidia-nccl-cu12 = {version = "2.21.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-nvidia-nvjitlink-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-nvidia-nvtx-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-setuptools = {version = "*", markers = "python_version >= \"3.12\""}
-sympy = {version = "1.13.1", markers = "python_version >= \"3.9\""}
-triton = {version = "3.1.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\""}
+nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-nccl-cu12 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+setuptools = "*"
+sympy = "*"
+triton = {version = "3.0.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\""}
typing-extensions = ">=4.8.0"
[package.extras]
opt-einsum = ["opt-einsum (>=3.3)"]
-optree = ["optree (>=0.12.0)"]
+optree = ["optree (>=0.11.0)"]
[[package]]
name = "torch-optimi"
@@ -3325,35 +3342,35 @@ dev = ["bitsandbytes", "expecttest", "fire", "hypothesis", "matplotlib", "ninja"
[[package]]
name = "torchaudio"
-version = "2.5.0"
+version = "2.4.1"
description = "An audio package for PyTorch"
optional = false
python-versions = "*"
files = [
- {file = "torchaudio-2.5.0-1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:9dfeedcef3e43010f3ec2d804c8f62fe49ab09ef1c19e6736325939661a293bd"},
- {file = "torchaudio-2.5.0-1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:c35201efd28244152d6edbde92775c10f39f5a5d9346202f07b1554dc78d25a2"},
- {file = "torchaudio-2.5.0-1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:593258f33de1fa16ebe5718c9717daf3695460d48a0188194a5c574a710838cb"},
- {file = "torchaudio-2.5.0-1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:aabf8c4ce919c2e24ace49641ea429360018816371a3d470427fc02ab11156c5"},
- {file = "torchaudio-2.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ab69bdfeb4434159e168a4a2c1618d1d65a5a14a91d17d21256ea960f33405fd"},
- {file = "torchaudio-2.5.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:70fe426ecae408e9a7019cfcbcd4e81b6f084920ffffac2520f1d28a23e145fe"},
- {file = "torchaudio-2.5.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:7e2f596b0e8909924cdf46acc579481132f5c0341824957f1cca8385c61db5b5"},
- {file = "torchaudio-2.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:763bf99b2def4681b1e760883849e0e85fa172eac4a12d1870380d5b7d1149c2"},
- {file = "torchaudio-2.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:982ec0494a27c7f3e7e68c91cc92e6e1ad6f86fedeeec627096051309632b149"},
- {file = "torchaudio-2.5.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:470dd171a2e44a4c1aa89c5cdd4a0ba9f99650b68228f3a63b20ceaafd553567"},
- {file = "torchaudio-2.5.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:ee844aa12fa25f521f64ec86c835acf925d194ed4fb66a9b442436f80b39e8da"},
- {file = "torchaudio-2.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:168d9d2a8216a5f1888713c13914edf410d2e28d39c6bfd9e1211baf6f2c76d8"},
- {file = "torchaudio-2.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c835da771701b06fbe8e19ce643d5e587fd385e5f4d8f140551ce04900b1b96b"},
- {file = "torchaudio-2.5.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:3009ca3e095825911917ab865b2ea48abbf3e66f948c4ffd64916fe6e476bfec"},
- {file = "torchaudio-2.5.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:1971302bfc321be5ea18ba60dbc3ffdc6ae53cb47bb6427db25d5b69e4e932ec"},
- {file = "torchaudio-2.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:e98ad102e0e574a397759078bc9809afc05de6e6f8ac0cb26d2245e0d3817adf"},
- {file = "torchaudio-2.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:85573f66dc09497d282bbf406e8f2b03e5929eb3bdc1f76a9d9bc46644d407b1"},
- {file = "torchaudio-2.5.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:877706453e7329844382d06ffee31cb11b602c6991afefb594086ecdd739a5cf"},
- {file = "torchaudio-2.5.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:74a3de13ec0a5024999aec75b3fafa97891d617ce5566818d3094857d1e0229d"},
- {file = "torchaudio-2.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:242b3c9b1abf212b3a1f28eae9814db4daf4507f74b63c7ec7161d35b3c37147"},
+ {file = "torchaudio-2.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:661909751909340b24f637410dfec02a888867816c3db19ed4f4102ae105244a"},
+ {file = "torchaudio-2.4.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:bfc234cef1d03092ea27440fb79e486722ccb41cff94ebaf9d5a1082436395fe"},
+ {file = "torchaudio-2.4.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:54431179d9a9ccf3feeae98aace07d89fae9fd728e2bc8656efbd70e7edcc6f8"},
+ {file = "torchaudio-2.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:dec97872215c3122b7718ec47ac63e143565c3cced06444d0225e98bf4dd4b5f"},
+ {file = "torchaudio-2.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:60af1531815d22659e5412ea401bed552a16c389938c49664e446e4cfd5ddc06"},
+ {file = "torchaudio-2.4.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:95a0968569f7f4455bfd242bfcd489ec47ad37d2ba0f3d9f738cd1128a5f775c"},
+ {file = "torchaudio-2.4.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:7640aaffb2056e12f2906187b03a22228a0908c87d0295fddf4b0b92334a290b"},
+ {file = "torchaudio-2.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:3c08b42a0c296c8eeee6c533bcae5cfbc0ceae86a34f24fe6bbbb5faa7a7bea1"},
+ {file = "torchaudio-2.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:953946cf610ffd57bb3fdd228effa2112fa51c5dfe36a96611effc9074a3d3be"},
+ {file = "torchaudio-2.4.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:1796a8961decb522c47daab0fbe27c057d6d143ee22bb6ae0d5eb9b2a038c7b6"},
+ {file = "torchaudio-2.4.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:5b62fc7b16ed708b0c07d4393137797e92f63fc3bd5705607d97ba6a9a7cf3f0"},
+ {file = "torchaudio-2.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:d721b186aae7bd8752c9ad95213f5d650926597bb9060728dfe476986a1ff570"},
+ {file = "torchaudio-2.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4ea0fd00142fe795c75bcc20a303981b56f2327c7f7d321b42a8fef1d78aafa9"},
+ {file = "torchaudio-2.4.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:375d8740c8035a50faca7a5afe2fbdb712aa8733715b971b2af61b4003fa1c41"},
+ {file = "torchaudio-2.4.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:74d19cf9ca3dad394afcabb7e6f7ed9ab9f59f2540d502826c7ec3e33985251d"},
+ {file = "torchaudio-2.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:40e9fa8fdc8d328ea4aa90be65fd34c5ef975610dbd707545e3664393a8a2497"},
+ {file = "torchaudio-2.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3adce550850902b9aa6cd2378ccd720ac9ec8cf31e2eba9743ccc84ffcbe76d6"},
+ {file = "torchaudio-2.4.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:98d8e03703f96b13a8d172d1ccdc7badb338227fd762985fdcea6b30f6697bdb"},
+ {file = "torchaudio-2.4.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:36c7e7bc6b358cbf42b769c80206780fa1497d141a985c6b3e7768de44524e9a"},
+ {file = "torchaudio-2.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:f46e34ab3866ad8d8ace0673cd11e697c5cde6a3b7a4d8d789207d4d8badbb6e"},
]
[package.dependencies]
-torch = "2.5.0"
+torch = "2.4.1"
[[package]]
name = "torchmetrics"
@@ -3401,37 +3418,37 @@ trampoline = ">=0.1.2"
[[package]]
name = "torchvision"
-version = "0.20.0"
+version = "0.19.1"
description = "image and video datasets and models for torch deep learning"
optional = false
python-versions = ">=3.8"
files = [
- {file = "torchvision-0.20.0-1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:e084f50ecbdbe7a9cc2fc51ea0367ae35fde46e84a964bf4046cb1c7feb7e3e6"},
- {file = "torchvision-0.20.0-1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:55d7f43ef912ebc4da4bba73a0bbf387d38a6be9cd521679c0f4056f9564b698"},
- {file = "torchvision-0.20.0-1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:f8d0213489acfb138369f2455a6893880c194a8195e381c19f872b277f2654c3"},
- {file = "torchvision-0.20.0-1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:8d6cea8ab0bf72ecb71b07cd0fe836eacf5a5fa98f6629d2261212e90977b963"},
- {file = "torchvision-0.20.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f164d545965186ffd66014e34a966706d12c84198302dd46748cae45984609a4"},
- {file = "torchvision-0.20.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:9c18208575d60b96e7d53a09c453781afea4a81487c9ebc501dfc2bc88daa308"},
- {file = "torchvision-0.20.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:09080359be90314fc4fdd64b11a4d231c1999018f19d58bf7764f5e15f8e9fb3"},
- {file = "torchvision-0.20.0-cp310-cp310-win_amd64.whl", hash = "sha256:a7d46cf096007b7e8df1bddad7375427664a064bc05d9cbff5d506b73c1ab8ca"},
- {file = "torchvision-0.20.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a15de6266a36bcd10d89f6f3d7ba4e2dd567a7a0add616ebc6e65aea20790e5d"},
- {file = "torchvision-0.20.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:b64d9f83cf201ebda4f6b03533e4918fa0b4223b28b0ee3cbede15b8174c7cbd"},
- {file = "torchvision-0.20.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:d80eb740810804bac4b8e6b6411946ab286a1ee1d731db36af2f885333254802"},
- {file = "torchvision-0.20.0-cp311-cp311-win_amd64.whl", hash = "sha256:1fd045757335d34969d176fc5688b643d201860cb45b48ce8d5d8fb90868f746"},
- {file = "torchvision-0.20.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ac0edba534fb071b2b03a2fd5cbbf9b7c259896d17a1d0d830b3c5b7dfae0782"},
- {file = "torchvision-0.20.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:c8f3bc399d9c3e4ba05d74ca6dd5e63fed08ad5c5b302a946c8fcaa56216220f"},
- {file = "torchvision-0.20.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a78c99ebe1a62857b68e97ff9417b92f299f2ee61f009491a114ddad050c493d"},
- {file = "torchvision-0.20.0-cp312-cp312-win_amd64.whl", hash = "sha256:bb0da0950d2034a0412c251a3a9117ff9612157f45177d37ba1b20b472c0864b"},
- {file = "torchvision-0.20.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6a70c81ea5068dd7b1e340ebeabb65364576d8b9819454cfdf812290cf03e45a"},
- {file = "torchvision-0.20.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:95d8c817681a4c2156f66ef83cafc4c5c4b97e4694956d54d7dc554804ee510d"},
- {file = "torchvision-0.20.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:1ab53244701eab897e5c65026ba178c0abbc5bd08629c3d20f737d618e9e5a37"},
- {file = "torchvision-0.20.0-cp39-cp39-win_amd64.whl", hash = "sha256:47d0751aeaa7057ee6a5973d35e7acad3ad7c17b8e57a2c4304d13e001e330ae"},
+ {file = "torchvision-0.19.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:54e8513099e6f586356c70f809d34f391af71ad182fe071cc328a28af2c40608"},
+ {file = "torchvision-0.19.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:20a1f5e02bfdad7714e55fa3fa698347c11d829fa65e11e5a84df07d93350eed"},
+ {file = "torchvision-0.19.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:7b063116164be52fc6deb4762de7f8c90bfa3a65f8d5caf17f8e2d5aadc75a04"},
+ {file = "torchvision-0.19.1-cp310-cp310-win_amd64.whl", hash = "sha256:f40b6acabfa886da1bc3768f47679c61feee6bde90deb979d9f300df8c8a0145"},
+ {file = "torchvision-0.19.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:40514282b4896d62765b8e26d7091c32e17c35817d00ec4be2362ea3ba3d1787"},
+ {file = "torchvision-0.19.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:5a91be061ae5d6d5b95e833b93e57ca4d3c56c5a57444dd15da2e3e7fba96050"},
+ {file = "torchvision-0.19.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:d71a6a6fe3a5281ca3487d4c56ad4aad20ff70f82f1d7c79bcb6e7b0c2af00c8"},
+ {file = "torchvision-0.19.1-cp311-cp311-win_amd64.whl", hash = "sha256:70dea324174f5e9981b68e4b7cd524512c106ba64aedef560a86a0bbf2fbf62c"},
+ {file = "torchvision-0.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:27ece277ff0f6cdc7fed0627279c632dcb2e58187da771eca24b0fbcf3f8590d"},
+ {file = "torchvision-0.19.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:c659ff92a61f188a1a7baef2850f3c0b6c85685447453c03d0e645ba8f1dcc1c"},
+ {file = "torchvision-0.19.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:c07bf43c2a145d792ecd9d0503d6c73577147ece508d45600d8aac77e4cdfcf9"},
+ {file = "torchvision-0.19.1-cp312-cp312-win_amd64.whl", hash = "sha256:b4283d283675556bb0eae31d29996f53861b17cbdcdf3509e6bc050414ac9289"},
+ {file = "torchvision-0.19.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4c4e4f5b24ea6b087b02ed492ab1e21bba3352c4577e2def14248cfc60732338"},
+ {file = "torchvision-0.19.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:9281d63ead929bb19143731154cd1d8bf0b5e9873dff8578a40e90a6bec3c6fa"},
+ {file = "torchvision-0.19.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:4d10bc9083c4d5fadd7edd7b729700a7be48dab4f62278df3bc73fa48e48a155"},
+ {file = "torchvision-0.19.1-cp38-cp38-win_amd64.whl", hash = "sha256:ccf085ef1824fb9e16f1901285bf89c298c62dfd93267a39e8ee42c71255242f"},
+ {file = "torchvision-0.19.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:731f434d91586769e255b5d70ed1a4457e0a1394a95f4aacf0e1e7e21f80c098"},
+ {file = "torchvision-0.19.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:febe4f14d4afcb47cc861d8be7760ab6a123cd0817f97faf5771488cb6aa90f4"},
+ {file = "torchvision-0.19.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:e328309b8670a2e889b2fe76a1c2744a099c11c984da9a822357bd9debd699a5"},
+ {file = "torchvision-0.19.1-cp39-cp39-win_amd64.whl", hash = "sha256:6616f12e00a22e7f3fedbd0fccb0804c05e8fe22871668f10eae65cf3f283614"},
]
[package.dependencies]
numpy = "*"
pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0"
-torch = "2.5.0"
+torch = "2.4.1"
[package.extras]
gdown = ["gdown (>=4.7.3)"]
@@ -3469,13 +3486,13 @@ files = [
[[package]]
name = "transformers"
-version = "4.44.2"
+version = "4.46.1"
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
optional = false
python-versions = ">=3.8.0"
files = [
- {file = "transformers-4.44.2-py3-none-any.whl", hash = "sha256:1c02c65e7bfa5e52a634aff3da52138b583fc6f263c1f28d547dc144ba3d412d"},
- {file = "transformers-4.44.2.tar.gz", hash = "sha256:36aa17cc92ee154058e426d951684a2dab48751b35b49437896f898931270826"},
+ {file = "transformers-4.46.1-py3-none-any.whl", hash = "sha256:f77b251a648fd32e3d14b5e7e27c913b7c29154940f519e4c8c3aa6061df0f05"},
+ {file = "transformers-4.46.1.tar.gz", hash = "sha256:16d79927d772edaf218820a96f9254e2211f9cd7fb3c308562d2d636c964a68c"},
]
[package.dependencies]
@@ -3487,21 +3504,21 @@ pyyaml = ">=5.1"
regex = "!=2019.12.17"
requests = "*"
safetensors = ">=0.4.1"
-tokenizers = ">=0.19,<0.20"
+tokenizers = ">=0.20,<0.21"
tqdm = ">=4.27"
[package.extras]
-accelerate = ["accelerate (>=0.21.0)"]
-agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"]
-all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=0.9.16)", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"]
+accelerate = ["accelerate (>=0.26.0)"]
+agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"]
+all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=0.9.16)", "tokenizers (>=0.20,<0.21)", "torch", "torchaudio", "torchvision"]
audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
-benchmark = ["optimum-benchmark (>=0.2.0)"]
+benchmark = ["optimum-benchmark (>=0.3.0)"]
codecarbon = ["codecarbon (==1.2.0)"]
-deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"]
-deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
-dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=0.9.16)", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
-dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.19,<0.20)", "urllib3 (<2.0.0)"]
-dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=0.9.16)", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
+deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"]
+deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
+dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=0.9.16)", "tokenizers (>=0.20,<0.21)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
+dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.20,<0.21)", "urllib3 (<2.0.0)"]
+dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "libcst", "librosa", "nltk (<=3.8.1)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=0.9.16)", "tokenizers (>=0.20,<0.21)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"]
flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
ftfy = ["ftfy"]
@@ -3512,7 +3529,7 @@ natten = ["natten (>=0.14.6,<0.15.0)"]
onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"]
onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
optuna = ["optuna"]
-quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "ruff (==0.5.1)", "urllib3 (<2.0.0)"]
+quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.5.1)", "urllib3 (<2.0.0)"]
ray = ["ray[tune] (>=2.7.0)"]
retrieval = ["datasets (!=2.5.0)", "faiss-cpu"]
ruff = ["ruff (==0.5.1)"]
@@ -3522,31 +3539,32 @@ serving = ["fastapi", "pydantic", "starlette", "uvicorn"]
sigopt = ["sigopt"]
sklearn = ["scikit-learn"]
speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
-testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
+testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"]
tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
+tiktoken = ["blobfile", "tiktoken"]
timm = ["timm (<=0.9.16)"]
-tokenizers = ["tokenizers (>=0.19,<0.20)"]
-torch = ["accelerate (>=0.21.0)", "torch"]
+tokenizers = ["tokenizers (>=0.20,<0.21)"]
+torch = ["accelerate (>=0.26.0)", "torch"]
torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"]
-torchhub = ["filelock", "huggingface-hub (>=0.23.2,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.19,<0.20)", "torch", "tqdm (>=4.27)"]
-video = ["av (==9.2.0)", "decord (==0.6.0)"]
+torchhub = ["filelock", "huggingface-hub (>=0.23.2,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.20,<0.21)", "torch", "tqdm (>=4.27)"]
+video = ["av (==9.2.0)"]
vision = ["Pillow (>=10.0.1,<=15.0)"]
[[package]]
name = "triton"
-version = "3.1.0"
+version = "3.0.0"
description = "A language and compiler for custom Deep Learning operations"
optional = false
python-versions = "*"
files = [
- {file = "triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b0dd10a925263abbe9fa37dcde67a5e9b2383fc269fdf59f5657cac38c5d1d8"},
- {file = "triton-3.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f34f6e7885d1bf0eaaf7ba875a5f0ce6f3c13ba98f9503651c1e6dc6757ed5c"},
- {file = "triton-3.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8182f42fd8080a7d39d666814fa36c5e30cc00ea7eeeb1a2983dbb4c99a0fdc"},
- {file = "triton-3.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6dadaca7fc24de34e180271b5cf864c16755702e9f63a16f62df714a8099126a"},
- {file = "triton-3.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aafa9a20cd0d9fee523cd4504aa7131807a864cd77dcf6efe7e981f18b8c6c11"},
+ {file = "triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a"},
+ {file = "triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c"},
+ {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"},
+ {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"},
+ {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"},
]
[package.dependencies]
@@ -4202,4 +4220,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
-content-hash = "c24ef2004b2191f79e139e629ba8258fb92b6d6a15ad0d827b6042fa62440ebd"
+content-hash = "61cbffb8a7b0a9a5bac978652b3097c6fa10ae6b0677a2640613ac7736408e91"
diff --git a/install/apple/pyproject.toml b/install/apple/pyproject.toml
index 98e7c4a1..1034d915 100644
--- a/install/apple/pyproject.toml
+++ b/install/apple/pyproject.toml
@@ -9,8 +9,8 @@ package-mode = false
[tool.poetry.dependencies]
python = ">=3.10,<3.13"
-torch = "^2.5.0"
-torchvision = "^0.20.0"
+torch = "2.4.1"
+torchvision = "*"
diffusers = "^0.31.0"
transformers = "^4.44.2"
datasets = "^3.0.0"
@@ -44,7 +44,8 @@ fastapi = {extras = ["standard"], version = "^0.115.0"}
deepspeed = "^0.15.1"
sentencepiece = "^0.2.0"
torchao = "^0.5.0"
-torchaudio = "^2.5.0"
+torchaudio = "*"
+omnigen = {git = "https://github.com/bghira/omnigen", rev = "dependency-update/peft"}
[build-system]