diff --git a/configure.py b/configure.py index f14e13b9..83aa5564 100644 --- a/configure.py +++ b/configure.py @@ -18,12 +18,13 @@ "full": [ "flux", "sdxl", + "omnigen", "pixart_sigma", "kolors", "sd3", "legacy", ], - "lora": ["flux", "sdxl", "kolors", "sd3", "legacy"], + "lora": ["flux", "sdxl", "kolors", "sd3", "legacy", "omnigen"], "controlnet": ["sdxl", "legacy"], } @@ -34,6 +35,7 @@ "kolors": "kwai-kolors/kolors-diffusers", "terminus": "ptx0/terminus-xl-velocity-v2", "sd3": "stabilityai/stable-diffusion-3.5-large", + "omnigen": "Shitao/OmniGen-v1", "legacy": "stabilityai/stable-diffusion-2-1-base", } @@ -43,12 +45,14 @@ "pixart_sigma": 3.4, "kolors": 5.0, "terminus": 8.0, - "sd3": 5.0, + "omnigen": 3.0, + "sd3": 6.0, } model_labels = { "sd3": "Stable Diffusion 3", "flux": "FLUX", + "omnigen": "OmniGen", "pixart_sigma": "PixArt Sigma", "kolors": "Kwai Kolors", "terminus": "Terminus", diff --git a/helpers/caching/text_embeds.py b/helpers/caching/text_embeds.py index 3ee330db..78ef578d 100644 --- a/helpers/caching/text_embeds.py +++ b/helpers/caching/text_embeds.py @@ -208,12 +208,7 @@ def discover_all_files(self): def save_to_cache(self, filename, embeddings): """Add write requests to the queue instead of writing directly.""" - if not self.batch_write_thread.is_alive(): - logger.debug("Restarting background write thread.") - # Start the thread again. - self.process_write_batches = True - self.batch_write_thread = Thread(target=self.batch_write_embeddings) - self.batch_write_thread.start() + self.process_write_batches = True self.write_queue.put((embeddings, filename)) logger.debug( f"save_to_cache called for {filename}, write queue has {self.write_queue.qsize()} items, and the write thread's status: {self.batch_write_thread.is_alive()}" @@ -221,8 +216,6 @@ def save_to_cache(self, filename, embeddings): def batch_write_embeddings(self): """Process write requests in batches.""" - batch = [] - written_elements = 0 while True: try: # Block until an item is available or timeout occurs @@ -233,25 +226,14 @@ def batch_write_embeddings(self): while ( not self.write_queue.empty() and len(batch) < self.write_batch_size ): - logger.debug("Retrieving more items from the queue.") items = self.write_queue.get_nowait() batch.append(items) - logger.debug(f"Batch now contains {len(batch)} items.") self.process_write_batch(batch) self.write_thread_bar.update(len(batch)) - logger.debug("Processed batch write.") - written_elements += len(batch) except queue.Empty: # Timeout occurred, no items were ready - if not self.process_write_batches: - if len(batch) > 0: - self.process_write_batch(batch) - self.write_thread_bar.update(len(batch)) - logger.debug(f"Exiting batch write thread, no more work to do after writing {written_elements} elements") - break - logger.debug(f"Queue is empty. Retrieving new entries. Should retrieve? {self.process_write_batches}") pass except Exception: logger.exception("An error occurred while writing embeddings to disk.") @@ -260,7 +242,6 @@ def batch_write_embeddings(self): def process_write_batch(self, batch): """Write a batch of embeddings to the cache.""" logger.debug(f"Writing {len(batch)} items to disk") - logger.debug(f"Batch: {batch}") with ThreadPoolExecutor(max_workers=self.max_workers) as executor: futures = [ executor.submit(self.data_backend.torch_save, *args) for args in batch @@ -288,8 +269,8 @@ def encode_flux_prompt( text_encoders: List of text encoders. tokenizers: List of tokenizers. prompt: The prompt to encode. - num_images_per_prompt: The number of images to generate per prompt. is_validation: Whether the prompt is for validation. No-op for SD3. + zero_padding_tokens: Whether to zero out padding tokens. Returns: Tuple of (prompt_embeds, pooled_prompt_embeds). @@ -320,6 +301,31 @@ def encode_flux_prompt( return prompt_embeds, pooled_prompt_embeds, time_ids, masks + def encode_omnigen_prompt( + self, text_encoders, tokenizers, prompt: str, is_validation: bool = False + ): + """ + Encode a prompt for an OmniGen model. + + Args: + text_encoders: List of text encoders. + tokenizers: List of tokenizers. + prompt: The prompt to encode. + is_validation: Whether the prompt is for validation. No-op for OmniGen. + + Returns: + Dict of OmniGen inputs + """ + # it's not a text encoder, it's the MLLM preprocessor / tokeniser. + processed = text_encoders[0]( + instructions=prompt, + # use_img_cfg=False, + # separate_cfg_input=False, + # use_input_image_size_as_output=False, + ) + + return processed + # Adapted from pipelines.StableDiffusion3Pipeline.encode_prompt def encode_sd3_prompt( self, @@ -525,9 +531,7 @@ def encode_prompt(self, prompt: str, is_validation: bool = False): prompt, is_validation, zero_padding_tokens=( - True - if StateTracker.get_args().t5_padding == "zero" - else False + True if StateTracker.get_args().t5_padding == "zero" else False ), ) else: @@ -666,6 +670,12 @@ def compute_embeddings_for_prompts( return_concat=return_concat, load_from_cache=load_from_cache, ) + elif self.model_type == "omnigen": + output = self.compute_embeddings_for_omnigen_prompts( + raw_prompts, + return_concat=return_concat, + load_from_cache=load_from_cache, + ) else: raise ValueError( f"No such text encoding backend for model type '{self.model_type}'" @@ -1039,6 +1049,164 @@ def compute_embeddings_for_legacy_prompts( return prompt_embeds_all, attention_masks_all return prompt_embeds_all + def compute_embeddings_for_omnigen_prompts( + self, + prompts: list = None, + return_concat: bool = True, + is_validation: bool = False, + load_from_cache: bool = True, + ): + # print(f"Computing embeddings for Omnigen prompts") + # processed = self.text_encoders[0]( + # prompts, + # ) + + # # processed looks like: + # # {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx} + # print(f"Processed: {processed.keys()}") + + # return processed + processed_all = [] + should_encode = not load_from_cache + args = StateTracker.get_args() + if should_encode: + local_caption_split = self.split_captions_between_processes( + prompts or self.prompts + ) + else: + local_caption_split = prompts or self.prompts + if ( + hasattr(args, "cache_clear_validation_prompts") + and args.cache_clear_validation_prompts + and is_validation + ): + # If --cache_clear_validation_prompts was provided, we will forcibly overwrite them. + load_from_cache = False + should_encode = True + + if self.webhook_handler is not None: + last_reported_index = 0 + self.send_progress_update( + type="init_cache_text_embeds_started", + progress=int(0 // len(local_caption_split)), + total=len(local_caption_split), + current=0, + ) + self.write_thread_bar = tqdm( + desc="Write embeds to disk", + leave=False, + ncols=125, + disable=return_concat, + total=len(local_caption_split), + position=get_rank(), + ) + with torch.no_grad(): + last_reported_index = 0 + for prompt in tqdm( + local_caption_split, + desc="Processing prompts", + disable=return_concat, + miniters=50, + leave=False, + ncols=125, + position=get_rank() + self.accelerator.num_processes + 1, + ): + filename = os.path.join(self.cache_dir, self.hash_prompt(prompt)) + debug_msg = f"Processing file: {filename}, prompt: {prompt}" + prompt = PromptHandler.filter_caption(self.data_backend, prompt) + debug_msg = f"{debug_msg}\n -> filtered prompt: {prompt}" + if prompt is None: + logger.error(f"Filename {filename} does not have a caption.") + continue + logger.debug(debug_msg) + if return_concat and load_from_cache: + try: + # We attempt to load. + _processed = self.load_from_cache(filename) + logger.debug(f"Cached OmniGen inputs: {_processed}") + except Exception as e: + # We failed to load. Now encode the prompt. + logger.error( + f"Failed retrieving prompt from cache:" + f"\n-> prompt: {prompt}" + f"\n-> filename: {filename}" + f"\n-> error: {e}" + f"\n-> id: {self.id}, data_backend id: {self.data_backend.id}" + ) + should_encode = True + raise Exception( + "Cache retrieval for text embed file failed. Ensure your dataloader config value for skip_file_discovery does not contain 'text', and that preserve_data_backend_cache is disabled or unset." + ) + if should_encode: + # If load_from_cache is True, should_encode would be False unless we failed to load. + self.debug_log(f"Encoding prompt: {prompt}") + _processed = self.encode_omnigen_prompt( + self.text_encoders, self.tokenizers, [prompt], is_validation + ) + logger.debug(f"OmniGen prompt embeds: {_processed}") + current_size = self.write_queue.qsize() + if current_size >= 2048: + log_msg = str( + f"[WARNING] Write queue size is {current_size}. This is quite large." + " Consider increasing the write batch size. Delaying encode so that writes can catch up." + ) + self.write_thread_bar.write(log_msg) + while self.write_queue.qsize() > 100: + time.sleep(0.1) + + self.debug_log(f"Adding embed to write queue: {filename}") + self.save_to_cache(filename, _processed) + if ( + self.webhook_handler is not None + and int( + self.write_thread_bar.n % self.webhook_progress_interval + ) + < 10 + ): + last_reported_index = int( + self.write_thread_bar.n % self.webhook_progress_interval + ) + self.send_progress_update( + type="init_cache_text_embeds_status_update", + progress=int( + self.write_thread_bar.n + // len(local_caption_split) + * 100 + ), + total=len(local_caption_split), + current=0, + ) + + if not return_concat: + del _processed + continue + + if return_concat: + processed_all.append(_processed) + + while self.write_queue.qsize() > 0: + time.sleep(0.1) # Sleep briefly to avoid busy-waiting + + if self.webhook_handler is not None: + self.send_progress_update( + type="init_cache_text_embeds_status_complete", + progress=100, + total=len(local_caption_split), + current=len(local_caption_split), + ) + + # Close the tqdm progress bar after the loop + self.write_thread_bar.close() + self.process_write_batches = False + + if not return_concat: + del processed_all + return + + logger.debug(f"Returning all prompt embeds: {processed_all}") + + return processed_all + def compute_embeddings_for_flux_prompts( self, prompts: list = None, @@ -1320,7 +1488,7 @@ def compute_embeddings_for_sd3_prompts( ) if should_encode: # If load_from_cache is True, should_encode would be False unless we failed to load. - self.debug_log(f"Encoding filename {filename} :: device {self.text_encoders[0].device} :: prompt {prompt}") + self.debug_log(f"Encoding prompt: {prompt}") prompt_embeds, pooled_prompt_embeds = self.encode_sd3_prompt( self.text_encoders, self.tokenizers, @@ -1333,7 +1501,7 @@ def compute_embeddings_for_sd3_prompts( ), ) logger.debug( - f"Filename {filename} SD3 prompt embeds: {prompt_embeds.shape}, {pooled_prompt_embeds.shape}" + f"SD3 prompt embeds: {prompt_embeds.shape}, {pooled_prompt_embeds.shape}" ) add_text_embeds = pooled_prompt_embeds # StabilityAI say not to zero them out. diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index ee3b4f79..13f3fa5e 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -85,7 +85,16 @@ def get_argument_parser(): ) parser.add_argument( "--model_family", - choices=["pixart_sigma", "kolors", "sd3", "flux", "smoldit", "sdxl", "legacy"], + choices=[ + "omnigen", + "pixart_sigma", + "kolors", + "sd3", + "flux", + "smoldit", + "sdxl", + "legacy", + ], default=None, required=True, help=("The model family to train. This option is required."), @@ -2079,7 +2088,7 @@ def parse_cmdline_args(input_args=None): if ( args.pretrained_vae_model_name_or_path is not None - and args.model_family in ["legacy", "flux", "sd3"] + and args.model_family in ["legacy", "flux", "sd3", "omnigen"] and "sdxl" in args.pretrained_vae_model_name_or_path and "deepfloyd" not in args.model_type ): @@ -2109,7 +2118,6 @@ def parse_cmdline_args(input_args=None): info_log( f"SD3 embeds for unconditional captions: t5={args.sd3_t5_uncond_behaviour}, clip={args.sd3_clip_uncond_behaviour}" ) - elif "deepfloyd" in args.model_type: deepfloyd_pixel_alignment = 8 if args.aspect_bucket_alignment != deepfloyd_pixel_alignment: diff --git a/helpers/models/omnigen/pipeline.py b/helpers/models/omnigen/pipeline.py new file mode 100644 index 00000000..88c61ec9 --- /dev/null +++ b/helpers/models/omnigen/pipeline.py @@ -0,0 +1,363 @@ +import os +import inspect +from typing import Any, Callable, Dict, List, Optional, Union +import gc + +from PIL import Image +import numpy as np +import torch +from huggingface_hub import snapshot_download +from peft import LoraConfig, PeftModel +from diffusers.models import AutoencoderKL +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from safetensors.torch import load_file + +from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler + + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from OmniGen import OmniGenPipeline + >>> pipe = FluxControlNetPipeline.from_pretrained( + ... base_model + ... ) + >>> prompt = "A woman holds a bouquet of flowers and faces the camera" + >>> image = pipe( + ... prompt, + ... guidance_scale=2.5, + ... num_inference_steps=50, + ... ).images[0] + >>> image.save("t2i.png") + ``` +""" + + +90 + + +class OmniGenPipeline: + def __init__( + self, + vae: AutoencoderKL, + model: OmniGen, + processor: OmniGenProcessor, + device: Union[str, torch.device], + ): + self.vae = vae + self.model = model + self.processor = processor + self.device = device + + self.model.to(torch.bfloat16) + self.model.eval() + self.vae.eval() + + self.model_cpu_offload = False + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path, vae_path: str = None, **kwargs + ): + if not os.path.exists(pretrained_model_name_or_path) or ( + not os.path.exists( + os.path.join(pretrained_model_name_or_path, "model.safetensors") + ) + and pretrained_model_name_or_path == "Shitao/OmniGen-v1" + ): + logger.info("Model not found, downloading...") + cache_folder = os.getenv("HF_HUB_CACHE") + pretrained_model_name_or_path = snapshot_download( + repo_id=pretrained_model_name_or_path, + cache_dir=cache_folder, + ignore_patterns=[ + "flax_model.msgpack", + "rust_model.ot", + "tf_model.h5", + "model.pt", + ], + ) + logger.info(f"Downloaded model to {pretrained_model_name_or_path}") + model = OmniGen.from_pretrained(pretrained_model_name_or_path) + processor = OmniGenProcessor.from_pretrained(pretrained_model_name_or_path) + + if os.path.exists(os.path.join(pretrained_model_name_or_path, "vae")): + vae = AutoencoderKL.from_pretrained( + os.path.join(pretrained_model_name_or_path, "vae") + ) + elif vae_path is not None: + vae = AutoencoderKL.from_pretrained(vae_path).to(device) + else: + logger.info( + f"No VAE found in {pretrained_model_name_or_path}, downloading stabilityai/sdxl-vae from HF" + ) + vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device) + + print(f"OmniGenPipeline received unexpected arguments: {kwargs.keys()}") + + return cls(vae, model, processor) + + def merge_lora(self, lora_path: str): + model = PeftModel.from_pretrained(self.model, lora_path) + model.merge_and_unload() + + self.model = model + + def to(self, device: Union[str, torch.device]): + if isinstance(device, str): + device = torch.device(device) + self.model.to(device) + self.vae.to(device) + self.device = device + + def vae_encode(self, x, dtype): + if self.vae.config.shift_factor is not None: + x = self.vae.encode(x).latent_dist.sample() + x = (x - self.vae.config.shift_factor) * self.vae.config.scaling_factor + else: + x = ( + self.vae.encode(x) + .latent_dist.sample() + .mul_(self.vae.config.scaling_factor) + ) + x = x.to(dtype) + return x + + def move_to_device(self, data): + if isinstance(data, list): + return [x.to(self.device) for x in data] + return data.to(self.device) + + def enable_model_cpu_offload(self): + self.model_cpu_offload = True + self.model.to("cpu") + self.vae.to("cpu") + torch.cuda.empty_cache() # Clear VRAM + gc.collect() # Run garbage collection to free system RAM + + def disable_model_cpu_offload(self): + self.model_cpu_offload = False + self.model.to(self.device) + self.vae.to(self.device) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + input_images: Union[List[str], List[List[str]]] = None, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 50, + guidance_scale: float = 3, + use_img_guidance: bool = True, + img_guidance_scale: float = 1.6, + max_input_image_size: int = 1024, + separate_cfg_infer: bool = True, + offload_model: bool = False, + use_kv_cache: bool = True, + offload_kv_cache: bool = True, + use_input_image_size_as_output: bool = False, + dtype: torch.dtype = torch.bfloat16, + seed: int = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + input_images (`List[str]` or `List[List[str]]`, *optional*): + The list of input images. We will replace the "<|image_i|>" in prompt with the 1-th image in list. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. The number must be a multiple of 16. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. The number must be a multiple of 16. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + use_img_guidance (`bool`, *optional*, defaults to True): + Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800). + img_guidance_scale (`float`, *optional*, defaults to 1.6): + Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800). + max_input_image_size (`int`, *optional*, defaults to 1024): the maximum size of input image, which will be used to crop the input image to the maximum size + separate_cfg_infer (`bool`, *optional*, defaults to False): + Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference. + use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference + offload_kv_cache (`bool`, *optional*, defaults to True): offload the cached key and value to cpu, which can save memory but slow down the generation silightly + offload_model (`bool`, *optional*, defaults to False): offload the model to cpu, which can save memory but slow down the generation + use_input_image_size_as_output (bool, defaults to False): whether to use the input image size as the output image size, which can be used for single-image input, e.g., image editing task + seed (`int`, *optional*): + A random seed for generating output. + dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): + data type for the model + Examples: + + Returns: + A list with the generated images. + """ + # check inputs: + if use_input_image_size_as_output: + assert ( + isinstance(prompt, str) and len(input_images) == 1 + ), "if you want to make sure the output image have the same size as the input image, please only input one image instead of multiple input images" + else: + assert ( + height % 16 == 0 and width % 16 == 0 + ), "The height and width must be a multiple of 16." + if input_images is None: + use_img_guidance = False + if isinstance(prompt, str): + prompt = [prompt] + input_images = [input_images] if input_images is not None else None + + # set model and processor + if max_input_image_size != self.processor.max_image_size: + self.processor = OmniGenProcessor( + self.processor.text_tokenizer, max_image_size=max_input_image_size + ) + if offload_model: + self.enable_model_cpu_offload() + else: + self.disable_model_cpu_offload() + + input_data = self.processor( + prompt, + input_images, + height=height, + width=width, + use_img_cfg=use_img_guidance, + separate_cfg_input=separate_cfg_infer, + use_input_image_size_as_output=use_input_image_size_as_output, + ) + print(f"Input shapes: {input_data['attention_mask'][0].shape}") + + num_prompt = len(prompt) + num_cfg = 2 if use_img_guidance else 1 + if use_input_image_size_as_output: + if separate_cfg_infer: + height, width = input_data["input_pixel_values"][0][0].shape[-2:] + else: + height, width = input_data["input_pixel_values"][0].shape[-2:] + latent_size_h, latent_size_w = height // 8, width // 8 + + if seed is not None: + generator = torch.Generator(device=self.device).manual_seed(seed) + else: + generator = None + latents = torch.randn( + num_prompt, + 4, + latent_size_h, + latent_size_w, + device=self.device, + generator=generator, + ) + latents = torch.cat([latents] * (1 + num_cfg), 0).to(dtype) + + if input_images is not None and self.model_cpu_offload: + self.vae.to(self.device) + input_img_latents = [] + if separate_cfg_infer: + for temp_pixel_values in input_data["input_pixel_values"]: + temp_input_latents = [] + for img in temp_pixel_values: + img = self.vae_encode(img.to(self.device), dtype) + temp_input_latents.append(img) + input_img_latents.append(temp_input_latents) + else: + for img in input_data["input_pixel_values"]: + img = self.vae_encode(img.to(self.device), dtype) + input_img_latents.append(img) + if input_images is not None and self.model_cpu_offload: + self.vae.to("cpu") + torch.cuda.empty_cache() # Clear VRAM + gc.collect() # Run garbage collection to free system RAM + + model_kwargs = dict( + input_ids=self.move_to_device(input_data["input_ids"]), + input_img_latents=input_img_latents, + input_image_sizes=input_data["input_image_sizes"], + attention_mask=self.move_to_device(input_data["attention_mask"]), + position_ids=self.move_to_device(input_data["position_ids"]), + cfg_scale=guidance_scale, + img_cfg_scale=img_guidance_scale, + use_img_cfg=use_img_guidance, + use_kv_cache=use_kv_cache, + offload_model=offload_model, + ) + + if separate_cfg_infer: + func = self.model.forward_with_separate_cfg + else: + func = self.model.forward_with_cfg + self.model.to(dtype) + + if self.model_cpu_offload: + for name, param in self.model.named_parameters(): + if "layers" in name and "layers.0" not in name: + param.data = param.data.cpu() + else: + param.data = param.data.to(self.device) + for buffer_name, buffer in self.model.named_buffers(): + setattr(self.model, buffer_name, buffer.to(self.device)) + # else: + # self.model.to(self.device) + + scheduler = OmniGenScheduler(num_steps=num_inference_steps) + samples = scheduler( + latents, + func, + model_kwargs, + use_kv_cache=use_kv_cache, + offload_kv_cache=offload_kv_cache, + ) + samples = samples.chunk((1 + num_cfg), dim=0)[0] + + if self.model_cpu_offload: + self.model.to("cpu") + torch.cuda.empty_cache() + gc.collect() + + self.vae.to(self.device) + samples = samples.to(torch.float32) + if self.vae.config.shift_factor is not None: + samples = ( + samples / self.vae.config.scaling_factor + self.vae.config.shift_factor + ) + else: + samples = samples / self.vae.config.scaling_factor + samples = self.vae.decode( + samples.to(dtype=self.vae.dtype, device=self.vae.device) + ).sample + + if self.model_cpu_offload: + self.vae.to("cpu") + torch.cuda.empty_cache() + gc.collect() + + output_samples = (samples * 0.5 + 0.5).clamp(0, 1) * 255 + output_samples = ( + output_samples.permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() + ) + output_images = [] + for i, sample in enumerate(output_samples): + output_images.append(Image.fromarray(sample)) + + torch.cuda.empty_cache() # Clear VRAM + gc.collect() # Run garbage collection to free system RAM + return output_images diff --git a/helpers/models/omnigen/processor.py b/helpers/models/omnigen/processor.py new file mode 100644 index 00000000..bc634d26 --- /dev/null +++ b/helpers/models/omnigen/processor.py @@ -0,0 +1,187 @@ +import os +import re +from typing import Dict, List +import json + +import torch +import numpy as np +import random +from PIL import Image +from torchvision import transforms +from transformers import AutoTokenizer +from huggingface_hub import snapshot_download + +from OmniGen.utils import ( + create_logger, + update_ema, + requires_grad, + center_crop_arr, + crop_arr, +) + + +class OmniGenCollator: + def __init__(self, pad_token_id=2): + self.pad_token_id = pad_token_id + + def __call__(self, features): + print(f"features: {features}") + input_ids = [f[0]["input_ids"] for f in features] + attention_masks = [] + max_length = max(len(ids) for ids in input_ids) + + # Pad input_ids and create attention masks + padded_input_ids = [] + for ids in input_ids: + pad_length = max_length - len(ids) + padded_ids = [self.pad_token_id] * pad_length + ids + attention_mask = [0] * pad_length + [1] * len(ids) + padded_input_ids.append(padded_ids) + attention_masks.append(attention_mask) + + padded_input_ids = torch.tensor(padded_input_ids) + attention_masks = torch.tensor(attention_masks) + + # Handle pixel values + pixel_values = [ + f[0]["pixel_values"] for f in features if f[0]["pixel_values"] is not None + ] + if pixel_values: + pixel_values = [pv for sublist in pixel_values for pv in sublist] + pixel_values = torch.stack(pixel_values) + else: + pixel_values = None + + return { + "input_ids": padded_input_ids, + "attention_mask": attention_masks, + "pixel_values": pixel_values, + # Include other necessary fields + } + + +class OmniGenTrainingProcessor: + def __init__(self, text_tokenizer, max_image_size: int = 1024): + self.text_tokenizer = text_tokenizer + self.max_image_size = max_image_size + + self.image_transform = transforms.Compose( + [ + transforms.Lambda( + lambda pil_image: crop_arr(pil_image, max_image_size) + ), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True + ), + ] + ) + + self.collator = OmniGenCollator() + + @classmethod + def from_pretrained(cls, model_name): + if not os.path.exists(model_name): + cache_folder = os.getenv("HF_HUB_CACHE") + model_name = snapshot_download( + repo_id=model_name, cache_dir=cache_folder, allow_patterns="*.json" + ) + text_tokenizer = AutoTokenizer.from_pretrained(model_name) + + return cls(text_tokenizer) + + def process_image(self, image): + image = Image.open(image).convert("RGB") + return self.image_transform(image) + + def process_multi_modal_prompt(self, text, input_images): + text = self.add_prefix_instruction(text) + if input_images is None or len(input_images) == 0: + model_inputs = self.text_tokenizer(text) + return { + "input_ids": model_inputs.input_ids, + "pixel_values": None, + "image_sizes": None, + } + + pattern = r"<\|image_\d+\|>" + prompt_chunks = [ + self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text) + ] + + for i in range(1, len(prompt_chunks)): + if prompt_chunks[i][0] == 1: + prompt_chunks[i] = prompt_chunks[i][1:] + + image_tags = re.findall(pattern, text) + image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags] + + unique_image_ids = sorted(list(set(image_ids))) + assert unique_image_ids == list( + range(1, len(unique_image_ids) + 1) + ), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}" + # total images must be the same as the number of image tags + assert len(unique_image_ids) == len( + input_images + ), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images" + + input_images = [input_images[x - 1] for x in image_ids] + + all_input_ids = [] + img_inx = [] + idx = 0 + for i in range(len(prompt_chunks)): + all_input_ids.extend(prompt_chunks[i]) + if i != len(prompt_chunks) - 1: + start_inx = len(all_input_ids) + size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16 + img_inx.append([start_inx, start_inx + size]) + all_input_ids.extend([0] * size) + + return { + "input_ids": all_input_ids, + "pixel_values": input_images, + "image_sizes": img_inx, + } + + def add_prefix_instruction(self, prompt): + user_prompt = "<|user|>\n" + generation_prompt = ( + "Generate an image according to the following instructions\n" + ) + assistant_prompt = "<|assistant|>\n<|diffusion|>" + prompt_suffix = "<|end|>\n" + prompt = ( + f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}" + ) + return prompt + + def __call__( + self, + instructions: List[str], + input_images: List[List[str]] = None, + height: int = 1024, + width: int = 1024, + ) -> Dict: + + if isinstance(instructions, str): + instructions = [instructions] + input_images = [input_images] + + input_data = [] + for i in range(len(instructions)): + cur_instruction = instructions[i] + cur_input_images = None if input_images is None else input_images[i] + if cur_input_images is not None and len(cur_input_images) > 0: + cur_input_images = [self.process_image(x) for x in cur_input_images] + else: + cur_input_images = None + assert "<|image_1|>" not in cur_instruction + + mllm_input = self.process_multi_modal_prompt( + cur_instruction, cur_input_images + ) + + input_data.append((mllm_input, [height, width])) + + return self.collator(input_data) diff --git a/helpers/training/adapter.py b/helpers/training/adapter.py index 04b99069..06676d06 100644 --- a/helpers/training/adapter.py +++ b/helpers/training/adapter.py @@ -9,6 +9,9 @@ def determine_adapter_target_modules(args, unet, transformer): elif transformer is not None: target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + if args.model_family.lower() == "omnigen": + target_modules = ["qkv_proj", "o_proj"] + if args.model_family.lower() == "flux" and args.flux_lora_target == "all": # target_modules = mmdit layers here target_modules = [ diff --git a/helpers/training/collate.py b/helpers/training/collate.py index c1549e3a..f9c9fb40 100644 --- a/helpers/training/collate.py +++ b/helpers/training/collate.py @@ -1,5 +1,6 @@ import torch import logging +import random import concurrent.futures import numpy as np from os import environ @@ -213,7 +214,12 @@ def compute_latents(filepaths, data_backend_id: str): def compute_single_embedding( - caption, text_embed_cache, is_sdxl, is_sd3: bool = False, is_flux: bool = False + caption, + text_embed_cache, + is_sdxl, + is_sd3: bool = False, + is_flux: bool = False, + is_omnigen: bool = False, ): """Worker function to compute embedding for a single caption.""" if caption == "" or not caption: @@ -246,6 +252,11 @@ def compute_single_embedding( time_ids[0], masks[0] if masks is not None else None, ) + elif is_omnigen: + processed = text_embed_cache.compute_embeddings_for_omnigen_prompts( + prompts=[caption] + ) + return processed else: prompt_embeds = text_embed_cache.compute_embeddings_for_legacy_prompts( [caption] @@ -284,6 +295,7 @@ def compute_prompt_embeddings(captions, text_embed_cache): is_pixart_sigma = text_embed_cache.model_type == "pixart_sigma" is_smoldit = text_embed_cache.model_type == "smoldit" is_flux = text_embed_cache.model_type == "flux" + is_omnigen = text_embed_cache.model_type == "omnigen" # Use a thread pool to compute embeddings concurrently with ThreadPoolExecutor() as executor: @@ -295,6 +307,7 @@ def compute_prompt_embeddings(captions, text_embed_cache): [is_sdxl] * len(captions), [is_sd3] * len(captions), [is_flux] * len(captions), + [is_omnigen] * len(captions), ) ) @@ -331,6 +344,9 @@ def compute_prompt_embeddings(captions, text_embed_cache): torch.stack(time_ids), torch.stack(masks) if None not in masks else None, ) + elif is_omnigen: + embeddings = [e[0] for e in embeddings] + return embeddings else: # Separate the tuples prompt_embeds = [t[0] for t in embeddings] @@ -426,8 +442,17 @@ def collate_fn(batch): "This trainer is not designed to handle multiple batches in a single collate." ) debug_log("Begin collate_fn on batch") - - # SDXL Dropout + ( + latent_batch, + prompt_embeds_all, + add_text_embeds_all, + input_ids, + batch_time_ids, + batch_luminance, + conditioning_pixel_values, + attn_mask, + conditioning_type, + ) = (None, None, None, None, None, None, None, None, None) dropout_probability = StateTracker.get_args().caption_dropout_probability batch = batch[0] examples = batch["training_samples"] @@ -528,11 +553,70 @@ def collate_fn(batch): attn_mask = None batch_time_ids = None + input_ids = None + extra_batch_inputs = {} if StateTracker.get_model_family() == "flux": debug_log("Compute and stack Flux time ids") prompt_embeds_all, add_text_embeds_all, batch_time_ids, attn_mask = ( compute_prompt_embeddings(captions, text_embed_cache) ) + elif StateTracker.get_model_family() == "omnigen": + # instruction, output_image = example['instruction'], example['input_images'], example['output_image'] + omnigen_processed_embeddings = compute_prompt_embeddings( + captions, text_embed_cache + ) + from OmniGen.processor import OmniGenCollator + + attn_mask = [e.get("attention_mask") for e in omnigen_processed_embeddings] + attn_mask_len = len(attn_mask[0][0]) + attn_mask = torch.stack(attn_mask, dim=0) + + # we can use the OmniGenCollator.create_position to make positional ids + num_tokens_for_output_images = [] + for img_size in [ + [ + latent_batch.shape[3] * 8, + latent_batch.shape[2] * 8, + ] + * len(latent_batch) + ]: + num_img_tokens = img_size[0] * img_size[1] // 16 // 16 + num_text_tokens = attn_mask_len + total_num_tokens = num_img_tokens - num_text_tokens + num_tokens_for_output_images.append(total_num_tokens) + position_ids = OmniGenCollator.create_position( + attn_mask, num_tokens_for_output_images + ) + # pad attn_mask to match the position_ids, eg. mask [1, 1, 1, 57] -> [1, 1, 1, 4097] + attn_mask = torch.cat( + [ + attn_mask, + torch.zeros( + ( + attn_mask.shape[0], + attn_mask.shape[1], + num_tokens_for_output_images[0] + 1, + ) + ), + ], + dim=-1, + ) + + # TODO: support "input images" for OmniGen which behave as conditioning images, eg. ControlNet Canny, Depth, etc. + # conditioning_pixel_values = torch.stack([e.get('input_pixel_values') for e in omnigen_processed_embeddings], dim=0) + # input_image_sizes = [e.get('input_image_size') for e in omnigen_processed_embeddings] + # extra_batch_inputs['conditioning_pixel_values'] = conditioning_pixel_values + # extra_batch_inputs['input_image_sizes'] = input_image_sizes + # input_ids = [e.get('input_ids') for e in omnigen_processed_embeddings] + # input_ids = torch.stack(input_ids, dim=0) + # TODO: Support instruction/conditioning image dropout for OmniGen. + # if random.random() < StateTracker.get_args().caption_dropout_probability: + # instruction = '' + # latent_batch = None + padding_images = [e.get("padding_image") for e in omnigen_processed_embeddings] + extra_batch_inputs["position_ids"] = position_ids + extra_batch_inputs["padding_images"] = padding_images + extra_batch_inputs["input_ids"] = input_ids else: prompt_embeds_all, add_text_embeds_all = compute_prompt_embeddings( captions, text_embed_cache @@ -566,4 +650,5 @@ def collate_fn(batch): "encoder_attention_mask": attn_mask, "is_regularisation_data": is_regularisation_data, "conditioning_type": conditioning_type, + **extra_batch_inputs, } diff --git a/helpers/training/diffusion_model.py b/helpers/training/diffusion_model.py index 78611dc8..6ef9c86b 100644 --- a/helpers/training/diffusion_model.py +++ b/helpers/training/diffusion_model.py @@ -34,7 +34,24 @@ def load_diffusion_model(args, weight_dtype): bnb_4bit_compute_dtype=weight_dtype, ) - if args.model_family == "sd3": + if args.model_family == "omnigen": + try: + from OmniGen import OmniGen + except ImportError: + logger.error( + "Could not import Omnigen. Please install Omnigen to use this model." + ) + raise + logger.info("Loading OmniGen model..") + transformer = OmniGen.from_pretrained( + args.pretrained_transformer_model_name_or_path + or args.pretrained_model_name_or_path + ) + transformer.llm.config.use_cache = False + logger.info(f"Enabling gradient checkpointing..") + transformer.llm.gradient_checkpointing_enable() + + elif args.model_family == "sd3": # Stable Diffusion 3 uses a Diffusion transformer. logger.info("Loading Stable Diffusion 3 diffusion transformer..") try: diff --git a/helpers/training/save_hooks.py b/helpers/training/save_hooks.py index 51fb85de..38d66d73 100644 --- a/helpers/training/save_hooks.py +++ b/helpers/training/save_hooks.py @@ -164,6 +164,12 @@ def __init__( elif args.model_family == "smoldit": self.denoiser_class = SmolDiT2DModel self.pipeline_class = SmolDiTPipeline + elif args.model_family == "omnigen": + from OmniGen import OmniGen + from helpers.models.omnigen.pipeline import OmniGenPipeline + + self.denoiser_class = OmniGen + self.pipeline_class = OmniGenPipeline self.denoiser_subdir = "transformer" if args.controlnet: diff --git a/helpers/training/schedulers.py b/helpers/training/schedulers.py index b7636189..83e37ccf 100644 --- a/helpers/training/schedulers.py +++ b/helpers/training/schedulers.py @@ -22,6 +22,10 @@ def load_scheduler_from_args(args): subfolder="scheduler", shift=1 if args.model_family == "sd3" else 3, ) + elif args.model_family == "omnigen": + from OmniGen import OmniGenScheduler + + noise_scheduler = OmniGenScheduler() else: if args.model_family == "legacy": args.rescale_betas_zero_snr = True diff --git a/helpers/training/state_tracker.py b/helpers/training/state_tracker.py index d069a2fb..51c9fb57 100644 --- a/helpers/training/state_tracker.py +++ b/helpers/training/state_tracker.py @@ -117,6 +117,7 @@ def set_model_family(cls, model_type: str): "legacy", "sdxl", "sd3", + "omnigen", "pixart_sigma", "kolors", "smoldit", diff --git a/helpers/training/text_encoding.py b/helpers/training/text_encoding.py index 5261b6fb..9c082482 100644 --- a/helpers/training/text_encoding.py +++ b/helpers/training/text_encoding.py @@ -15,6 +15,19 @@ def import_model_class_from_model_name_or_path( args, subfolder: str = "text_encoder", ): + if args.model_family.lower() == "omnigen": + try: + from helpers.models.omnigen.processor import ( + OmniGenTrainingProcessor as OmniGenProcessor, + ) + except ImportError: + logger.error( + "Could not import Omnigen. Please install omnigen to use this model." + ) + raise + + return OmniGenProcessor + if args.model_family.lower() == "smoldit": from transformers import AutoModelForSeq2SeqLM @@ -51,6 +64,9 @@ def import_model_class_from_model_name_or_path( def get_tokenizers(args): tokenizer_1, tokenizer_2, tokenizer_3 = None, None, None try: + if args.model_family.lower() == "omnigen": + return None, None, None + if args.model_family.lower() == "smoldit": from transformers import AutoTokenizer @@ -225,6 +241,11 @@ def load_tes( f"Loading ChatGLM language model from {text_encoder_path}/{text_encoder_subfolder}.." ) text_encoder_variant = "fp16" + elif args.model_family.lower() == "omnigen": + logger.info(f"Loading OmniGen processor from {text_encoder_path}..") + text_encoder_1 = text_encoder_cls_1.from_pretrained(text_encoder_path) + + return text_encoder_variant, text_encoder_1, text_encoder_2, text_encoder_3 else: logger.info( f"Loading CLIP text encoder from {text_encoder_path}/{text_encoder_subfolder}.." diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index a5e70d44..92a844c4 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -505,12 +505,26 @@ def init_text_tokenizer(self): def init_text_encoder(self, move_to_accelerator: bool = True): self.init_text_tokenizer() + self.text_encoders = [] + self.tokenizers = [] self.text_encoder_1, self.text_encoder_2, self.text_encoder_3 = None, None, None self.text_encoder_cls_1, self.text_encoder_cls_2, self.text_encoder_cls_3 = ( None, None, None, ) + if self.config.model_family.lower() == "omnigen": + # omnigen uses a preprocessor w/ a simple tokeniser, not a text encoder. + from helpers.models.omnigen.processor import ( + OmniGenTrainingProcessor as OmniGenProcessor, + ) + + self.text_encoder_1 = OmniGenProcessor.from_pretrained( + self.config.pretrained_transformer_model_name_or_path + or self.config.pretrained_model_name_or_path + ) + self.text_encoders.append(self.text_encoder_1) + self.tokenizers.append(self.text_encoder_1.text_tokenizer) if self.tokenizer_1 is not None: self.text_encoder_cls_1 = import_model_class_from_model_name_or_path( self.config.text_encoder_path, @@ -555,8 +569,6 @@ def init_text_encoder(self, move_to_accelerator: bool = True): if not move_to_accelerator: logger.debug("Not moving text encoders to accelerator.") return - self.text_encoders = [] - self.tokenizers = [] if self.tokenizer_1 is not None: logger.info("Moving text encoder to GPU.") self.text_encoder_1.to( @@ -579,6 +591,9 @@ def init_text_encoder(self, move_to_accelerator: bool = True): self.tokenizers.append(self.tokenizer_3) self.text_encoders.append(self.text_encoder_3) + if not any(self.text_encoders): + logger.warning("No text encoders loaded. This may cause issues.") + def init_freeze_models(self): # Freeze vae and text_encoders if self.vae is not None: @@ -676,6 +691,10 @@ def init_data_backend(self): self.accelerator.wait_for_everyone() def init_validation_prompts(self): + self.validation_prompts = None + self.validation_shortnames = None + self.validation_negative_prompt_embeds = None + self.validation_negative_pooled_embeds = None if self.accelerator.is_main_process: if self.config.model_family == "flux": ( @@ -688,6 +707,14 @@ def init_validation_prompts(self): args=self.config, embed_cache=StateTracker.get_default_text_embed_cache(), ) + elif self.config.model_family == "omnigen": + ( + self.validation_prompts, + self.validation_shortnames, + ) = prepare_validation_prompt_list( + args=self.config, + embed_cache=StateTracker.get_default_text_embed_cache(), + ) else: ( self.validation_prompts, @@ -698,11 +725,6 @@ def init_validation_prompts(self): args=self.config, embed_cache=StateTracker.get_default_text_embed_cache(), ) - else: - self.validation_prompts = None - self.validation_shortnames = None - self.validation_negative_prompt_embeds = None - self.validation_negative_pooled_embeds = None self.accelerator.wait_for_everyone() def stats_memory_used(self): @@ -869,6 +891,9 @@ def init_trainable_peft_adapter(self): target_modules=target_modules, use_dora=self.config.use_dora, ) + if self.config.model_family == "omnigen": + self.transformer.llm.enable_input_require_grads() + self.transformer.add_adapter(transformer_lora_config) if self.config.init_lora: addkeys, misskeys = load_lora_weights( @@ -974,7 +999,11 @@ def init_post_load_freeze(self): unwrap_model( self.accelerator, self.unet ).enable_gradient_checkpointing() - if self.transformer is not None and self.config.model_family != "smoldit": + if ( + self.transformer is not None + and self.config.model_family != "smoldit" + and hasattr(self.transformer, "enable_gradient_checkpointing") + ): unwrap_model( self.accelerator, self.transformer ).enable_gradient_checkpointing() @@ -1996,6 +2025,29 @@ def model_predict( ), } model_pred = self.transformer(**inputs).sample + elif self.config.model_family == "omnigen": + inputs = { + "x": noisy_latents, + "timestep": timesteps, + "input_ids": ( + batch.get("input_ids").to(self.accelerator.device) + if batch.get("input_ids") is not None + else None + ), + "input_img_latents": ( + batch.get("input_img_latents").to(self.accelerator.device) + if batch.get("input_img_latents") is not None + else None + ), + "input_image_sizes": batch.get("input_image_sizes"), + "attention_mask": batch.get("encoder_attention_mask").to( + self.accelerator.device + ), + "position_ids": batch.get("position_ids").to( + self.accelerator.device + ), + } + model_pred = self.transformer(**inputs)[0] elif self.unet is not None: if self.config.model_family == "legacy": # SD 1.5 or 2.x @@ -2182,7 +2234,26 @@ def train(self): f"Received {bsz} latents, but expected {self.config.train_batch_size}. Processing short batch." ) training_logger.debug(f"Working on batch size: {bsz}") - if self.config.flow_matching: + if self.config.model_family == "omnigen": + # x1 corresponds to your latents + x1 = latents + + # Sample x0 from a standard normal distribution with the same shape as latents + x0 = torch.randn_like(latents) + + # Sample t for each sample in the batch using the specified distribution + u = torch.randn(bsz, device=latents.device) + t = 1 / (1 + torch.exp(-u)) # t ∈ (0, 1) + t = t.to(latents.device, dtype=latents.dtype) + + # Convert t to timesteps compatible with the model (scaled appropriately) + timesteps = t * 999 + timesteps = timesteps.to( + self.accelerator.device, dtype=latents.dtype + ) + timesteps = timesteps.long() + + elif self.config.flow_matching: if ( not self.config.flux_fast_schedule and not self.config.flux_use_beta_schedule @@ -2281,6 +2352,14 @@ def train(self): if self.config.flow_matching: noisy_latents = (1 - sigmas) * latents + sigmas * input_noise + elif self.config.model_family == "omnigen": + # Reshape t to match the dimensions of latents for broadcasting + dims = [1] * (latents.dim() - 1) + t_reshaped = t.view(-1, *dims) + + # Compute noisy_latents (xt) using the Omnigen sampling formula + noisy_latents = t_reshaped * x1 + (1 - t_reshaped) * x0 + else: # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -2291,19 +2370,25 @@ def train(self): dtype=self.config.weight_dtype, ) - encoder_hidden_states = batch["prompt_embeds"].to( - dtype=self.config.weight_dtype, device=self.accelerator.device - ) - training_logger.debug( - f"Encoder hidden states: {encoder_hidden_states.shape}" - ) + encoder_hidden_states = None + if hasattr(batch["prompt_embeds"], "to"): + encoder_hidden_states = batch["prompt_embeds"].to( + dtype=self.config.weight_dtype, + device=self.accelerator.device, + ) + training_logger.debug( + f"Encoder hidden states: {encoder_hidden_states.shape}" + ) add_text_embeds = batch["add_text_embeds"] training_logger.debug( f"Pooled embeds: {add_text_embeds.shape if add_text_embeds is not None else None}" ) # Get the target for loss depending on the prediction type - if self.config.flow_matching: + if ( + self.config.flow_matching + or self.config.model_family == "omnigen" + ): # This is the flow-matching target for vanilla SD3. # If self.config.flow_matching_loss == "diffusion", we will instead use v_prediction (see below) if self.config.flow_matching_loss == "diffusers": @@ -2311,9 +2396,10 @@ def train(self): elif self.config.flow_matching_loss == "compatible": target = noise - latents elif self.config.flow_matching_loss == "sd35": - sigma_reshaped = sigmas.view(-1, 1, 1, 1) # Ensure sigma has the correct shape + sigma_reshaped = sigmas.view( + -1, 1, 1, 1 + ) # Ensure sigma has the correct shape target = (noisy_latents - latents) / sigma_reshaped - elif self.noise_scheduler.config.prediction_type == "epsilon": target = noise elif ( @@ -2420,7 +2506,10 @@ def train(self): parent_loss = None # Compute the per-pixel loss without reducing over spatial dimensions - if self.config.flow_matching: + if ( + self.config.flow_matching + or self.config.model_family == "omnigen" + ): # For flow matching, compute the per-pixel squared differences loss = ( model_pred.float() - target.float() diff --git a/helpers/training/validation.py b/helpers/training/validation.py index b9406a77..65c17817 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -280,6 +280,8 @@ def prepare_validation_prompt_list(args, embed_cache): validation_negative_pooled_embeds, validation_negative_time_ids, ) + elif model_type == "omnigen": + return (validation_prompts, validation_shortnames) else: raise ValueError(f"Unknown model type '{model_type}'") @@ -590,6 +592,10 @@ def _pipeline_cls(self): from helpers.models.smoldit import SmolDiTPipeline return SmolDiTPipeline + elif model_type == "omnigen": + from helpers.models.omnigen.pipeline import OmniGenPipeline + + return OmniGenPipeline else: raise NotImplementedError( f"Model type {model_type} not implemented for validation." @@ -598,6 +604,9 @@ def _pipeline_cls(self): def _gather_prompt_embeds(self, validation_prompt: str): prompt_embeds = {} current_validation_prompt_mask = None + current_validation_prompt_embeds = None + current_validation_pooled_embeds = None + current_validation_time_ids = None if ( StateTracker.get_model_family() == "sdxl" or StateTracker.get_model_family() == "sd3" @@ -680,25 +689,33 @@ def _gather_prompt_embeds(self, validation_prompt: str): # logger.debug( # f"Dtypes: {current_validation_prompt_embeds.dtype}, {self.validation_negative_prompt_embeds.dtype}" # ) + elif StateTracker.get_model_family() == "omnigen": + # no special treatment needed here. + pass else: raise NotImplementedError( f"Model type {StateTracker.get_model_family()} not implemented for validation." ) - current_validation_prompt_embeds = current_validation_prompt_embeds.to( - device=self.inference_device, dtype=self.weight_dtype - ) - self.validation_negative_prompt_embeds = ( - self.validation_negative_prompt_embeds.to( + if current_validation_prompt_embeds is not None: + current_validation_prompt_embeds = current_validation_prompt_embeds.to( device=self.inference_device, dtype=self.weight_dtype ) - ) + if self.validation_negative_prompt_embeds is not None: + self.validation_negative_prompt_embeds = ( + self.validation_negative_prompt_embeds.to( + device=self.inference_device, dtype=self.weight_dtype + ) + ) # when sampling unconditional guidance, you should only zero one or the other prompt, and not both. # we'll assume that the user has a negative prompt, so that the unconditional sampling works. # the positive prompt embed is zeroed out for SDXL at the time of it being placed into the cache. # the embeds are not zeroed out for any other model, including Stable Diffusion 3. - prompt_embeds["prompt_embeds"] = current_validation_prompt_embeds - prompt_embeds["negative_prompt_embeds"] = self.validation_negative_prompt_embeds + if current_validation_prompt_embeds is not None: + prompt_embeds["prompt_embeds"] = current_validation_prompt_embeds + prompt_embeds["negative_prompt_embeds"] = ( + self.validation_negative_prompt_embeds + ) if ( StateTracker.get_model_family() == "pixart_sigma" or StateTracker.get_model_family() == "smoldit" @@ -1063,6 +1080,19 @@ def setup_pipeline(self, validation_type, enable_ema_model: bool = True): text_encoder=self.text_encoder_1, scheduler=self.setup_scheduler(), ) + elif self.args.model_family == "omnigen": + # we use upstream processor to get negative prompting. + from OmniGen import OmniGenProcessor + + self.pipeline = pipeline_cls( + vae=self.vae, + processor=OmniGenProcessor.from_pretrained( + self.args.pretrained_transformer_model_name_or_path + or self.args.pretrained_model_name_or_path + ), + model=self.transformer, + device=self.accelerator.device, + ) else: self.pipeline = pipeline_cls.from_pretrained(**pipeline_kwargs) except Exception as e: @@ -1261,7 +1291,7 @@ def validate_prompt( for key, value in pipeline_kwargs.items(): if hasattr(value, "device"): logger.debug(f"Device for {key}: {value.device}") - for key, value in self.pipeline.components.items(): + for key, value in getattr(self.pipeline, "components", {}).items(): if hasattr(value, "device"): logger.debug(f"Device for {key}: {value.device}") if StateTracker.get_model_family() == "flux": @@ -1281,8 +1311,26 @@ def validate_prompt( pipeline_kwargs["negative_prompt_attention_mask"] = torch.unsqueeze( pipeline_kwargs.pop("negative_mask")[0], dim=0 ).to(device=self.inference_device, dtype=self.weight_dtype) + if StateTracker.get_model_family() == "omnigen": + pipeline_kwargs["prompt"] = prompt + del pipeline_kwargs["negative_prompt"] + del pipeline_kwargs["num_images_per_prompt"] + del pipeline_kwargs["generator"] + del pipeline_kwargs["guidance_rescale"] + if "image" in pipeline_kwargs: + pipeline_kwargs["input_image"] = pipeline_kwargs.pop("image") + pipeline_kwargs["seed"] = self.args.validation_seed + pipeline_kwargs["use_kv_cache"] = ( + False if torch.backends.mps.is_available() else True + ) + pipeline_kwargs["offload_kv_cache"] = ( + False if torch.backends.mps.is_available() else True + ) + logger.debug(f"OmniGen pipeline kwargs: {pipeline_kwargs}") - validation_image_results = self.pipeline(**pipeline_kwargs).images + validation_image_results = self.pipeline(**pipeline_kwargs) + if hasattr(validation_image_results, "images"): + validation_image_results = validation_image_results.images if self.args.controlnet: validation_image_results = self.stitch_conditioning_images( validation_image_results, extra_validation_kwargs["image"] diff --git a/install/apple/poetry.lock b/install/apple/poetry.lock index dde26187..fb299253 100644 --- a/install/apple/poetry.lock +++ b/install/apple/poetry.lock @@ -1559,50 +1559,46 @@ files = [ [[package]] name = "nvidia-cublas-cu12" -version = "12.4.5.8" +version = "12.1.3.1" description = "CUBLAS native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3"}, - {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b"}, - {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-win_amd64.whl", hash = "sha256:5a796786da89203a0657eda402bcdcec6180254a8ac22d72213abc42069522dc"}, + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, ] [[package]] name = "nvidia-cuda-cupti-cu12" -version = "12.4.127" +version = "12.1.105" description = "CUDA profiling tools runtime libs." optional = false python-versions = ">=3" files = [ - {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a"}, - {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb"}, - {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922"}, + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, ] [[package]] name = "nvidia-cuda-nvrtc-cu12" -version = "12.4.127" +version = "12.1.105" description = "NVRTC native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198"}, - {file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338"}, - {file = "nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:a961b2f1d5f17b14867c619ceb99ef6fcec12e46612711bcec78eb05068a60ec"}, + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"}, ] [[package]] name = "nvidia-cuda-runtime-cu12" -version = "12.4.127" +version = "12.1.105" description = "CUDA Runtime native Libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3"}, - {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5"}, - {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:09c2e35f48359752dfa822c09918211844a3d93c100a715d79b59591130c5e1e"}, + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, ] [[package]] @@ -1621,41 +1617,35 @@ nvidia-cublas-cu12 = "*" [[package]] name = "nvidia-cufft-cu12" -version = "11.2.1.3" +version = "11.0.2.54" description = "CUFFT native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399"}, - {file = "nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9"}, - {file = "nvidia_cufft_cu12-11.2.1.3-py3-none-win_amd64.whl", hash = "sha256:d802f4954291101186078ccbe22fc285a902136f974d369540fd4a5333d1440b"}, + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, ] -[package.dependencies] -nvidia-nvjitlink-cu12 = "*" - [[package]] name = "nvidia-curand-cu12" -version = "10.3.5.147" +version = "10.3.2.106" description = "CURAND native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9"}, - {file = "nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b"}, - {file = "nvidia_curand_cu12-10.3.5.147-py3-none-win_amd64.whl", hash = "sha256:f307cc191f96efe9e8f05a87096abc20d08845a841889ef78cb06924437f6771"}, + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"}, ] [[package]] name = "nvidia-cusolver-cu12" -version = "11.6.1.9" +version = "11.4.5.107" description = "CUDA solver native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e"}, - {file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260"}, - {file = "nvidia_cusolver_cu12-11.6.1.9-py3-none-win_amd64.whl", hash = "sha256:e77314c9d7b694fcebc84f58989f3aa4fb4cb442f12ca1a9bde50f5e8f6d1b9c"}, + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"}, ] [package.dependencies] @@ -1665,14 +1655,13 @@ nvidia-nvjitlink-cu12 = "*" [[package]] name = "nvidia-cusparse-cu12" -version = "12.3.1.170" +version = "12.1.0.106" description = "CUSPARSE native runtime libraries" optional = false python-versions = ">=3" files = [ - {file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3"}, - {file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1"}, - {file = "nvidia_cusparse_cu12-12.3.1.170-py3-none-win_amd64.whl", hash = "sha256:9bc90fb087bc7b4c15641521f31c0371e9a612fc2ba12c338d3ae032e6b6797f"}, + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"}, ] [package.dependencies] @@ -1680,12 +1669,13 @@ nvidia-nvjitlink-cu12 = "*" [[package]] name = "nvidia-nccl-cu12" -version = "2.21.5" +version = "2.20.5" description = "NVIDIA Collective Communication Library (NCCL) Runtime" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0"}, + {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01"}, + {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56"}, ] [[package]] @@ -1702,16 +1692,41 @@ files = [ [[package]] name = "nvidia-nvtx-cu12" -version = "12.4.127" +version = "12.1.105" description = "NVIDIA Tools Extension" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3"}, - {file = "nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a"}, - {file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"}, + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, ] +[[package]] +name = "OmniGen" +version = "1.0.3" +description = "OmniGen" +optional = false +python-versions = "*" +files = [] +develop = false + +[package.dependencies] +accelerate = ">=0.26.1" +datasets = "*" +diffusers = ">=0.30.3" +peft = ">=0.9.0" +safetensors = "*" +setuptools = "*" +timm = "*" +torch = "<2.5" +transformers = ">=4.45.2" + +[package.source] +type = "git" +url = "https://github.com/bghira/omnigen" +reference = "dependency-update/peft" +resolved_reference = "93c149a1a5f6526a98bca5e6cff764e0a4790782" + [[package]] name = "open-clip-torch" version = "2.26.1" @@ -3105,111 +3120,111 @@ torchvision = "*" [[package]] name = "tokenizers" -version = "0.19.1" +version = "0.20.1" description = "" optional = false python-versions = ">=3.7" files = [ - {file = "tokenizers-0.19.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:952078130b3d101e05ecfc7fc3640282d74ed26bcf691400f872563fca15ac97"}, - {file = "tokenizers-0.19.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:82c8b8063de6c0468f08e82c4e198763e7b97aabfe573fd4cf7b33930ca4df77"}, - {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f03727225feaf340ceeb7e00604825addef622d551cbd46b7b775ac834c1e1c4"}, - {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:453e4422efdfc9c6b6bf2eae00d5e323f263fff62b29a8c9cd526c5003f3f642"}, - {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:02e81bf089ebf0e7f4df34fa0207519f07e66d8491d963618252f2e0729e0b46"}, - {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b07c538ba956843833fee1190cf769c60dc62e1cf934ed50d77d5502194d63b1"}, - {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e28cab1582e0eec38b1f38c1c1fb2e56bce5dc180acb1724574fc5f47da2a4fe"}, - {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b01afb7193d47439f091cd8f070a1ced347ad0f9144952a30a41836902fe09e"}, - {file = "tokenizers-0.19.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7fb297edec6c6841ab2e4e8f357209519188e4a59b557ea4fafcf4691d1b4c98"}, - {file = "tokenizers-0.19.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2e8a3dd055e515df7054378dc9d6fa8c8c34e1f32777fb9a01fea81496b3f9d3"}, - {file = "tokenizers-0.19.1-cp310-none-win32.whl", hash = "sha256:7ff898780a155ea053f5d934925f3902be2ed1f4d916461e1a93019cc7250837"}, - {file = "tokenizers-0.19.1-cp310-none-win_amd64.whl", hash = "sha256:bea6f9947e9419c2fda21ae6c32871e3d398cba549b93f4a65a2d369662d9403"}, - {file = "tokenizers-0.19.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:5c88d1481f1882c2e53e6bb06491e474e420d9ac7bdff172610c4f9ad3898059"}, - {file = "tokenizers-0.19.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ddf672ed719b4ed82b51499100f5417d7d9f6fb05a65e232249268f35de5ed14"}, - {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:dadc509cc8a9fe460bd274c0e16ac4184d0958117cf026e0ea8b32b438171594"}, - {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dfedf31824ca4915b511b03441784ff640378191918264268e6923da48104acc"}, - {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac11016d0a04aa6487b1513a3a36e7bee7eec0e5d30057c9c0408067345c48d2"}, - {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:76951121890fea8330d3a0df9a954b3f2a37e3ec20e5b0530e9a0044ca2e11fe"}, - {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b342d2ce8fc8d00f376af068e3274e2e8649562e3bc6ae4a67784ded6b99428d"}, - {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d16ff18907f4909dca9b076b9c2d899114dd6abceeb074eca0c93e2353f943aa"}, - {file = "tokenizers-0.19.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:706a37cc5332f85f26efbe2bdc9ef8a9b372b77e4645331a405073e4b3a8c1c6"}, - {file = "tokenizers-0.19.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:16baac68651701364b0289979ecec728546133e8e8fe38f66fe48ad07996b88b"}, - {file = "tokenizers-0.19.1-cp311-none-win32.whl", hash = "sha256:9ed240c56b4403e22b9584ee37d87b8bfa14865134e3e1c3fb4b2c42fafd3256"}, - {file = "tokenizers-0.19.1-cp311-none-win_amd64.whl", hash = "sha256:ad57d59341710b94a7d9dbea13f5c1e7d76fd8d9bcd944a7a6ab0b0da6e0cc66"}, - {file = "tokenizers-0.19.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:621d670e1b1c281a1c9698ed89451395d318802ff88d1fc1accff0867a06f153"}, - {file = "tokenizers-0.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d924204a3dbe50b75630bd16f821ebda6a5f729928df30f582fb5aade90c818a"}, - {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:4f3fefdc0446b1a1e6d81cd4c07088ac015665d2e812f6dbba4a06267d1a2c95"}, - {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9620b78e0b2d52ef07b0d428323fb34e8ea1219c5eac98c2596311f20f1f9266"}, - {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:04ce49e82d100594715ac1b2ce87d1a36e61891a91de774755f743babcd0dd52"}, - {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5c2ff13d157afe413bf7e25789879dd463e5a4abfb529a2d8f8473d8042e28f"}, - {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3174c76efd9d08f836bfccaca7cfec3f4d1c0a4cf3acbc7236ad577cc423c840"}, - {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7c9d5b6c0e7a1e979bec10ff960fae925e947aab95619a6fdb4c1d8ff3708ce3"}, - {file = "tokenizers-0.19.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a179856d1caee06577220ebcfa332af046d576fb73454b8f4d4b0ba8324423ea"}, - {file = "tokenizers-0.19.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:952b80dac1a6492170f8c2429bd11fcaa14377e097d12a1dbe0ef2fb2241e16c"}, - {file = "tokenizers-0.19.1-cp312-none-win32.whl", hash = "sha256:01d62812454c188306755c94755465505836fd616f75067abcae529c35edeb57"}, - {file = "tokenizers-0.19.1-cp312-none-win_amd64.whl", hash = "sha256:b70bfbe3a82d3e3fb2a5e9b22a39f8d1740c96c68b6ace0086b39074f08ab89a"}, - {file = "tokenizers-0.19.1-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:bb9dfe7dae85bc6119d705a76dc068c062b8b575abe3595e3c6276480e67e3f1"}, - {file = "tokenizers-0.19.1-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:1f0360cbea28ea99944ac089c00de7b2e3e1c58f479fb8613b6d8d511ce98267"}, - {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:71e3ec71f0e78780851fef28c2a9babe20270404c921b756d7c532d280349214"}, - {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b82931fa619dbad979c0ee8e54dd5278acc418209cc897e42fac041f5366d626"}, - {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e8ff5b90eabdcdaa19af697885f70fe0b714ce16709cf43d4952f1f85299e73a"}, - {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e742d76ad84acbdb1a8e4694f915fe59ff6edc381c97d6dfdd054954e3478ad4"}, - {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d8c5d59d7b59885eab559d5bc082b2985555a54cda04dda4c65528d90ad252ad"}, - {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b2da5c32ed869bebd990c9420df49813709e953674c0722ff471a116d97b22d"}, - {file = "tokenizers-0.19.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:638e43936cc8b2cbb9f9d8dde0fe5e7e30766a3318d2342999ae27f68fdc9bd6"}, - {file = "tokenizers-0.19.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:78e769eb3b2c79687d9cb0f89ef77223e8e279b75c0a968e637ca7043a84463f"}, - {file = "tokenizers-0.19.1-cp37-none-win32.whl", hash = "sha256:72791f9bb1ca78e3ae525d4782e85272c63faaef9940d92142aa3eb79f3407a3"}, - {file = "tokenizers-0.19.1-cp37-none-win_amd64.whl", hash = "sha256:f3bbb7a0c5fcb692950b041ae11067ac54826204318922da754f908d95619fbc"}, - {file = "tokenizers-0.19.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:07f9295349bbbcedae8cefdbcfa7f686aa420be8aca5d4f7d1ae6016c128c0c5"}, - {file = "tokenizers-0.19.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:10a707cc6c4b6b183ec5dbfc5c34f3064e18cf62b4a938cb41699e33a99e03c1"}, - {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6309271f57b397aa0aff0cbbe632ca9d70430839ca3178bf0f06f825924eca22"}, - {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ad23d37d68cf00d54af184586d79b84075ada495e7c5c0f601f051b162112dc"}, - {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:427c4f0f3df9109314d4f75b8d1f65d9477033e67ffaec4bca53293d3aca286d"}, - {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e83a31c9cf181a0a3ef0abad2b5f6b43399faf5da7e696196ddd110d332519ee"}, - {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c27b99889bd58b7e301468c0838c5ed75e60c66df0d4db80c08f43462f82e0d3"}, - {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bac0b0eb952412b0b196ca7a40e7dce4ed6f6926489313414010f2e6b9ec2adf"}, - {file = "tokenizers-0.19.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8a6298bde623725ca31c9035a04bf2ef63208d266acd2bed8c2cb7d2b7d53ce6"}, - {file = "tokenizers-0.19.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:08a44864e42fa6d7d76d7be4bec62c9982f6f6248b4aa42f7302aa01e0abfd26"}, - {file = "tokenizers-0.19.1-cp38-none-win32.whl", hash = "sha256:1de5bc8652252d9357a666e609cb1453d4f8e160eb1fb2830ee369dd658e8975"}, - {file = "tokenizers-0.19.1-cp38-none-win_amd64.whl", hash = "sha256:0bcce02bf1ad9882345b34d5bd25ed4949a480cf0e656bbd468f4d8986f7a3f1"}, - {file = "tokenizers-0.19.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:0b9394bd204842a2a1fd37fe29935353742be4a3460b6ccbaefa93f58a8df43d"}, - {file = "tokenizers-0.19.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4692ab92f91b87769d950ca14dbb61f8a9ef36a62f94bad6c82cc84a51f76f6a"}, - {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6258c2ef6f06259f70a682491c78561d492e885adeaf9f64f5389f78aa49a051"}, - {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c85cf76561fbd01e0d9ea2d1cbe711a65400092bc52b5242b16cfd22e51f0c58"}, - {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:670b802d4d82bbbb832ddb0d41df7015b3e549714c0e77f9bed3e74d42400fbe"}, - {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:85aa3ab4b03d5e99fdd31660872249df5e855334b6c333e0bc13032ff4469c4a"}, - {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cbf001afbbed111a79ca47d75941e9e5361297a87d186cbfc11ed45e30b5daba"}, - {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4c89aa46c269e4e70c4d4f9d6bc644fcc39bb409cb2a81227923404dd6f5227"}, - {file = "tokenizers-0.19.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:39c1ec76ea1027438fafe16ecb0fb84795e62e9d643444c1090179e63808c69d"}, - {file = "tokenizers-0.19.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c2a0d47a89b48d7daa241e004e71fb5a50533718897a4cd6235cb846d511a478"}, - {file = "tokenizers-0.19.1-cp39-none-win32.whl", hash = "sha256:61b7fe8886f2e104d4caf9218b157b106207e0f2a4905c9c7ac98890688aabeb"}, - {file = "tokenizers-0.19.1-cp39-none-win_amd64.whl", hash = "sha256:f97660f6c43efd3e0bfd3f2e3e5615bf215680bad6ee3d469df6454b8c6e8256"}, - {file = "tokenizers-0.19.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3b11853f17b54c2fe47742c56d8a33bf49ce31caf531e87ac0d7d13d327c9334"}, - {file = "tokenizers-0.19.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d26194ef6c13302f446d39972aaa36a1dda6450bc8949f5eb4c27f51191375bd"}, - {file = "tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:e8d1ed93beda54bbd6131a2cb363a576eac746d5c26ba5b7556bc6f964425594"}, - {file = "tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca407133536f19bdec44b3da117ef0d12e43f6d4b56ac4c765f37eca501c7bda"}, - {file = "tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce05fde79d2bc2e46ac08aacbc142bead21614d937aac950be88dc79f9db9022"}, - {file = "tokenizers-0.19.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:35583cd46d16f07c054efd18b5d46af4a2f070a2dd0a47914e66f3ff5efb2b1e"}, - {file = "tokenizers-0.19.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:43350270bfc16b06ad3f6f07eab21f089adb835544417afda0f83256a8bf8b75"}, - {file = "tokenizers-0.19.1-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b4399b59d1af5645bcee2072a463318114c39b8547437a7c2d6a186a1b5a0e2d"}, - {file = "tokenizers-0.19.1-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6852c5b2a853b8b0ddc5993cd4f33bfffdca4fcc5d52f89dd4b8eada99379285"}, - {file = "tokenizers-0.19.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bcd266ae85c3d39df2f7e7d0e07f6c41a55e9a3123bb11f854412952deacd828"}, - {file = "tokenizers-0.19.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ecb2651956eea2aa0a2d099434134b1b68f1c31f9a5084d6d53f08ed43d45ff2"}, - {file = "tokenizers-0.19.1-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:b279ab506ec4445166ac476fb4d3cc383accde1ea152998509a94d82547c8e2a"}, - {file = "tokenizers-0.19.1-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:89183e55fb86e61d848ff83753f64cded119f5d6e1f553d14ffee3700d0a4a49"}, - {file = "tokenizers-0.19.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b2edbc75744235eea94d595a8b70fe279dd42f3296f76d5a86dde1d46e35f574"}, - {file = "tokenizers-0.19.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:0e64bfde9a723274e9a71630c3e9494ed7b4c0f76a1faacf7fe294cd26f7ae7c"}, - {file = "tokenizers-0.19.1-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0b5ca92bfa717759c052e345770792d02d1f43b06f9e790ca0a1db62838816f3"}, - {file = "tokenizers-0.19.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f8a20266e695ec9d7a946a019c1d5ca4eddb6613d4f466888eee04f16eedb85"}, - {file = "tokenizers-0.19.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63c38f45d8f2a2ec0f3a20073cccb335b9f99f73b3c69483cd52ebc75369d8a1"}, - {file = "tokenizers-0.19.1-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:dd26e3afe8a7b61422df3176e06664503d3f5973b94f45d5c45987e1cb711876"}, - {file = "tokenizers-0.19.1-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:eddd5783a4a6309ce23432353cdb36220e25cbb779bfa9122320666508b44b88"}, - {file = "tokenizers-0.19.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:56ae39d4036b753994476a1b935584071093b55c7a72e3b8288e68c313ca26e7"}, - {file = "tokenizers-0.19.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f9939ca7e58c2758c01b40324a59c034ce0cebad18e0d4563a9b1beab3018243"}, - {file = "tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6c330c0eb815d212893c67a032e9dc1b38a803eccb32f3e8172c19cc69fbb439"}, - {file = "tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec11802450a2487cdf0e634b750a04cbdc1c4d066b97d94ce7dd2cb51ebb325b"}, - {file = "tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2b718f316b596f36e1dae097a7d5b91fc5b85e90bf08b01ff139bd8953b25af"}, - {file = "tokenizers-0.19.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:ed69af290c2b65169f0ba9034d1dc39a5db9459b32f1dd8b5f3f32a3fcf06eab"}, - {file = "tokenizers-0.19.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f8a9c828277133af13f3859d1b6bf1c3cb6e9e1637df0e45312e6b7c2e622b1f"}, - {file = "tokenizers-0.19.1.tar.gz", hash = "sha256:ee59e6680ed0fdbe6b724cf38bd70400a0c1dd623b07ac729087270caeac88e3"}, + {file = "tokenizers-0.20.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:439261da7c0a5c88bda97acb284d49fbdaf67e9d3b623c0bfd107512d22787a9"}, + {file = "tokenizers-0.20.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:03dae629d99068b1ea5416d50de0fea13008f04129cc79af77a2a6392792d93c"}, + {file = "tokenizers-0.20.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b61f561f329ffe4b28367798b89d60c4abf3f815d37413b6352bc6412a359867"}, + {file = "tokenizers-0.20.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ec870fce1ee5248a10be69f7a8408a234d6f2109f8ea827b4f7ecdbf08c9fd15"}, + {file = "tokenizers-0.20.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d388d1ea8b7447da784e32e3b86a75cce55887e3b22b31c19d0b186b1c677800"}, + {file = "tokenizers-0.20.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:299c85c1d21135bc01542237979bf25c32efa0d66595dd0069ae259b97fb2dbe"}, + {file = "tokenizers-0.20.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e96f6c14c9752bb82145636b614d5a78e9cde95edfbe0a85dad0dd5ddd6ec95c"}, + {file = "tokenizers-0.20.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc9e95ad49c932b80abfbfeaf63b155761e695ad9f8a58c52a47d962d76e310f"}, + {file = "tokenizers-0.20.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:f22dee205329a636148c325921c73cf3e412e87d31f4d9c3153b302a0200057b"}, + {file = "tokenizers-0.20.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a2ffd9a8895575ac636d44500c66dffaef133823b6b25067604fa73bbc5ec09d"}, + {file = "tokenizers-0.20.1-cp310-none-win32.whl", hash = "sha256:2847843c53f445e0f19ea842a4e48b89dd0db4e62ba6e1e47a2749d6ec11f50d"}, + {file = "tokenizers-0.20.1-cp310-none-win_amd64.whl", hash = "sha256:f9aa93eacd865f2798b9e62f7ce4533cfff4f5fbd50c02926a78e81c74e432cd"}, + {file = "tokenizers-0.20.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:4a717dcb08f2dabbf27ae4b6b20cbbb2ad7ed78ce05a829fae100ff4b3c7ff15"}, + {file = "tokenizers-0.20.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3f84dad1ff1863c648d80628b1b55353d16303431283e4efbb6ab1af56a75832"}, + {file = "tokenizers-0.20.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:929c8f3afa16a5130a81ab5079c589226273ec618949cce79b46d96e59a84f61"}, + {file = "tokenizers-0.20.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d10766473954397e2d370f215ebed1cc46dcf6fd3906a2a116aa1d6219bfedc3"}, + {file = "tokenizers-0.20.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9300fac73ddc7e4b0330acbdda4efaabf74929a4a61e119a32a181f534a11b47"}, + {file = "tokenizers-0.20.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0ecaf7b0e39caeb1aa6dd6e0975c405716c82c1312b55ac4f716ef563a906969"}, + {file = "tokenizers-0.20.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5170be9ec942f3d1d317817ced8d749b3e1202670865e4fd465e35d8c259de83"}, + {file = "tokenizers-0.20.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef3f1ae08fa9aea5891cbd69df29913e11d3841798e0bfb1ff78b78e4e7ea0a4"}, + {file = "tokenizers-0.20.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ee86d4095d3542d73579e953c2e5e07d9321af2ffea6ecc097d16d538a2dea16"}, + {file = "tokenizers-0.20.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:86dcd08da163912e17b27bbaba5efdc71b4fbffb841530fdb74c5707f3c49216"}, + {file = "tokenizers-0.20.1-cp311-none-win32.whl", hash = "sha256:9af2dc4ee97d037bc6b05fa4429ddc87532c706316c5e11ce2f0596dfcfa77af"}, + {file = "tokenizers-0.20.1-cp311-none-win_amd64.whl", hash = "sha256:899152a78b095559c287b4c6d0099469573bb2055347bb8154db106651296f39"}, + {file = "tokenizers-0.20.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:407ab666b38e02228fa785e81f7cf79ef929f104bcccf68a64525a54a93ceac9"}, + {file = "tokenizers-0.20.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2f13a2d16032ebc8bd812eb8099b035ac65887d8f0c207261472803b9633cf3e"}, + {file = "tokenizers-0.20.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e98eee4dca22849fbb56a80acaa899eec5b72055d79637dd6aa15d5e4b8628c9"}, + {file = "tokenizers-0.20.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:47c1bcdd61e61136087459cb9e0b069ff23b5568b008265e5cbc927eae3387ce"}, + {file = "tokenizers-0.20.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:128c1110e950534426e2274837fc06b118ab5f2fa61c3436e60e0aada0ccfd67"}, + {file = "tokenizers-0.20.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e2e2d47a819d2954f2c1cd0ad51bb58ffac6f53a872d5d82d65d79bf76b9896d"}, + {file = "tokenizers-0.20.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bdd67a0e3503a9a7cf8bc5a4a49cdde5fa5bada09a51e4c7e1c73900297539bd"}, + {file = "tokenizers-0.20.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:689b93d2e26d04da337ac407acec8b5d081d8d135e3e5066a88edd5bdb5aff89"}, + {file = "tokenizers-0.20.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0c6a796ddcd9a19ad13cf146997cd5895a421fe6aec8fd970d69f9117bddb45c"}, + {file = "tokenizers-0.20.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3ea919687aa7001a8ff1ba36ac64f165c4e89035f57998fa6cedcfd877be619d"}, + {file = "tokenizers-0.20.1-cp312-none-win32.whl", hash = "sha256:6d3ac5c1f48358ffe20086bf065e843c0d0a9fce0d7f0f45d5f2f9fba3609ca5"}, + {file = "tokenizers-0.20.1-cp312-none-win_amd64.whl", hash = "sha256:b0874481aea54a178f2bccc45aa2d0c99cd3f79143a0948af6a9a21dcc49173b"}, + {file = "tokenizers-0.20.1-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:96af92e833bd44760fb17f23f402e07a66339c1dcbe17d79a9b55bb0cc4f038e"}, + {file = "tokenizers-0.20.1-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:65f34e5b731a262dfa562820818533c38ce32a45864437f3d9c82f26c139ca7f"}, + {file = "tokenizers-0.20.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17f98fccb5c12ab1ce1f471731a9cd86df5d4bd2cf2880c5a66b229802d96145"}, + {file = "tokenizers-0.20.1-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b8c0fc3542cf9370bf92c932eb71bdeb33d2d4aeeb4126d9fd567b60bd04cb30"}, + {file = "tokenizers-0.20.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4b39356df4575d37f9b187bb623aab5abb7b62c8cb702867a1768002f814800c"}, + {file = "tokenizers-0.20.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bfdad27b0e50544f6b838895a373db6114b85112ba5c0cefadffa78d6daae563"}, + {file = "tokenizers-0.20.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:094663dd0e85ee2e573126918747bdb40044a848fde388efb5b09d57bc74c680"}, + {file = "tokenizers-0.20.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14e4cf033a2aa207d7ac790e91adca598b679999710a632c4a494aab0fc3a1b2"}, + {file = "tokenizers-0.20.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:9310951c92c9fb91660de0c19a923c432f110dbfad1a2d429fbc44fa956bf64f"}, + {file = "tokenizers-0.20.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:05e41e302c315bd2ed86c02e917bf03a6cf7d2f652c9cee1a0eb0d0f1ca0d32c"}, + {file = "tokenizers-0.20.1-cp37-none-win32.whl", hash = "sha256:212231ab7dfcdc879baf4892ca87c726259fa7c887e1688e3f3cead384d8c305"}, + {file = "tokenizers-0.20.1-cp37-none-win_amd64.whl", hash = "sha256:896195eb9dfdc85c8c052e29947169c1fcbe75a254c4b5792cdbd451587bce85"}, + {file = "tokenizers-0.20.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:741fb22788482d09d68e73ece1495cfc6d9b29a06c37b3df90564a9cfa688e6d"}, + {file = "tokenizers-0.20.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:10be14ebd8082086a342d969e17fc2d6edc856c59dbdbddd25f158fa40eaf043"}, + {file = "tokenizers-0.20.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:514cf279b22fa1ae0bc08e143458c74ad3b56cd078b319464959685a35c53d5e"}, + {file = "tokenizers-0.20.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a647c5b7cb896d6430cf3e01b4e9a2d77f719c84cefcef825d404830c2071da2"}, + {file = "tokenizers-0.20.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7cdf379219e1e1dd432091058dab325a2e6235ebb23e0aec8d0508567c90cd01"}, + {file = "tokenizers-0.20.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ba72260449e16c4c2f6f3252823b059fbf2d31b32617e582003f2b18b415c39"}, + {file = "tokenizers-0.20.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:910b96ed87316e4277b23c7bcaf667ce849c7cc379a453fa179e7e09290eeb25"}, + {file = "tokenizers-0.20.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e53975a6694428a0586534cc1354b2408d4e010a3103117f617cbb550299797c"}, + {file = "tokenizers-0.20.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:07c4b7be58da142b0730cc4e5fd66bb7bf6f57f4986ddda73833cd39efef8a01"}, + {file = "tokenizers-0.20.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b605c540753e62199bf15cf69c333e934077ef2350262af2ccada46026f83d1c"}, + {file = "tokenizers-0.20.1-cp38-none-win32.whl", hash = "sha256:88b3bc76ab4db1ab95ead623d49c95205411e26302cf9f74203e762ac7e85685"}, + {file = "tokenizers-0.20.1-cp38-none-win_amd64.whl", hash = "sha256:d412a74cf5b3f68a90c615611a5aa4478bb303d1c65961d22db45001df68afcb"}, + {file = "tokenizers-0.20.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a25dcb2f41a0a6aac31999e6c96a75e9152fa0127af8ece46c2f784f23b8197a"}, + {file = "tokenizers-0.20.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a12c3cebb8c92e9c35a23ab10d3852aee522f385c28d0b4fe48c0b7527d59762"}, + {file = "tokenizers-0.20.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:02e18da58cf115b7c40de973609c35bde95856012ba42a41ee919c77935af251"}, + {file = "tokenizers-0.20.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f326a1ac51ae909b9760e34671c26cd0dfe15662f447302a9d5bb2d872bab8ab"}, + {file = "tokenizers-0.20.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0b4872647ea6f25224e2833b044b0b19084e39400e8ead3cfe751238b0802140"}, + {file = "tokenizers-0.20.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce6238a3311bb8e4c15b12600927d35c267b92a52c881ef5717a900ca14793f7"}, + {file = "tokenizers-0.20.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57b7a8880b208866508b06ce365dc631e7a2472a3faa24daa430d046fb56c885"}, + {file = "tokenizers-0.20.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a908c69c2897a68f412aa05ba38bfa87a02980df70f5a72fa8490479308b1f2d"}, + {file = "tokenizers-0.20.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:da1001aa46f4490099c82e2facc4fbc06a6a32bf7de3918ba798010954b775e0"}, + {file = "tokenizers-0.20.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:42c097390e2f0ed0a5c5d569e6669dd4e9fff7b31c6a5ce6e9c66a61687197de"}, + {file = "tokenizers-0.20.1-cp39-none-win32.whl", hash = "sha256:3d4d218573a3d8b121a1f8c801029d70444ffb6d8f129d4cca1c7b672ee4a24c"}, + {file = "tokenizers-0.20.1-cp39-none-win_amd64.whl", hash = "sha256:37d1e6f616c84fceefa7c6484a01df05caf1e207669121c66213cb5b2911d653"}, + {file = "tokenizers-0.20.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:48689da7a395df41114f516208d6550e3e905e1239cc5ad386686d9358e9cef0"}, + {file = "tokenizers-0.20.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:712f90ea33f9bd2586b4a90d697c26d56d0a22fd3c91104c5858c4b5b6489a79"}, + {file = "tokenizers-0.20.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:359eceb6a620c965988fc559cebc0a98db26713758ec4df43fb76d41486a8ed5"}, + {file = "tokenizers-0.20.1-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d3caf244ce89d24c87545aafc3448be15870096e796c703a0d68547187192e1"}, + {file = "tokenizers-0.20.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:03b03cf8b9a32254b1bf8a305fb95c6daf1baae0c1f93b27f2b08c9759f41dee"}, + {file = "tokenizers-0.20.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:218e5a3561561ea0f0ef1559c6d95b825308dbec23fb55b70b92589e7ff2e1e8"}, + {file = "tokenizers-0.20.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f40df5e0294a95131cc5f0e0eb91fe86d88837abfbee46b9b3610b09860195a7"}, + {file = "tokenizers-0.20.1-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:08aaa0d72bb65058e8c4b0455f61b840b156c557e2aca57627056624c3a93976"}, + {file = "tokenizers-0.20.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:998700177b45f70afeb206ad22c08d9e5f3a80639dae1032bf41e8cbc4dada4b"}, + {file = "tokenizers-0.20.1-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62f7fbd3c2c38b179556d879edae442b45f68312019c3a6013e56c3947a4e648"}, + {file = "tokenizers-0.20.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31e87fca4f6bbf5cc67481b562147fe932f73d5602734de7dd18a8f2eee9c6dd"}, + {file = "tokenizers-0.20.1-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:956f21d359ae29dd51ca5726d2c9a44ffafa041c623f5aa33749da87cfa809b9"}, + {file = "tokenizers-0.20.1-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:1fbbaf17a393c78d8aedb6a334097c91cb4119a9ced4764ab8cfdc8d254dc9f9"}, + {file = "tokenizers-0.20.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:ebe63e31f9c1a970c53866d814e35ec2ec26fda03097c486f82f3891cee60830"}, + {file = "tokenizers-0.20.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:81970b80b8ac126910295f8aab2d7ef962009ea39e0d86d304769493f69aaa1e"}, + {file = "tokenizers-0.20.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:130e35e76f9337ed6c31be386e75d4925ea807055acf18ca1a9b0eec03d8fe23"}, + {file = "tokenizers-0.20.1-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cd28a8614f5c82a54ab2463554e84ad79526c5184cf4573bbac2efbbbcead457"}, + {file = "tokenizers-0.20.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9041ee665d0fa7f5c4ccf0f81f5e6b7087f797f85b143c094126fc2611fec9d0"}, + {file = "tokenizers-0.20.1-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:62eb9daea2a2c06bcd8113a5824af8ef8ee7405d3a71123ba4d52c79bb3d9f1a"}, + {file = "tokenizers-0.20.1-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f861889707b54a9ab1204030b65fd6c22bdd4a95205deec7994dc22a8baa2ea4"}, + {file = "tokenizers-0.20.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:89d5c337d74ea6e5e7dc8af124cf177be843bbb9ca6e58c01f75ea103c12c8a9"}, + {file = "tokenizers-0.20.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:0b7f515c83397e73292accdbbbedc62264e070bae9682f06061e2ddce67cacaf"}, + {file = "tokenizers-0.20.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e0305fc1ec6b1e5052d30d9c1d5c807081a7bd0cae46a33d03117082e91908c"}, + {file = "tokenizers-0.20.1-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5dc611e6ac0fa00a41de19c3bf6391a05ea201d2d22b757d63f5491ec0e67faa"}, + {file = "tokenizers-0.20.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5ffe0d7f7bfcfa3b2585776ecf11da2e01c317027c8573c78ebcb8985279e23"}, + {file = "tokenizers-0.20.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:e7edb8ec12c100d5458d15b1e47c0eb30ad606a05641f19af7563bc3d1608c14"}, + {file = "tokenizers-0.20.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:de291633fb9303555793cc544d4a86e858da529b7d0b752bcaf721ae1d74b2c9"}, + {file = "tokenizers-0.20.1.tar.gz", hash = "sha256:84edcc7cdeeee45ceedb65d518fffb77aec69311c9c8e30f77ad84da3025f002"}, ] [package.dependencies] @@ -3233,28 +3248,31 @@ files = [ [[package]] name = "torch" -version = "2.5.0" +version = "2.4.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" files = [ - {file = "torch-2.5.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:7f179373a047b947dec448243f4e6598a1c960fa3bb978a9a7eecd529fbc363f"}, - {file = "torch-2.5.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:15fbc95e38d330e5b0ef1593b7bc0a19f30e5bdad76895a5cffa1a6a044235e9"}, - {file = "torch-2.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:f499212f1cffea5d587e5f06144630ed9aa9c399bba12ec8905798d833bd1404"}, - {file = "torch-2.5.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:c54db1fade17287aabbeed685d8e8ab3a56fea9dd8d46e71ced2da367f09a49f"}, - {file = "torch-2.5.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:499a68a756d3b30d10f7e0f6214dc3767b130b797265db3b1c02e9094e2a07be"}, - {file = "torch-2.5.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:9f3df8138a1126a851440b7d5a4869bfb7c9cc43563d64fd9d96d0465b581024"}, - {file = "torch-2.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:b81da3bdb58c9de29d0e1361e52f12fcf10a89673f17a11a5c6c7da1cb1a8376"}, - {file = "torch-2.5.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:ba135923295d564355326dc409b6b7f5bd6edc80f764cdaef1fb0a1b23ff2f9c"}, - {file = "torch-2.5.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2dd40c885a05ef7fe29356cca81be1435a893096ceb984441d6e2c27aff8c6f4"}, - {file = "torch-2.5.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:bc52d603d87fe1da24439c0d5fdbbb14e0ae4874451d53f0120ffb1f6c192727"}, - {file = "torch-2.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:ea718746469246cc63b3353afd75698a288344adb55e29b7f814a5d3c0a7c78d"}, - {file = "torch-2.5.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:6de1fd253e27e7f01f05cd7c37929ae521ca23ca4620cfc7c485299941679112"}, - {file = "torch-2.5.0-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:83dcf518685db20912b71fc49cbddcc8849438cdb0e9dcc919b02a849e2cd9e8"}, - {file = "torch-2.5.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:65e0a60894435608334d68c8811e55fd8f73e5bf8ee6f9ccedb0064486a7b418"}, - {file = "torch-2.5.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:38c21ff1bd39f076d72ab06e3c88c2ea6874f2e6f235c9450816b6c8e7627094"}, - {file = "torch-2.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:ce4baeba9804da5a346e210b3b70826f5811330c343e4fe1582200359ee77fe5"}, - {file = "torch-2.5.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:03e53f577a96e4d41aca472da8faa40e55df89d2273664af390ce1f570e885bd"}, + {file = "torch-2.4.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:362f82e23a4cd46341daabb76fba08f04cd646df9bfaf5da50af97cb60ca4971"}, + {file = "torch-2.4.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:e8ac1985c3ff0f60d85b991954cfc2cc25f79c84545aead422763148ed2759e3"}, + {file = "torch-2.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:91e326e2ccfb1496e3bee58f70ef605aeb27bd26be07ba64f37dcaac3d070ada"}, + {file = "torch-2.4.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:d36a8ef100f5bff3e9c3cea934b9e0d7ea277cb8210c7152d34a9a6c5830eadd"}, + {file = "torch-2.4.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:0b5f88afdfa05a335d80351e3cea57d38e578c8689f751d35e0ff36bce872113"}, + {file = "torch-2.4.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:ef503165f2341942bfdf2bd520152f19540d0c0e34961232f134dc59ad435be8"}, + {file = "torch-2.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:092e7c2280c860eff762ac08c4bdcd53d701677851670695e0c22d6d345b269c"}, + {file = "torch-2.4.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:ddddbd8b066e743934a4200b3d54267a46db02106876d21cf31f7da7a96f98ea"}, + {file = "torch-2.4.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:fdc4fe11db3eb93c1115d3e973a27ac7c1a8318af8934ffa36b0370efe28e042"}, + {file = "torch-2.4.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:18835374f599207a9e82c262153c20ddf42ea49bc76b6eadad8e5f49729f6e4d"}, + {file = "torch-2.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:ebea70ff30544fc021d441ce6b219a88b67524f01170b1c538d7d3ebb5e7f56c"}, + {file = "torch-2.4.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:72b484d5b6cec1a735bf3fa5a1c4883d01748698c5e9cfdbeb4ffab7c7987e0d"}, + {file = "torch-2.4.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:c99e1db4bf0c5347107845d715b4aa1097e601bdc36343d758963055e9599d93"}, + {file = "torch-2.4.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:b57f07e92858db78c5b72857b4f0b33a65b00dc5d68e7948a8494b0314efb880"}, + {file = "torch-2.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:f18197f3f7c15cde2115892b64f17c80dbf01ed72b008020e7da339902742cf6"}, + {file = "torch-2.4.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:5fc1d4d7ed265ef853579caf272686d1ed87cebdcd04f2a498f800ffc53dab71"}, + {file = "torch-2.4.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:40f6d3fe3bae74efcf08cb7f8295eaddd8a838ce89e9d26929d4edd6d5e4329d"}, + {file = "torch-2.4.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:c9299c16c9743001ecef515536ac45900247f4338ecdf70746f2461f9e4831db"}, + {file = "torch-2.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:6bce130f2cd2d52ba4e2c6ada461808de7e5eccbac692525337cfb4c19421846"}, + {file = "torch-2.4.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:a38de2803ee6050309aac032676536c3d3b6a9804248537e38e098d0e14817ec"}, ] [package.dependencies] @@ -3262,26 +3280,25 @@ filelock = "*" fsspec = "*" jinja2 = "*" networkx = "*" -nvidia-cublas-cu12 = {version = "12.4.5.8", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-cupti-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-nvrtc-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-runtime-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cudnn-cu12 = {version = "9.1.0.70", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cufft-cu12 = {version = "11.2.1.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-curand-cu12 = {version = "10.3.5.147", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusolver-cu12 = {version = "11.6.1.9", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusparse-cu12 = {version = "12.3.1.170", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nccl-cu12 = {version = "2.21.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nvjitlink-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nvtx-cu12 = {version = "12.4.127", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -setuptools = {version = "*", markers = "python_version >= \"3.12\""} -sympy = {version = "1.13.1", markers = "python_version >= \"3.9\""} -triton = {version = "3.1.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\""} +nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nccl-cu12 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +setuptools = "*" +sympy = "*" +triton = {version = "3.0.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\""} typing-extensions = ">=4.8.0" [package.extras] opt-einsum = ["opt-einsum (>=3.3)"] -optree = ["optree (>=0.12.0)"] +optree = ["optree (>=0.11.0)"] [[package]] name = "torch-optimi" @@ -3325,35 +3342,35 @@ dev = ["bitsandbytes", "expecttest", "fire", "hypothesis", "matplotlib", "ninja" [[package]] name = "torchaudio" -version = "2.5.0" +version = "2.4.1" description = "An audio package for PyTorch" optional = false python-versions = "*" files = [ - {file = "torchaudio-2.5.0-1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:9dfeedcef3e43010f3ec2d804c8f62fe49ab09ef1c19e6736325939661a293bd"}, - {file = "torchaudio-2.5.0-1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:c35201efd28244152d6edbde92775c10f39f5a5d9346202f07b1554dc78d25a2"}, - {file = "torchaudio-2.5.0-1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:593258f33de1fa16ebe5718c9717daf3695460d48a0188194a5c574a710838cb"}, - {file = "torchaudio-2.5.0-1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:aabf8c4ce919c2e24ace49641ea429360018816371a3d470427fc02ab11156c5"}, - {file = "torchaudio-2.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ab69bdfeb4434159e168a4a2c1618d1d65a5a14a91d17d21256ea960f33405fd"}, - {file = "torchaudio-2.5.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:70fe426ecae408e9a7019cfcbcd4e81b6f084920ffffac2520f1d28a23e145fe"}, - {file = "torchaudio-2.5.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:7e2f596b0e8909924cdf46acc579481132f5c0341824957f1cca8385c61db5b5"}, - {file = "torchaudio-2.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:763bf99b2def4681b1e760883849e0e85fa172eac4a12d1870380d5b7d1149c2"}, - {file = "torchaudio-2.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:982ec0494a27c7f3e7e68c91cc92e6e1ad6f86fedeeec627096051309632b149"}, - {file = "torchaudio-2.5.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:470dd171a2e44a4c1aa89c5cdd4a0ba9f99650b68228f3a63b20ceaafd553567"}, - {file = "torchaudio-2.5.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:ee844aa12fa25f521f64ec86c835acf925d194ed4fb66a9b442436f80b39e8da"}, - {file = "torchaudio-2.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:168d9d2a8216a5f1888713c13914edf410d2e28d39c6bfd9e1211baf6f2c76d8"}, - {file = "torchaudio-2.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c835da771701b06fbe8e19ce643d5e587fd385e5f4d8f140551ce04900b1b96b"}, - {file = "torchaudio-2.5.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:3009ca3e095825911917ab865b2ea48abbf3e66f948c4ffd64916fe6e476bfec"}, - {file = "torchaudio-2.5.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:1971302bfc321be5ea18ba60dbc3ffdc6ae53cb47bb6427db25d5b69e4e932ec"}, - {file = "torchaudio-2.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:e98ad102e0e574a397759078bc9809afc05de6e6f8ac0cb26d2245e0d3817adf"}, - {file = "torchaudio-2.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:85573f66dc09497d282bbf406e8f2b03e5929eb3bdc1f76a9d9bc46644d407b1"}, - {file = "torchaudio-2.5.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:877706453e7329844382d06ffee31cb11b602c6991afefb594086ecdd739a5cf"}, - {file = "torchaudio-2.5.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:74a3de13ec0a5024999aec75b3fafa97891d617ce5566818d3094857d1e0229d"}, - {file = "torchaudio-2.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:242b3c9b1abf212b3a1f28eae9814db4daf4507f74b63c7ec7161d35b3c37147"}, + {file = "torchaudio-2.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:661909751909340b24f637410dfec02a888867816c3db19ed4f4102ae105244a"}, + {file = "torchaudio-2.4.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:bfc234cef1d03092ea27440fb79e486722ccb41cff94ebaf9d5a1082436395fe"}, + {file = "torchaudio-2.4.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:54431179d9a9ccf3feeae98aace07d89fae9fd728e2bc8656efbd70e7edcc6f8"}, + {file = "torchaudio-2.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:dec97872215c3122b7718ec47ac63e143565c3cced06444d0225e98bf4dd4b5f"}, + {file = "torchaudio-2.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:60af1531815d22659e5412ea401bed552a16c389938c49664e446e4cfd5ddc06"}, + {file = "torchaudio-2.4.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:95a0968569f7f4455bfd242bfcd489ec47ad37d2ba0f3d9f738cd1128a5f775c"}, + {file = "torchaudio-2.4.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:7640aaffb2056e12f2906187b03a22228a0908c87d0295fddf4b0b92334a290b"}, + {file = "torchaudio-2.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:3c08b42a0c296c8eeee6c533bcae5cfbc0ceae86a34f24fe6bbbb5faa7a7bea1"}, + {file = "torchaudio-2.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:953946cf610ffd57bb3fdd228effa2112fa51c5dfe36a96611effc9074a3d3be"}, + {file = "torchaudio-2.4.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:1796a8961decb522c47daab0fbe27c057d6d143ee22bb6ae0d5eb9b2a038c7b6"}, + {file = "torchaudio-2.4.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:5b62fc7b16ed708b0c07d4393137797e92f63fc3bd5705607d97ba6a9a7cf3f0"}, + {file = "torchaudio-2.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:d721b186aae7bd8752c9ad95213f5d650926597bb9060728dfe476986a1ff570"}, + {file = "torchaudio-2.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4ea0fd00142fe795c75bcc20a303981b56f2327c7f7d321b42a8fef1d78aafa9"}, + {file = "torchaudio-2.4.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:375d8740c8035a50faca7a5afe2fbdb712aa8733715b971b2af61b4003fa1c41"}, + {file = "torchaudio-2.4.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:74d19cf9ca3dad394afcabb7e6f7ed9ab9f59f2540d502826c7ec3e33985251d"}, + {file = "torchaudio-2.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:40e9fa8fdc8d328ea4aa90be65fd34c5ef975610dbd707545e3664393a8a2497"}, + {file = "torchaudio-2.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3adce550850902b9aa6cd2378ccd720ac9ec8cf31e2eba9743ccc84ffcbe76d6"}, + {file = "torchaudio-2.4.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:98d8e03703f96b13a8d172d1ccdc7badb338227fd762985fdcea6b30f6697bdb"}, + {file = "torchaudio-2.4.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:36c7e7bc6b358cbf42b769c80206780fa1497d141a985c6b3e7768de44524e9a"}, + {file = "torchaudio-2.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:f46e34ab3866ad8d8ace0673cd11e697c5cde6a3b7a4d8d789207d4d8badbb6e"}, ] [package.dependencies] -torch = "2.5.0" +torch = "2.4.1" [[package]] name = "torchmetrics" @@ -3401,37 +3418,37 @@ trampoline = ">=0.1.2" [[package]] name = "torchvision" -version = "0.20.0" +version = "0.19.1" description = "image and video datasets and models for torch deep learning" optional = false python-versions = ">=3.8" files = [ - {file = "torchvision-0.20.0-1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:e084f50ecbdbe7a9cc2fc51ea0367ae35fde46e84a964bf4046cb1c7feb7e3e6"}, - {file = "torchvision-0.20.0-1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:55d7f43ef912ebc4da4bba73a0bbf387d38a6be9cd521679c0f4056f9564b698"}, - {file = "torchvision-0.20.0-1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:f8d0213489acfb138369f2455a6893880c194a8195e381c19f872b277f2654c3"}, - {file = "torchvision-0.20.0-1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:8d6cea8ab0bf72ecb71b07cd0fe836eacf5a5fa98f6629d2261212e90977b963"}, - {file = "torchvision-0.20.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f164d545965186ffd66014e34a966706d12c84198302dd46748cae45984609a4"}, - {file = "torchvision-0.20.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:9c18208575d60b96e7d53a09c453781afea4a81487c9ebc501dfc2bc88daa308"}, - {file = "torchvision-0.20.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:09080359be90314fc4fdd64b11a4d231c1999018f19d58bf7764f5e15f8e9fb3"}, - {file = "torchvision-0.20.0-cp310-cp310-win_amd64.whl", hash = "sha256:a7d46cf096007b7e8df1bddad7375427664a064bc05d9cbff5d506b73c1ab8ca"}, - {file = "torchvision-0.20.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a15de6266a36bcd10d89f6f3d7ba4e2dd567a7a0add616ebc6e65aea20790e5d"}, - {file = "torchvision-0.20.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:b64d9f83cf201ebda4f6b03533e4918fa0b4223b28b0ee3cbede15b8174c7cbd"}, - {file = "torchvision-0.20.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:d80eb740810804bac4b8e6b6411946ab286a1ee1d731db36af2f885333254802"}, - {file = "torchvision-0.20.0-cp311-cp311-win_amd64.whl", hash = "sha256:1fd045757335d34969d176fc5688b643d201860cb45b48ce8d5d8fb90868f746"}, - {file = "torchvision-0.20.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ac0edba534fb071b2b03a2fd5cbbf9b7c259896d17a1d0d830b3c5b7dfae0782"}, - {file = "torchvision-0.20.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:c8f3bc399d9c3e4ba05d74ca6dd5e63fed08ad5c5b302a946c8fcaa56216220f"}, - {file = "torchvision-0.20.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a78c99ebe1a62857b68e97ff9417b92f299f2ee61f009491a114ddad050c493d"}, - {file = "torchvision-0.20.0-cp312-cp312-win_amd64.whl", hash = "sha256:bb0da0950d2034a0412c251a3a9117ff9612157f45177d37ba1b20b472c0864b"}, - {file = "torchvision-0.20.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6a70c81ea5068dd7b1e340ebeabb65364576d8b9819454cfdf812290cf03e45a"}, - {file = "torchvision-0.20.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:95d8c817681a4c2156f66ef83cafc4c5c4b97e4694956d54d7dc554804ee510d"}, - {file = "torchvision-0.20.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:1ab53244701eab897e5c65026ba178c0abbc5bd08629c3d20f737d618e9e5a37"}, - {file = "torchvision-0.20.0-cp39-cp39-win_amd64.whl", hash = "sha256:47d0751aeaa7057ee6a5973d35e7acad3ad7c17b8e57a2c4304d13e001e330ae"}, + {file = "torchvision-0.19.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:54e8513099e6f586356c70f809d34f391af71ad182fe071cc328a28af2c40608"}, + {file = "torchvision-0.19.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:20a1f5e02bfdad7714e55fa3fa698347c11d829fa65e11e5a84df07d93350eed"}, + {file = "torchvision-0.19.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:7b063116164be52fc6deb4762de7f8c90bfa3a65f8d5caf17f8e2d5aadc75a04"}, + {file = "torchvision-0.19.1-cp310-cp310-win_amd64.whl", hash = "sha256:f40b6acabfa886da1bc3768f47679c61feee6bde90deb979d9f300df8c8a0145"}, + {file = "torchvision-0.19.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:40514282b4896d62765b8e26d7091c32e17c35817d00ec4be2362ea3ba3d1787"}, + {file = "torchvision-0.19.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:5a91be061ae5d6d5b95e833b93e57ca4d3c56c5a57444dd15da2e3e7fba96050"}, + {file = "torchvision-0.19.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:d71a6a6fe3a5281ca3487d4c56ad4aad20ff70f82f1d7c79bcb6e7b0c2af00c8"}, + {file = "torchvision-0.19.1-cp311-cp311-win_amd64.whl", hash = "sha256:70dea324174f5e9981b68e4b7cd524512c106ba64aedef560a86a0bbf2fbf62c"}, + {file = "torchvision-0.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:27ece277ff0f6cdc7fed0627279c632dcb2e58187da771eca24b0fbcf3f8590d"}, + {file = "torchvision-0.19.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:c659ff92a61f188a1a7baef2850f3c0b6c85685447453c03d0e645ba8f1dcc1c"}, + {file = "torchvision-0.19.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:c07bf43c2a145d792ecd9d0503d6c73577147ece508d45600d8aac77e4cdfcf9"}, + {file = "torchvision-0.19.1-cp312-cp312-win_amd64.whl", hash = "sha256:b4283d283675556bb0eae31d29996f53861b17cbdcdf3509e6bc050414ac9289"}, + {file = "torchvision-0.19.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4c4e4f5b24ea6b087b02ed492ab1e21bba3352c4577e2def14248cfc60732338"}, + {file = "torchvision-0.19.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:9281d63ead929bb19143731154cd1d8bf0b5e9873dff8578a40e90a6bec3c6fa"}, + {file = "torchvision-0.19.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:4d10bc9083c4d5fadd7edd7b729700a7be48dab4f62278df3bc73fa48e48a155"}, + {file = "torchvision-0.19.1-cp38-cp38-win_amd64.whl", hash = "sha256:ccf085ef1824fb9e16f1901285bf89c298c62dfd93267a39e8ee42c71255242f"}, + {file = "torchvision-0.19.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:731f434d91586769e255b5d70ed1a4457e0a1394a95f4aacf0e1e7e21f80c098"}, + {file = "torchvision-0.19.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:febe4f14d4afcb47cc861d8be7760ab6a123cd0817f97faf5771488cb6aa90f4"}, + {file = "torchvision-0.19.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:e328309b8670a2e889b2fe76a1c2744a099c11c984da9a822357bd9debd699a5"}, + {file = "torchvision-0.19.1-cp39-cp39-win_amd64.whl", hash = "sha256:6616f12e00a22e7f3fedbd0fccb0804c05e8fe22871668f10eae65cf3f283614"}, ] [package.dependencies] numpy = "*" pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0" -torch = "2.5.0" +torch = "2.4.1" [package.extras] gdown = ["gdown (>=4.7.3)"] @@ -3469,13 +3486,13 @@ files = [ [[package]] name = "transformers" -version = "4.44.2" +version = "4.46.1" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.8.0" files = [ - {file = "transformers-4.44.2-py3-none-any.whl", hash = "sha256:1c02c65e7bfa5e52a634aff3da52138b583fc6f263c1f28d547dc144ba3d412d"}, - {file = "transformers-4.44.2.tar.gz", hash = "sha256:36aa17cc92ee154058e426d951684a2dab48751b35b49437896f898931270826"}, + {file = "transformers-4.46.1-py3-none-any.whl", hash = "sha256:f77b251a648fd32e3d14b5e7e27c913b7c29154940f519e4c8c3aa6061df0f05"}, + {file = "transformers-4.46.1.tar.gz", hash = "sha256:16d79927d772edaf218820a96f9254e2211f9cd7fb3c308562d2d636c964a68c"}, ] [package.dependencies] @@ -3487,21 +3504,21 @@ pyyaml = ">=5.1" regex = "!=2019.12.17" requests = "*" safetensors = ">=0.4.1" -tokenizers = ">=0.19,<0.20" +tokenizers = ">=0.20,<0.21" tqdm = ">=4.27" [package.extras] -accelerate = ["accelerate (>=0.21.0)"] -agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"] -all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=0.9.16)", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"] +accelerate = ["accelerate (>=0.26.0)"] +agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=0.9.16)", "tokenizers (>=0.20,<0.21)", "torch", "torchaudio", "torchvision"] audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] -benchmark = ["optimum-benchmark (>=0.2.0)"] +benchmark = ["optimum-benchmark (>=0.3.0)"] codecarbon = ["codecarbon (==1.2.0)"] -deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"] -deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=0.9.16)", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.19,<0.20)", "urllib3 (<2.0.0)"] -dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=0.9.16)", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=0.9.16)", "tokenizers (>=0.20,<0.21)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.20,<0.21)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "libcst", "librosa", "nltk (<=3.8.1)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=0.9.16)", "tokenizers (>=0.20,<0.21)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"] flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] ftfy = ["ftfy"] @@ -3512,7 +3529,7 @@ natten = ["natten (>=0.14.6,<0.15.0)"] onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] optuna = ["optuna"] -quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "ruff (==0.5.1)", "urllib3 (<2.0.0)"] +quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.5.1)", "urllib3 (<2.0.0)"] ray = ["ray[tune] (>=2.7.0)"] retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] ruff = ["ruff (==0.5.1)"] @@ -3522,31 +3539,32 @@ serving = ["fastapi", "pydantic", "starlette", "uvicorn"] sigopt = ["sigopt"] sklearn = ["scikit-learn"] speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"] tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +tiktoken = ["blobfile", "tiktoken"] timm = ["timm (<=0.9.16)"] -tokenizers = ["tokenizers (>=0.19,<0.20)"] -torch = ["accelerate (>=0.21.0)", "torch"] +tokenizers = ["tokenizers (>=0.20,<0.21)"] +torch = ["accelerate (>=0.26.0)", "torch"] torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] -torchhub = ["filelock", "huggingface-hub (>=0.23.2,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.19,<0.20)", "torch", "tqdm (>=4.27)"] -video = ["av (==9.2.0)", "decord (==0.6.0)"] +torchhub = ["filelock", "huggingface-hub (>=0.23.2,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.20,<0.21)", "torch", "tqdm (>=4.27)"] +video = ["av (==9.2.0)"] vision = ["Pillow (>=10.0.1,<=15.0)"] [[package]] name = "triton" -version = "3.1.0" +version = "3.0.0" description = "A language and compiler for custom Deep Learning operations" optional = false python-versions = "*" files = [ - {file = "triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b0dd10a925263abbe9fa37dcde67a5e9b2383fc269fdf59f5657cac38c5d1d8"}, - {file = "triton-3.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f34f6e7885d1bf0eaaf7ba875a5f0ce6f3c13ba98f9503651c1e6dc6757ed5c"}, - {file = "triton-3.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8182f42fd8080a7d39d666814fa36c5e30cc00ea7eeeb1a2983dbb4c99a0fdc"}, - {file = "triton-3.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6dadaca7fc24de34e180271b5cf864c16755702e9f63a16f62df714a8099126a"}, - {file = "triton-3.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aafa9a20cd0d9fee523cd4504aa7131807a864cd77dcf6efe7e981f18b8c6c11"}, + {file = "triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a"}, + {file = "triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c"}, + {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, + {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, + {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, ] [package.dependencies] @@ -4202,4 +4220,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "c24ef2004b2191f79e139e629ba8258fb92b6d6a15ad0d827b6042fa62440ebd" +content-hash = "61cbffb8a7b0a9a5bac978652b3097c6fa10ae6b0677a2640613ac7736408e91" diff --git a/install/apple/pyproject.toml b/install/apple/pyproject.toml index 98e7c4a1..1034d915 100644 --- a/install/apple/pyproject.toml +++ b/install/apple/pyproject.toml @@ -9,8 +9,8 @@ package-mode = false [tool.poetry.dependencies] python = ">=3.10,<3.13" -torch = "^2.5.0" -torchvision = "^0.20.0" +torch = "2.4.1" +torchvision = "*" diffusers = "^0.31.0" transformers = "^4.44.2" datasets = "^3.0.0" @@ -44,7 +44,8 @@ fastapi = {extras = ["standard"], version = "^0.115.0"} deepspeed = "^0.15.1" sentencepiece = "^0.2.0" torchao = "^0.5.0" -torchaudio = "^2.5.0" +torchaudio = "*" +omnigen = {git = "https://github.com/bghira/omnigen", rev = "dependency-update/peft"} [build-system]