From 21f149938bbf0f57b0119944e3865ebe5feace07 Mon Sep 17 00:00:00 2001 From: Jimmy <39@🇺🇸.com> Date: Tue, 29 Oct 2024 20:36:20 -0400 Subject: [PATCH] Fix multi-caption parquets crashing in multiple locations (Closes #1092) --- helpers/data_backend/factory.py | 99 +++++++++++++++++++++++----- helpers/metadata/backends/parquet.py | 12 ++-- helpers/prompts.py | 30 +++++---- tests/test_dataset.py | 59 +++++++++++++++++ 4 files changed, 165 insertions(+), 35 deletions(-) diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index dba8d7c4..8b3fbf2f 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -24,6 +24,8 @@ from tqdm import tqdm import queue from math import sqrt +import pandas as pd +import numpy as np logger = logging.getLogger("DataBackendFactory") if should_log(): @@ -48,6 +50,68 @@ def info_log(message): logger.info(message) +def check_column_values(column_data, column_name, parquet_path, fallback_caption_column=False): + # Determine if the column contains arrays or scalar values + non_null_values = column_data.dropna() + if non_null_values.empty: + # All values are null + raise ValueError( + f"Parquet file {parquet_path} contains only null values in the '{column_name}' column." + ) + + first_non_null = non_null_values.iloc[0] + if isinstance(first_non_null, (list, tuple, np.ndarray, pd.Series)): + # Column contains arrays + # Check for null arrays + if column_data.isnull().any() and not fallback_caption_column: + raise ValueError( + f"Parquet file {parquet_path} contains null arrays in the '{column_name}' column." + ) + + # Check for empty arrays + empty_arrays = column_data.apply(lambda x: len(x) == 0) + if empty_arrays.any() and not fallback_caption_column: + raise ValueError( + f"Parquet file {parquet_path} contains empty arrays in the '{column_name}' column." + ) + + # Check for null elements within arrays + null_elements_in_arrays = column_data.apply( + lambda arr: any(pd.isnull(s) for s in arr) + ) + if null_elements_in_arrays.any() and not fallback_caption_column: + raise ValueError( + f"Parquet file {parquet_path} contains null values within arrays in the '{column_name}' column." + ) + + # Check for empty strings within arrays + empty_strings_in_arrays = column_data.apply( + lambda arr: any(s == "" for s in arr) + ) + if empty_strings_in_arrays.all() and not fallback_caption_column: + raise ValueError( + f"Parquet file {parquet_path} contains only empty strings within arrays in the '{column_name}' column." + ) + + elif isinstance(first_non_null, str): + # Column contains scalar strings + # Check for null values + if column_data.isnull().any() and not fallback_caption_column: + raise ValueError( + f"Parquet file {parquet_path} contains null values in the '{column_name}' column." + ) + + # Check for empty strings + if (column_data == "").any() and not fallback_caption_column: + raise ValueError( + f"Parquet file {parquet_path} contains empty strings in the '{column_name}' column." + ) + else: + raise TypeError( + f"Unsupported data type in column '{column_name}'. Expected strings or arrays of strings." + ) + + def init_backend_config(backend: dict, args: dict, accelerator) -> dict: output = {"id": backend["id"], "config": {}} if backend.get("dataset_type", None) == "text_embeds": @@ -292,24 +356,23 @@ def configure_parquet_database(backend: dict, args, data_backend: BaseDataBacken raise ValueError( f"Parquet file {parquet_path} does not contain a column named '{filename_column}'." ) - # Check for null values - if df[caption_column].isnull().values.any() and not fallback_caption_column: - raise ValueError( - f"Parquet file {parquet_path} contains null values in the '{caption_column}' column, but no fallback_caption_column was set." - ) - if df[filename_column].isnull().values.any(): - raise ValueError( - f"Parquet file {parquet_path} contains null values in the '{filename_column}' column." - ) - # Check for empty strings - if (df[caption_column] == "").sum() > 0 and not fallback_caption_column: - raise ValueError( - f"Parquet file {parquet_path} contains empty strings in the '{caption_column}' column." - ) - if (df[filename_column] == "").sum() > 0: - raise ValueError( - f"Parquet file {parquet_path} contains empty strings in the '{filename_column}' column." - ) + + # Apply the function to the caption_column. + check_column_values( + df[caption_column], + caption_column, + parquet_path, + fallback_caption_column=fallback_caption_column + ) + + # Apply the function to the filename_column. + check_column_values( + df[filename_column], + filename_column, + parquet_path, + fallback_caption_column=False # Always check filename_column + ) + # Store the database in StateTracker StateTracker.set_parquet_database( backend["id"], diff --git a/helpers/metadata/backends/parquet.py b/helpers/metadata/backends/parquet.py index 584cd9c0..2850f986 100644 --- a/helpers/metadata/backends/parquet.py +++ b/helpers/metadata/backends/parquet.py @@ -150,11 +150,13 @@ def _extract_captions_to_fast_list(self): if len(caption_column) > 0: caption = [row[c] for c in caption_column] else: - caption = row[caption_column] + caption = row.get(caption_column) + if isinstance(caption, (numpy.ndarray, pd.Series)): + caption = [str(item) for item in caption if item is not None] - if not caption and fallback_caption_column: - caption = row[fallback_caption_column] - if not caption: + if caption is None and fallback_caption_column: + caption = row.get(fallback_caption_column, None) + if caption is None or caption == "" or caption == []: raise ValueError( f"Could not locate caption for image {filename} in sampler_backend {self.id} with filename column {filename_column}, caption column {caption_column}, and a parquet database with {len(self.parquet_database)} entries." ) @@ -162,7 +164,7 @@ def _extract_captions_to_fast_list(self): caption = caption.decode("utf-8") elif type(caption) == list: caption = [c.strip() for c in caption if c.strip()] - if caption: + elif type(caption) == str: caption = caption.strip() captions[filename] = caption return captions diff --git a/helpers/prompts.py b/helpers/prompts.py index bac03d77..448f0e79 100644 --- a/helpers/prompts.py +++ b/helpers/prompts.py @@ -5,6 +5,13 @@ from helpers.training.multi_process import _get_rank as get_rank from helpers.training import image_file_extensions +import numpy + +try: + import pandas as pd +except ImportError: + raise ImportError("Pandas is required for the ParquetMetadataBackend.") + prompts = { "alien_landscape": "Alien planet, strange rock formations, glowing plants, bizarre creatures, surreal atmosphere", "alien_market": "Alien marketplace, bizarre creatures, exotic goods, vibrant colors, otherworldly atmosphere", @@ -256,8 +263,10 @@ def prepare_instance_prompt_from_parquet( ) if type(image_caption) == bytes: image_caption = image_caption.decode("utf-8") - if image_caption: + if type(image_caption) == str: image_caption = image_caption.strip() + if type(image_caption) in (list, tuple, numpy.ndarray, pd.Series): + image_caption = [str(item).strip() for item in image_caption if item is not None] if prepend_instance_prompt: if type(image_caption) == list: image_caption = [instance_prompt + " " + x for x in image_caption] @@ -436,17 +445,14 @@ def get_all_captions( data_backend=data_backend, ) elif caption_strategy == "parquet": - try: - caption = PromptHandler.prepare_instance_prompt_from_parquet( - image_path, - use_captions=use_captions, - prepend_instance_prompt=prepend_instance_prompt, - instance_prompt=instance_prompt, - data_backend=data_backend, - sampler_backend_id=data_backend.id, - ) - except: - continue + caption = PromptHandler.prepare_instance_prompt_from_parquet( + image_path, + use_captions=use_captions, + prepend_instance_prompt=prepend_instance_prompt, + instance_prompt=instance_prompt, + data_backend=data_backend, + sampler_backend_id=data_backend.id, + ) elif caption_strategy == "instanceprompt": return [instance_prompt] elif caption_strategy == "csv": diff --git a/tests/test_dataset.py b/tests/test_dataset.py index aa6d5d26..7ae433b5 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,10 +1,12 @@ import unittest +import pandas as pd from unittest.mock import patch, Mock, MagicMock from PIL import Image from pathlib import Path from helpers.multiaspect.dataset import MultiAspectDataset from helpers.metadata.backends.discovery import DiscoveryMetadataBackend from helpers.data_backend.base import BaseDataBackend +from helpers.data_backend.factory import check_column_values class TestMultiAspectDataset(unittest.TestCase): @@ -82,5 +84,62 @@ def test_getitem_invalid_image(self): self.dataset.__getitem__(self.image_metadata) +class TestDataBackendFactory(unittest.TestCase): + def test_all_null(self): + column_data = pd.Series([None, None, None]) + with self.assertRaises(ValueError) as context: + check_column_values(column_data, "test_column", "test_file.parquet") + self.assertIn("contains only null values", str(context.exception)) + + def test_arrays_with_nulls(self): + column_data = pd.Series([[1, 2], None, [3, 4]]) + with self.assertRaises(ValueError) as context: + check_column_values(column_data, "test_column", "test_file.parquet") + self.assertIn("contains null arrays", str(context.exception)) + + def test_empty_arrays(self): + column_data = pd.Series([[1, 2], [], [3, 4]]) + with self.assertRaises(ValueError) as context: + check_column_values(column_data, "test_column", "test_file.parquet") + self.assertIn("contains empty arrays", str(context.exception)) + + def test_null_elements_in_arrays(self): + column_data = pd.Series([[1, None], [2, 3], [3, 4]]) + with self.assertRaises(ValueError) as context: + check_column_values(column_data, "test_column", "test_file.parquet") + self.assertIn("contains null values within arrays", str(context.exception)) + + def test_empty_strings_in_arrays(self): + column_data = pd.Series([["", ""], ["", ""], ["", ""]]) + with self.assertRaises(ValueError) as context: + check_column_values(column_data, "test_column", "test_file.parquet") + self.assertIn("contains only empty strings within arrays", str(context.exception)) + + def test_scalar_strings_with_nulls(self): + column_data = pd.Series(["a", None, "b"]) + with self.assertRaises(ValueError) as context: + check_column_values(column_data, "test_column", "test_file.parquet") + self.assertIn("contains null values", str(context.exception)) + + def test_scalar_strings_with_empty(self): + column_data = pd.Series(["a", "", "b"]) + with self.assertRaises(ValueError) as context: + check_column_values(column_data, "test_column", "test_file.parquet") + self.assertIn("contains empty strings", str(context.exception)) + + def test_with_fallback_caption(self): + column_data = pd.Series([None, "", [None], [""]]) + try: + check_column_values(column_data, "test_column", "test_file.parquet", fallback_caption_column=True) + except ValueError: + self.fail("check_column_values() raised ValueError unexpectedly with fallback_caption_column=True") + + def test_invalid_data_type(self): + column_data = pd.Series([1, 2, 3]) + with self.assertRaises(TypeError) as context: + check_column_values(column_data, "test_column", "test_file.parquet") + self.assertIn("Unsupported data type in column", str(context.exception)) + + if __name__ == "__main__": unittest.main()