From fe54c211512ae90f0d596bf2ed1e40d3826a7c7e Mon Sep 17 00:00:00 2001 From: Jack McCluskey <34928439+jrmccluskey@users.noreply.github.com> Date: Fri, 7 Jun 2024 11:21:02 -0400 Subject: [PATCH] Implement Hugging Face Image Embedding MLTransform (#31536) * Implement Hugging Face Image Embedding MLTransform * correct imports * Simplify to original sentence transformer class --- .../ml/transforms/embeddings/huggingface.py | 19 +++++-- .../transforms/embeddings/huggingface_test.py | 52 ++++++++++++++++++- 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py index 7fcaa9c9a5df..46b4ef9cf7d6 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py @@ -32,6 +32,7 @@ from apache_beam.ml.inference.base import ModelHandler from apache_beam.ml.inference.base import RunInference from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _ImageEmbeddingHandler from apache_beam.ml.transforms.base import _TextEmbeddingHandler try: @@ -114,6 +115,7 @@ def __init__( model_name: str, columns: List[str], max_seq_length: Optional[int] = None, + image_model: bool = False, **kwargs): """ Embedding config for sentence-transformers. This config can be used with @@ -122,9 +124,13 @@ def __init__( Args: model_name: Name of the model to use. The model should be hosted on - HuggingFace Hub or compatible with sentence_transformers. + HuggingFace Hub or compatible with sentence_transformers. For image + embedding models, see + https://www.sbert.net/docs/sentence_transformer/pretrained_models.html#image-text-models # pylint: disable=line-too-long + for a list of available sentence_transformers models. columns: List of columns to be embedded. max_seq_length: Max sequence length to use for the model if applicable. + image_model: Whether the model is generating image embeddings. min_batch_size: The minimum batch size to be used for inference. max_batch_size: The maximum batch size to be used for inference. large_model: Whether to share the model across processes. @@ -132,6 +138,7 @@ def __init__( super().__init__(columns, **kwargs) self.model_name = model_name self.max_seq_length = max_seq_length + self.image_model = image_model def get_model_handler(self): return _SentenceTransformerModelHandler( @@ -144,8 +151,14 @@ def get_model_handler(self): large_model=self.large_model) def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: - # wrap the model handler in a _TextEmbeddingHandler since - # the SentenceTransformerEmbeddings works on text input data. + # wrap the model handler in an appropriate embedding handler to provide + # some type checking. + if self.image_model: + return ( + RunInference( + model_handler=_ImageEmbeddingHandler(self), + inference_args=self.inference_args, + )) return ( RunInference( model_handler=_TextEmbeddingHandler(self), diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/huggingface_test.py b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface_test.py index f94e747c5edd..d09a573b6766 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/huggingface_test.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface_test.py @@ -35,6 +35,7 @@ try: from apache_beam.ml.transforms.embeddings.huggingface import SentenceTransformerEmbeddings from apache_beam.ml.transforms.embeddings.huggingface import InferenceAPIEmbeddings + from PIL import Image import torch except ImportError: SentenceTransformerEmbeddings = None # type: ignore @@ -46,10 +47,17 @@ except ImportError: tft = None +# pylint: disable=ungrouped-imports +try: + from PIL import Image +except ImportError: + Image = None + _HF_TOKEN = os.environ.get('HF_INFERENCE_TOKEN') test_query = "This is a test" test_query_column = "feature_1" DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" +IMAGE_MODEL_NAME = "clip-ViT-B-32" _parameterized_inputs = [ ([{ test_query_column: 'That is a happy person' @@ -85,7 +93,7 @@ @unittest.skipIf( SentenceTransformerEmbeddings is None, 'sentence-transformers is not installed.') -class SentenceTrasformerEmbeddingsTest(unittest.TestCase): +class SentenceTransformerEmbeddingsTest(unittest.TestCase): def setUp(self) -> None: self.artifact_location = tempfile.mkdtemp(prefix='sentence_transformers_') # this bucket has TTL and will be deleted periodically @@ -277,6 +285,48 @@ def test_mltransform_to_ptransform_with_sentence_transformer(self): self.assertEqual( ptransform_list[i]._model_handler._underlying.model_name, model_name) + def generateRandomImage(self, size: int): + imarray = np.random.rand(size, size, 3) * 255 + return Image.fromarray(imarray.astype('uint8')).convert('RGBA') + + @unittest.skipIf(Image is None, 'Pillow is not installed.') + def test_sentence_transformer_image_embeddings(self): + embedding_config = SentenceTransformerEmbeddings( + model_name=IMAGE_MODEL_NAME, + columns=[test_query_column], + image_model=True) + img = self.generateRandomImage(256) + with beam.Pipeline() as pipeline: + result_pcoll = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: img + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + + def assert_element(element): + assert len(element[test_query_column]) == 512 + + _ = (result_pcoll | beam.Map(assert_element)) + + def test_sentence_transformer_images_with_str_data_types(self): + embedding_config = SentenceTransformerEmbeddings( + model_name=IMAGE_MODEL_NAME, + columns=[test_query_column], + image_model=True) + with self.assertRaises(TypeError): + with beam.Pipeline() as pipeline: + _ = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: "image.jpg" + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + @unittest.skipIf(_HF_TOKEN is None, 'HF_TOKEN environment variable not set.') class HuggingfaceInferenceAPITest(unittest.TestCase):