diff --git a/sdks/python/apache_beam/ml/transforms/base.py b/sdks/python/apache_beam/ml/transforms/base.py index b3a30bb5f125..4ecbeacb8e67 100644 --- a/sdks/python/apache_beam/ml/transforms/base.py +++ b/sdks/python/apache_beam/ml/transforms/base.py @@ -14,20 +14,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pytype: skip-file - import abc +import collections +import logging +import os +import tempfile +import uuid +from typing import Any from typing import Dict from typing import Generic from typing import List +from typing import Mapping from typing import Optional from typing import Sequence from typing import TypeVar +from typing import Union + +import jsonpickle +import numpy as np import apache_beam as beam +from apache_beam.io.filesystems import FileSystems from apache_beam.metrics.metric import Metrics +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import ModelT +from apache_beam.options.pipeline_options import PipelineOptions -__all__ = ['MLTransform', 'ProcessHandler', 'BaseOperation'] +_LOGGER = logging.getLogger(__name__) +_ATTRIBUTE_FILE_NAME = 'attributes.json' + +__all__ = [ + 'MLTransform', + 'ProcessHandler', + 'PTransformProvider', + 'BaseOperation', + 'EmbeddingsManager' +] TransformedDatasetT = TypeVar('TransformedDatasetT') TransformedMetadataT = TypeVar('TransformedMetadataT') @@ -42,12 +64,68 @@ OperationOutputT = TypeVar('OperationOutputT') +def _convert_list_of_dicts_to_dict_of_lists( + list_of_dicts: Sequence[Dict[str, Any]]) -> Dict[str, List[Any]]: + keys_to_element_list = collections.defaultdict(list) + for d in list_of_dicts: + for key, value in d.items(): + keys_to_element_list[key].append(value) + return keys_to_element_list + + +def _convert_dict_of_lists_to_lists_of_dict( + dict_of_lists: Dict[str, List[Any]], + batch_length: int) -> List[Dict[str, Any]]: + result: List[Dict[str, Any]] = [{} for _ in range(batch_length)] + for key, values in dict_of_lists.items(): + for i in range(len(values)): + result[i][key] = values[i] + return result + + class ArtifactMode(object): PRODUCE = 'produce' CONSUME = 'consume' -class BaseOperation(Generic[OperationInputT, OperationOutputT], abc.ABC): +class PTransformProvider: + """ + Data processing transforms that are intended to be used with MLTransform + should subclass PTransformProvider and implement the following methods: + 1. get_ptransform_for_processing() + 2. requires_chaining() + + get_ptransform_for_processing() method should return a PTransform that can be + used to process the data. + + requires_chaining() method should return True if the data processing + transforms needs to be chained sequentially with compatible data processing + transforms. + """ + @abc.abstractmethod + def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: + """ + Returns a PTransform that can be used to process the data. + """ + + @abc.abstractmethod + def requires_chaining(self): + """ + Returns True if the data processing transforms needs to be chained + sequentially with compatible data processing transforms. + """ + + def get_counter(self): + """ + Returns the counter name for the data processing transform. + """ + counter_name = self.__class__.__name__ + return Metrics.counter(MLTransform, f'BeamML_{counter_name}') + + +class BaseOperation(Generic[OperationInputT, OperationOutputT], + PTransformProvider, + abc.ABC): def __init__(self, columns: List[str]) -> None: """ Base Opertation class data processing transformations. @@ -76,33 +154,55 @@ def __call__(self, data: OperationInputT, transformed_data = self.apply_transform(data, output_column_name) return transformed_data - def get_counter(self): - """ - Returns the counter name for the operation. - """ - counter_name = self.__class__.__name__ - return Metrics.counter(MLTransform, f'BeamML_{counter_name}') - -class ProcessHandler(Generic[ExampleT, MLTransformOutputT], abc.ABC): +class ProcessHandler(beam.PTransform[beam.PCollection[ExampleT], + beam.PCollection[MLTransformOutputT]], + abc.ABC): """ Only for internal use. No backwards compatibility guarantees. """ @abc.abstractmethod - def process_data( - self, pcoll: beam.PCollection[ExampleT] - ) -> beam.PCollection[MLTransformOutputT]: + def append_transform(self, transform: BaseOperation): """ - Logic to process the data. This will be the entrypoint in - beam.MLTransform to process incoming data. + Append transforms to the ProcessHandler. """ + +# TODO: Add support for inference_fn +class EmbeddingsManager(PTransformProvider): + def __init__( + self, + columns: List[str], + *, + # common args for all ModelHandlers. + load_model_args: Optional[Dict[str, Any]] = None, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + large_model: bool = False, + **kwargs): + self.load_model_args = load_model_args or {} + self.min_batch_size = min_batch_size + self.max_batch_size = max_batch_size + self.large_model = large_model + self.columns = columns + + if kwargs: + _LOGGER.warning("Ignoring the following arguments: %s", kwargs.keys()) + + # TODO: Add set_model_handler method. @abc.abstractmethod - def append_transform(self, transform: BaseOperation): + def get_model_handler(self) -> ModelHandler: """ - Append transforms to the ProcessHandler. + Return framework specific model handler. """ + def requires_chaining(self): + # each embedding config requires a separate PTransform. so no chaining. + return False + + def get_columns_to_apply(self): + return self.columns + class MLTransform(beam.PTransform[beam.PCollection[ExampleT], beam.PCollection[MLTransformOutputT]], @@ -112,7 +212,8 @@ def __init__( *, write_artifact_location: Optional[str] = None, read_artifact_location: Optional[str] = None, - transforms: Optional[Sequence[BaseOperation]] = None): + transforms: Optional[List[Union[BaseOperation, + EmbeddingsManager]]] = None): """ MLTransform is a Beam PTransform that can be used to apply transformations to the data. MLTransform is used to wrap the @@ -157,9 +258,6 @@ def __init__( i-th transform is the output of the (i-1)-th transform. Multi-input transforms are not supported yet. """ - if transforms: - _ = [self._validate_transform(transform) for transform in transforms] - if read_artifact_location and write_artifact_location: raise ValueError( 'Only one of read_artifact_location or write_artifact_location can ' @@ -177,19 +275,10 @@ def __init__( artifact_location = write_artifact_location # type: ignore[assignment] artifact_mode = ArtifactMode.PRODUCE - # avoid circular import - # pylint: disable=wrong-import-order, wrong-import-position - from apache_beam.ml.transforms.handlers import TFTProcessHandler - # TODO: When new ProcessHandlers(eg: JaxProcessHandler) are introduced, - # create a mapping between transforms and ProcessHandler since - # ProcessHandler is not exposed to the user. - process_handler: ProcessHandler = TFTProcessHandler( - artifact_location=artifact_location, - artifact_mode=artifact_mode, - transforms=transforms) # type: ignore[arg-type] - - self._process_handler = process_handler - self.transforms = transforms + self._parent_artifact_location = artifact_location + + self._artifact_mode = artifact_mode + self.transforms = transforms or [] self._counter = Metrics.counter( MLTransform, f'BeamML_{self.__class__.__name__}') @@ -209,10 +298,33 @@ def expand( Returns: A PCollection of MLTransformOutputT type """ + _ = [self._validate_transform(transform) for transform in self.transforms] + if self._artifact_mode == ArtifactMode.PRODUCE: + ptransform_partitioner = _MLTransformToPTransformMapper( + transforms=self.transforms, + artifact_location=self._parent_artifact_location, + artifact_mode=self._artifact_mode, + pipeline_options=pcoll.pipeline.options) + ptransform_list = ptransform_partitioner.create_and_save_ptransform_list() + else: + ptransform_list = ( + _MLTransformToPTransformMapper.load_transforms_from_artifact_location( + self._parent_artifact_location)) + + # the saved transforms has artifact mode set to PRODUCE. + # set the artifact mode to CONSUME. + if self._artifact_mode == ArtifactMode.CONSUME: + for i in range(len(ptransform_list)): + if hasattr(ptransform_list[i], 'artifact_mode'): + ptransform_list[i].artifact_mode = self._artifact_mode + + for ptransform in ptransform_list: + pcoll = pcoll | ptransform + _ = ( pcoll.pipeline | "MLTransformMetricsUsage" >> MLTransformMetricsUsage(self)) - return self._process_handler.process_data(pcoll) + return pcoll # type: ignore[return-value] def with_transform(self, transform: BaseOperation): """ @@ -222,14 +334,21 @@ def with_transform(self, transform: BaseOperation): Returns: A MLTransform instance. """ - self._validate_transform(transform) - self._process_handler.append_transform(transform) + # self._validate_transform(transform) + # avoid circular import + # pylint: disable=wrong-import-order, wrong-import-position + self.transforms.append(transform) return self def _validate_transform(self, transform): - if not isinstance(transform, BaseOperation): + # every data processing transform should subclass PTransformProvider. Raise + # an error if the transform does not subclass PTransformProvider since the + # downstream code expects the transform to be a subclass of + # PTransformProvider. + if not isinstance(transform, PTransformProvider): raise TypeError( - 'transform must be a subclass of BaseOperation. ' + 'transform must be a subclass of PTransformProvider and implement ' + 'get_ptransform_for_processing() method.' 'Got: %s instead.' % type(transform)) @@ -243,9 +362,7 @@ def _increment_counters(): # increment for MLTransform. self._ml_transform._counter.inc() # increment if data processing transforms are passed. - transforms = ( - self._ml_transform.transforms or - self._ml_transform._process_handler.transforms) + transforms = self._ml_transform.transforms if transforms: for transform in transforms: transform.get_counter().inc() @@ -254,3 +371,243 @@ def _increment_counters(): pipeline | beam.Create([None]) | beam.Map(lambda _: _increment_counters())) + + +class _TransformAttributeManager: + """ + Base class used for saving and loading the attributes. + """ + @staticmethod + def save_attributes(artifact_location): + """ + Save the attributes to json file using stdlib json. + """ + raise NotImplementedError + + @staticmethod + def load_attributes(artifact_location): + """ + Load the attributes from json file. + """ + raise NotImplementedError + + +class _JsonPickleTransformAttributeManager(_TransformAttributeManager): + """ + Use Jsonpickle to save and load the attributes. Here the attributes refer + to the list of PTransforms that are used to process the data. + + jsonpickle is used to serialize the PTransforms and save it to a json file and + is compatible across python versions. + """ + @staticmethod + def _is_remote_path(path): + is_gcs = path.find('gs://') != -1 + # TODO: Add support for other remote paths. + if not is_gcs and path.find('://') != -1: + raise RuntimeError( + "Artifact locations are currently supported for only available for " + "local paths and GCS paths. Got: %s" % path) + return is_gcs + + @staticmethod + def save_attributes( + ptransform_list, + artifact_location, + **kwargs, + ): + if _JsonPickleTransformAttributeManager._is_remote_path(artifact_location): + try: + options = kwargs.get('options') + except KeyError: + raise RuntimeError( + 'pipeline options are required to save the attributes.' + 'in the artifact location %s' % artifact_location) + + temp_dir = tempfile.mkdtemp() + temp_json_file = os.path.join(temp_dir, _ATTRIBUTE_FILE_NAME) + with open(temp_json_file, 'w+') as f: + f.write(jsonpickle.encode(ptransform_list)) + with open(temp_json_file, 'rb') as f: + from apache_beam.runners.dataflow.internal import apiclient + _LOGGER.info('Creating artifact location: %s', artifact_location) + apiclient.DataflowApplicationClient(options=options).stage_file( + gcs_or_local_path=artifact_location, + file_name=_ATTRIBUTE_FILE_NAME, + stream=f, + mime_type='application/json') + else: + if not FileSystems.exists(artifact_location): + FileSystems.mkdirs(artifact_location) + # FileSystems.open() fails if the file does not exist. + with open(os.path.join(artifact_location, _ATTRIBUTE_FILE_NAME), + 'w+') as f: + f.write(jsonpickle.encode(ptransform_list)) + + @staticmethod + def load_attributes(artifact_location): + with FileSystems.open(os.path.join(artifact_location, _ATTRIBUTE_FILE_NAME), + 'rb') as f: + return jsonpickle.decode(f.read()) + + +_transform_attribute_manager = _JsonPickleTransformAttributeManager + + +class _MLTransformToPTransformMapper: + """ + This class takes in a list of data processing transforms compatible to be + wrapped around MLTransform and returns a list of PTransforms that are used to + run the data processing transforms. + + The _MLTransformToPTransformMapper is responsible for loading and saving the + PTransforms or attributes of PTransforms to the artifact location to seal + the gap between the training and inference pipelines. + """ + def __init__( + self, + transforms: List[Union[BaseOperation, EmbeddingsManager]], + artifact_location: str, + artifact_mode: str, + pipeline_options: Optional[PipelineOptions] = None, + ): + self.transforms = transforms + self._parent_artifact_location = artifact_location + self.artifact_mode = artifact_mode + self.pipeline_options = pipeline_options + + def create_and_save_ptransform_list(self): + ptransform_list = self.create_ptransform_list() + self.save_transforms_in_artifact_location(ptransform_list) + return ptransform_list + + def create_ptransform_list(self): + previous_ptransform_type = None + current_ptransform = None + ptransform_list = [] + for transform in self.transforms: + if not isinstance(transform, PTransformProvider): + raise RuntimeError( + 'Transforms must be instances of PTransformProvider and ' + 'implement get_ptransform_for_processing() method.') + # for each instance of PTransform, create a new artifact location + current_ptransform = transform.get_ptransform_for_processing( + artifact_location=os.path.join( + self._parent_artifact_location, uuid.uuid4().hex[:6]), + artifact_mode=self.artifact_mode) + # Determine if a new ptransform should be added to the list + is_different_type = (type(current_ptransform) != previous_ptransform_type) + if is_different_type or not transform.requires_chaining(): + ptransform_list.append(current_ptransform) + previous_ptransform_type = type(current_ptransform) + + if hasattr(ptransform_list[-1], 'append_transform'): + ptransform_list[-1].append_transform(transform) + + return ptransform_list + + def save_transforms_in_artifact_location(self, ptransform_list): + """ + Save the ptransform references to json file. + """ + _transform_attribute_manager.save_attributes( + ptransform_list=ptransform_list, + artifact_location=self._parent_artifact_location, + options=self.pipeline_options) + + @staticmethod + def load_transforms_from_artifact_location(artifact_location): + return _transform_attribute_manager.load_attributes(artifact_location) + + +class _TextEmbeddingHandler(ModelHandler): + """ + A ModelHandler intended to be work on list[dict[str, str]] inputs. + + The inputs to the model handler are expected to be a list of dicts. + + For example, if the original mode is used with RunInference to take a + PCollection[E] to a PCollection[P], this ModelHandler would take a + PCollection[Dict[str, E]] to a PCollection[Dict[str, P]]. + + _TextEmbeddingHandler will accept an EmbeddingsManager instance, which + contains the details of the model to be loaded and the inference_fn to be + used. The purpose of _TextEmbeddingHandler is to generate embeddings for + text inputs using the EmbeddingsManager instance. + + If the input is not a text column, a RuntimeError will be raised. + + This is an internal class and offers no backwards compatibility guarantees. + + Args: + embeddings_manager: An EmbeddingsManager instance. + """ + def __init__(self, embeddings_manager: EmbeddingsManager): + self.embedding_config = embeddings_manager + self._underlying = self.embedding_config.get_model_handler() + self.columns = self.embedding_config.get_columns_to_apply() + + def load_model(self): + model = self._underlying.load_model() + return model + + def _validate_column_data(self, batch): + if not isinstance(batch[0], (str, bytes)): + raise TypeError('Embeddings can only be generated on text columns.') + + def _validate_batch(self, batch: Sequence[Dict[str, List[str]]]): + if not batch or not isinstance(batch[0], dict): + raise TypeError( + 'Expected data to be dicts, got ' + f'{type(batch[0])} instead.') + + def _process_batch( + self, + dict_batch: Dict[str, List[Any]], + model: ModelT, + inference_args: Optional[Dict[str, Any]]) -> Dict[str, List[Any]]: + result: Dict[str, List[Any]] = collections.defaultdict(list) + for key, batch in dict_batch.items(): + if key in self.columns: + self._validate_column_data(batch) + prediction = self._underlying.run_inference( + batch, model, inference_args) + if isinstance(prediction, np.ndarray): + prediction = prediction.tolist() + result[key] = prediction # type: ignore[assignment] + else: + result[key] = prediction # type: ignore[assignment] + else: + result[key] = batch + return result + + def run_inference( + self, + batch: Sequence[Dict[str, List[str]]], + model: ModelT, + inference_args: Optional[Dict[str, Any]] = None, + ) -> List[Dict[str, Union[List[float], List[str]]]]: + """ + Runs inference on a batch of text inputs. The inputs are expected to be + a list of dicts. Each dict should have the same keys, and the shape + should be of the same size for a single key across the batch. + """ + self._validate_batch(batch) + batch_len = len(batch) + dict_batch = _convert_list_of_dicts_to_dict_of_lists(list_of_dicts=batch) + transformed_batch = self._process_batch(dict_batch, model, inference_args) + return _convert_dict_of_lists_to_lists_of_dict( + dict_of_lists=transformed_batch, batch_length=batch_len) + + def get_metrics_namespace(self) -> str: + return ( + self._underlying.get_metrics_namespace() or + 'BeamML_TextEmbeddingHandler') + + def batch_elements_kwargs(self) -> Mapping[str, Any]: + batch_sizes_map = {} + if self.embedding_config.max_batch_size: + batch_sizes_map['max_batch_size'] = self.embedding_config.max_batch_size + if self.embedding_config.min_batch_size: + batch_sizes_map['min_batch_size'] = self.embedding_config.min_batch_size + return (self._underlying.batch_elements_kwargs() or batch_sizes_map) diff --git a/sdks/python/apache_beam/ml/transforms/base_test.py b/sdks/python/apache_beam/ml/transforms/base_test.py index 2e447964541b..1f9e5a85d1c2 100644 --- a/sdks/python/apache_beam/ml/transforms/base_test.py +++ b/sdks/python/apache_beam/ml/transforms/base_test.py @@ -20,7 +20,11 @@ import tempfile import typing import unittest +from typing import Any +from typing import Dict from typing import List +from typing import Optional +from typing import Sequence import numpy as np from parameterized import param @@ -28,28 +32,30 @@ import apache_beam as beam from apache_beam.metrics.metric import MetricsFilter +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.transforms import base from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports try: - from apache_beam.ml.transforms import base from apache_beam.ml.transforms import tft from apache_beam.ml.transforms.tft import TFTOperation except ImportError: tft = None # type: ignore -if tft is None: - raise unittest.SkipTest('tensorflow_transform is not installed') - +try: -class _FakeOperation(TFTOperation): - def __init__(self, name, *args, **kwargs): - super().__init__(*args, **kwargs) - self.name = name + class _FakeOperation(TFTOperation): + def __init__(self, name, *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = name - def apply_transform(self, inputs, output_column_name, **kwargs): - return {output_column_name: inputs} + def apply_transform(self, inputs, output_column_name, **kwargs): + return {output_column_name: inputs} +except: # pylint: disable=bare-except + pass class BaseMLTransformTest(unittest.TestCase): @@ -59,6 +65,7 @@ def setUp(self) -> None: def tearDown(self): shutil.rmtree(self.artifact_location) + @unittest.skipIf(tft is None, 'tft module is not installed.') def test_ml_transform_appends_transforms_to_process_handler_correctly(self): fake_fn_1 = _FakeOperation(name='fake_fn_1', columns=['x']) transforms = [fake_fn_1] @@ -67,12 +74,11 @@ def test_ml_transform_appends_transforms_to_process_handler_correctly(self): ml_transform = ml_transform.with_transform( transform=_FakeOperation(name='fake_fn_2', columns=['x'])) - self.assertEqual(len(ml_transform._process_handler.transforms), 2) - self.assertEqual( - ml_transform._process_handler.transforms[0].name, 'fake_fn_1') - self.assertEqual( - ml_transform._process_handler.transforms[1].name, 'fake_fn_2') + self.assertEqual(len(ml_transform.transforms), 2) + self.assertEqual(ml_transform.transforms[0].name, 'fake_fn_1') + self.assertEqual(ml_transform.transforms[1].name, 'fake_fn_2') + @unittest.skipIf(tft is None, 'tft module is not installed.') def test_ml_transform_on_dict(self): transforms = [tft.ScaleTo01(columns=['x'])] data = [{'x': 1}, {'x': 2}] @@ -91,6 +97,7 @@ def test_ml_transform_on_dict(self): assert_that( actual_output, equal_to(expected_output, equals_fn=np.array_equal)) + @unittest.skipIf(tft is None, 'tft module is not installed.') def test_ml_transform_on_list_dict(self): transforms = [tft.ScaleTo01(columns=['x'])] data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}] @@ -162,6 +169,7 @@ def test_ml_transform_on_list_dict(self): }, ), ]) + @unittest.skipIf(tft is None, 'tft module is not installed.') def test_ml_transform_dict_output_pcoll_schema( self, input_data, input_types, expected_dtype): transforms = [tft.ScaleTo01(columns=['x'])] @@ -178,6 +186,7 @@ def test_ml_transform_dict_output_pcoll_schema( if name in expected_dtype: self.assertEqual(expected_dtype[name], typ) + @unittest.skipIf(tft is None, 'tft module is not installed.') def test_ml_transform_fail_for_non_global_windows_in_produce_mode(self): transforms = [tft.ScaleTo01(columns=['x'])] with beam.Pipeline() as p: @@ -193,6 +202,7 @@ def test_ml_transform_fail_for_non_global_windows_in_produce_mode(self): write_artifact_location=self.artifact_location, )) + @unittest.skipIf(tft is None, 'tft module is not installed.') def test_ml_transform_on_multiple_columns_single_transform(self): transforms = [tft.ScaleTo01(columns=['x', 'y'])] data = [{'x': [1, 2, 3], 'y': [1.0, 10.0, 20.0]}] @@ -217,6 +227,7 @@ def test_ml_transform_on_multiple_columns_single_transform(self): equal_to(expected_output_y, equals_fn=np.array_equal), label='y') + @unittest.skipIf(tft is None, 'tft module is not installed.') def test_ml_transforms_on_multiple_columns_multiple_transforms(self): transforms = [ tft.ScaleTo01(columns=['x']), @@ -245,6 +256,7 @@ def test_ml_transforms_on_multiple_columns_multiple_transforms(self): equal_to(expected_output_y, equals_fn=np.array_equal), label='actual_output_y') + @unittest.skipIf(tft is None, 'tft module is not installed.') def test_mltransform_with_counter(self): transforms = [ tft.ComputeAndApplyVocabulary(columns=['y']), @@ -269,6 +281,149 @@ def test_mltransform_with_counter(self): self.assertEqual( result.metrics().query(mltransform_counter)['counters'][0].result, 1) + def test_non_ptransfrom_provider_class_to_mltransform(self): + class Add: + def __call__(self, x): + return x + 1 + + with self.assertRaisesRegex( + TypeError, 'transform must be a subclass of PTransformProvider'): + with beam.Pipeline() as p: + _ = ( + p + | beam.Create([{ + 'x': 1 + }]) + | base.MLTransform( + write_artifact_location=self.artifact_location).with_transform( + Add())) + + +class FakeModel: + def __call__(self, example: List[str]) -> List[str]: + for i in range(len(example)): + example[i] = example[i][::-1] + return example + + +class FakeModelHandler(ModelHandler): + def run_inference( + self, + batch: Sequence[str], + model: Any, + inference_args: Optional[Dict[str, Any]] = None): + return model(batch) + + def load_model(self): + return FakeModel() + + +class FakeEmbeddingsManager(base.EmbeddingsManager): + def __init__(self, columns): + super().__init__(columns=columns) + + def get_model_handler(self) -> ModelHandler: + return FakeModelHandler() + + def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: + return (RunInference(model_handler=base._TextEmbeddingHandler(self))) + + +class TextEmbeddingHandlerTest(unittest.TestCase): + def setUp(self) -> None: + self.embedding_conig = FakeEmbeddingsManager(columns=['x']) + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self) -> None: + shutil.rmtree(self.artifact_location) + + def test_handler_with_incompatible_datatype(self): + text_handler = base._TextEmbeddingHandler( + embeddings_manager=self.embedding_conig) + data = [ + ('x', 1), + ('x', 2), + ('x', 3), + ] + with self.assertRaises(TypeError): + text_handler.run_inference(data, None, None) + + def test_handler_with_dict_inputs(self): + data = [ + { + 'x': "Hello world" + }, + { + 'x': "Apache Beam" + }, + ] + expected_data = [{key: value[::-1] + for key, value in d.items()} for d in data] + with beam.Pipeline() as p: + result = ( + p + | beam.Create(data) + | base.MLTransform( + write_artifact_location=self.artifact_location).with_transform( + self.embedding_conig)) + assert_that( + result, + equal_to(expected_data), + ) + + def test_handler_with_batch_sizes(self): + self.embedding_conig.max_batch_size = 100 + self.embedding_conig.min_batch_size = 10 + data = [ + { + 'x': "Hello world" + }, + { + 'x': "Apache Beam" + }, + ] * 100 + expected_data = [{key: value[::-1] + for key, value in d.items()} for d in data] + with beam.Pipeline() as p: + result = ( + p + | beam.Create(data) + | base.MLTransform( + write_artifact_location=self.artifact_location).with_transform( + self.embedding_conig)) + assert_that( + result, + equal_to(expected_data), + ) + + def test_handler_on_multiple_columns(self): + self.embedding_conig.columns = ['x', 'y'] + data = [ + { + 'x': "Hello world", 'y': "Apache Beam", 'z': 'unchanged' + }, + { + 'x': "Apache Beam", 'y': "Hello world", 'z': 'unchanged' + }, + ] + self.embedding_conig.columns = ['x', 'y'] + expected_data = [{ + key: (value[::-1] if key in self.embedding_conig.columns else value) + for key, + value in d.items() + } for d in data] + with beam.Pipeline() as p: + result = ( + p + | beam.Create(data) + | base.MLTransform( + write_artifact_location=self.artifact_location).with_transform( + self.embedding_conig)) + assert_that( + result, + equal_to(expected_data), + ) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/__init__.py b/sdks/python/apache_beam/ml/transforms/embeddings/__init__.py new file mode 100644 index 000000000000..bda6256b79ef --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/embeddings/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# TODO: Add dead letter queue for RunInference transforms. + +""" +This module contains embedding configs that can be used to generate +embeddings using MLTransform. +""" diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/sentence_transformer.py b/sdks/python/apache_beam/ml/transforms/embeddings/sentence_transformer.py new file mode 100644 index 000000000000..5b31dbca0082 --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/embeddings/sentence_transformer.py @@ -0,0 +1,128 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["SentenceTransformerEmbeddings"] + +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Mapping +from typing import Optional +from typing import Sequence + +import apache_beam as beam +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _TextEmbeddingHandler +from sentence_transformers import SentenceTransformer + + +# TODO: Use HuggingFaceModelHandlerTensor once the import issue is fixed. +# Right now, the hugging face model handler import torch and tensorflow +# at the same time, which adds too much weigth to the container unnecessarily. +class _SentenceTransformerModelHandler(ModelHandler): + """ + Note: Intended for internal use and guarantees no backwards compatibility. + """ + def __init__( + self, + model_name: str, + model_class: Callable, + load_model_args: Optional[dict] = None, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + max_seq_length: Optional[int] = None, + large_model: bool = False, + **kwargs): + self._max_seq_length = max_seq_length + self._model_uri = model_name + self._model_class = model_class + self._load_model_args = load_model_args + self._min_batch_size = min_batch_size + self._max_batch_size = max_batch_size + self._large_model = large_model + self._kwargs = kwargs + + def run_inference( + self, + batch: Sequence[str], + model: SentenceTransformer, + inference_args: Optional[Dict[str, Any]] = None, + ): + inference_args = inference_args or {} + return model.encode(batch, **inference_args) + + def load_model(self): + model = self._model_class(self._model_uri) + if self._max_seq_length: + model.max_seq_length = self._max_seq_length + return model + + def share_model_across_processes(self) -> bool: + return self._large_model + + def batch_elements_kwargs(self) -> Mapping[str, Any]: + batch_sizes = {} + if self._min_batch_size: + batch_sizes["min_batch_size"] = self._min_batch_size + if self._max_batch_size: + batch_sizes["max_batch_size"] = self._max_batch_size + return batch_sizes + + +class SentenceTransformerEmbeddings(EmbeddingsManager): + def __init__( + self, + model_name: str, + columns: List[str], + max_seq_length: Optional[int] = None, + **kwargs): + """ + Embedding config for sentence-transformers. This config can be used with + MLTransform to embed text data. Models are loaded using the RunInference + PTransform with the help of ModelHandler. + Args: + model_name: Name of the model to use. The model should be hosted on + HuggingFace Hub or compatible with sentence_transformers. + columns: List of columns to be embedded. + max_seq_length: Max sequence length to use for the model if applicable. + min_batch_size: The minimum batch size to be used for inference. + max_batch_size: The maximum batch size to be used for inference. + large_model: Whether to share the model across processes. + """ + super().__init__(columns, **kwargs) + self.model_name = model_name + self.max_seq_length = max_seq_length + + def get_model_handler(self): + return _SentenceTransformerModelHandler( + model_class=SentenceTransformer, + max_seq_length=self.max_seq_length, + model_name=self.model_name, + load_model_args=self.load_model_args, + min_batch_size=self.min_batch_size, + max_batch_size=self.max_batch_size, + large_model=self.large_model) + + def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: + # wrap the model handler in a _TextEmbeddingHandler since + # the SentenceTransformerEmbeddings works on text input data. + return (RunInference(model_handler=_TextEmbeddingHandler(self))) + + def requires_chaining(self): + return False diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/sentence_transformer_test.py b/sdks/python/apache_beam/ml/transforms/embeddings/sentence_transformer_test.py new file mode 100644 index 000000000000..63f401180dc2 --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/embeddings/sentence_transformer_test.py @@ -0,0 +1,212 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import tempfile +import unittest + +import apache_beam as beam +from apache_beam.ml.transforms.base import MLTransform + +# pylint: disable=ungrouped-imports +try: + from apache_beam.ml.transforms.embeddings.sentence_transformer import SentenceTransformerEmbeddings +except ImportError: + SentenceTransformerEmbeddings = None # type: ignore + +# pylint: disable=ungrouped-imports +try: + import tensorflow_transform as tft + from apache_beam.ml.transforms.tft import ScaleTo01 +except ImportError: + tft = None + +test_query = "This is a test" +test_query_column = "feature_1" +DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" + + +def get_pipeline_wth_embedding_config( + pipeline: beam.Pipeline, embedding_config, artifact_location): + transformed_pcoll = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }]) + | "MLTransform" >> MLTransform(write_artifact_location=artifact_location). + with_transform(embedding_config)) + return transformed_pcoll + + +@unittest.skipIf( + SentenceTransformerEmbeddings is None, + 'sentence-transformers is not installed.') +class SentenceTrasformerEmbeddingsTest(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self) -> None: + shutil.rmtree(self.artifact_location) + + def test_sentence_transformer_embeddings(self): + model_name = DEFAULT_MODEL_NAME + embedding_config = SentenceTransformerEmbeddings( + model_name=model_name, columns=[test_query_column]) + with beam.Pipeline() as pipeline: + result_pcoll = get_pipeline_wth_embedding_config( + pipeline=pipeline, + embedding_config=embedding_config, + artifact_location=self.artifact_location) + + def assert_element(element): + assert len(element[test_query_column]) == 768 + + _ = (result_pcoll | beam.Map(assert_element)) + + @unittest.skipIf(tft is None, 'Tensorflow Transform is not installed.') + def test_embeddings_with_scale_to_0_1(self): + model_name = DEFAULT_MODEL_NAME + embedding_config = SentenceTransformerEmbeddings( + model_name=model_name, + columns=[test_query_column], + ) + with beam.Pipeline() as pipeline: + transformed_pcoll = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config).with_transform( + ScaleTo01(columns=[test_query_column]))) + + def assert_element(element): + assert max(element.feature_1) == 1 + + _ = (transformed_pcoll | beam.Map(assert_element)) + + def pipeline_with_configurable_artifact_location( + self, + pipeline, + embedding_config=None, + read_artifact_location=None, + write_artifact_location=None): + if write_artifact_location: + return ( + pipeline + | MLTransform(write_artifact_location=write_artifact_location). + with_transform(embedding_config)) + elif read_artifact_location: + return ( + pipeline + | MLTransform(read_artifact_location=read_artifact_location)) + else: + raise NotImplementedError + + def test_embeddings_with_read_artifact_location(self): + with beam.Pipeline() as p: + model_name = DEFAULT_MODEL_NAME + embedding_config = SentenceTransformerEmbeddings( + model_name=model_name, columns=[test_query_column]) + + with beam.Pipeline() as p: + data = ( + p + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }])) + _ = self.pipeline_with_configurable_artifact_location( + pipeline=data, + embedding_config=embedding_config, + write_artifact_location=self.artifact_location) + + with beam.Pipeline() as p: + data = ( + p + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }, { + test_query_column: test_query + }])) + result_pcoll = self.pipeline_with_configurable_artifact_location( + pipeline=data, read_artifact_location=self.artifact_location) + + def assert_element(element): + assert round(element, 2) == 0.13 + + _ = ( + result_pcoll + | beam.Map(lambda x: max(x[test_query_column])) + # 0.1342099905014038 + | beam.Map(assert_element)) + + def test_sentence_transformer_with_int_data_types(self): + model_name = DEFAULT_MODEL_NAME + embedding_config = SentenceTransformerEmbeddings( + model_name=model_name, columns=[test_query_column]) + with self.assertRaises(TypeError): + with beam.Pipeline() as pipeline: + _ = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: 1 + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + + def test_with_gcs_artifact_location(self): + artifact_location = ('gs://apache-beam-ml/testing/sentence_transformers') + with beam.Pipeline() as p: + model_name = DEFAULT_MODEL_NAME + embedding_config = SentenceTransformerEmbeddings( + model_name=model_name, columns=[test_query_column]) + + with beam.Pipeline() as p: + data = ( + p + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }])) + _ = self.pipeline_with_configurable_artifact_location( + pipeline=data, + embedding_config=embedding_config, + write_artifact_location=artifact_location) + + with beam.Pipeline() as p: + data = ( + p + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }, { + test_query_column: test_query + }])) + result_pcoll = self.pipeline_with_configurable_artifact_location( + pipeline=data, read_artifact_location=artifact_location) + + def assert_element(element): + assert round(element, 2) == 0.13 + + _ = ( + result_pcoll + | beam.Map(lambda x: max(x[test_query_column])) + # 0.1342099905014038 + | beam.Map(assert_element)) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py new file mode 100644 index 000000000000..4b01f7ec44b9 --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py @@ -0,0 +1,124 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable +from typing import List +from typing import Optional + +import apache_beam as beam +import tensorflow as tf +import tensorflow_hub as hub +import tensorflow_text as text # required to register TF ops. # pylint: disable=unused-import +from apache_beam.ml.inference import utils +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor +from apache_beam.ml.inference.tensorflow_inference import default_tensor_inference_fn +from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _TextEmbeddingHandler + +__all__ = ['TensorflowHubTextEmbeddings'] + + +class _TensorflowHubModelHandler(TFModelHandlerTensor): + """ + Note: Intended for internal use only. No backwards compatibility guarantees. + """ + def __init__(self, preprocessing_url: Optional[str], *args, **kwargs): + self.preprocessing_url = preprocessing_url + super().__init__(*args, **kwargs) + + def load_model(self): + # unable to load the models with tf.keras.models.load_model so + # using hub.KerasLayer instead + model = hub.KerasLayer(self._model_uri) + return model + + def _convert_prediction_result_to_list( + self, predictions: Iterable[PredictionResult]): + result = [] + for prediction in predictions: + inference = prediction.inference.numpy().tolist() + result.append(inference) + return result + + def run_inference(self, batch, model, inference_args, model_id=None): + if not inference_args: + inference_args = {} + if not self.preprocessing_url: + predictions = default_tensor_inference_fn( + model=model, + batch=batch, + inference_args=inference_args, + model_id=model_id) + return self._convert_prediction_result_to_list(predictions) + + vectorized_batch = tf.stack(batch, axis=0) + preprocessor_fn = hub.KerasLayer(self.preprocessing_url) + vectorized_batch = preprocessor_fn(vectorized_batch) + predictions = model(vectorized_batch) + # https://www.tensorflow.org/text/tutorials/classify_text_with_bert#using_the_bert_model # pylint: disable=line-too-long + # pooled_output -> represents the text as a whole. This is an embeddings + # of the whole text. The shape is [batch_size, embedding_dimension] + # sequence_output -> represents the text as a sequence of tokens. This is + # an embeddings of each token in the text. The shape is + # [batch_size, max_sequence_length, embedding_dimension] + # pooled output is the embeedings as per the documentation. so let's use + # that. + embeddings = predictions['pooled_output'] + predictions = utils._convert_to_result(batch, embeddings, model_id) + return self._convert_prediction_result_to_list(predictions) + + +class TensorflowHubTextEmbeddings(EmbeddingsManager): + def __init__( + self, + columns: List[str], + hub_url: str, + preprocessing_url: Optional[str] = None, + **kwargs): + """ + Embedding config for tensorflow hub models. This config can be used with + MLTransform to embed text data. Models are loaded using the RunInference + PTransform with the help of a ModelHandler. + + Args: + columns: The columns containing the text to be embedded. + hub_url: The url of the tensorflow hub model. + preprocessing_url: The url of the preprocessing model. This is optional. + If provided, the preprocessing model will be used to preprocess the + text before feeding it to the main model. + min_batch_size: The minimum batch size to be used for inference. + max_batch_size: The maximum batch size to be used for inference. + large_model: Whether to share the model across processes. + """ + super().__init__(columns=columns, **kwargs) + self.model_uri = hub_url + self.preprocessing_url = preprocessing_url + + def get_model_handler(self) -> ModelHandler: + # override the default inference function + return _TensorflowHubModelHandler( + model_uri=self.model_uri, + preprocessing_url=self.preprocessing_url, + min_batch_size=self.min_batch_size, + max_batch_size=self.max_batch_size, + large_model=self.large_model, + ) + + def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: + return (RunInference(model_handler=_TextEmbeddingHandler(self))) diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py new file mode 100644 index 000000000000..6b918153945a --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py @@ -0,0 +1,198 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import tempfile +import unittest + +import apache_beam as beam +from apache_beam.ml.transforms.base import MLTransform + +hub_url = 'https://tfhub.dev/google/LEALLA/LEALLA-small/1' +test_query_column = 'test_query' +test_query = 'This is a test query' + +# pylint: disable=ungrouped-imports +try: + import tensorflow as tf # disable=unused-import + from apache_beam.ml.transforms.embeddings.tensorflow_hub import TensorflowHubTextEmbeddings +except ImportError: + tf = None + +try: + from apache_beam.ml.transforms.tft import ScaleTo01 +except ImportError: + ScaleTo01 = None # type: ignore + + +@unittest.skipIf(tf is None, 'Tensorflow is not installed.') +class TFHubEmbeddingsTest(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self) -> None: + shutil.rmtree(self.artifact_location) + + def test_tfhub_text_embeddings(self): + embedding_config = TensorflowHubTextEmbeddings( + hub_url=hub_url, columns=[test_query_column]) + with beam.Pipeline() as pipeline: + transformed_pcoll = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + + def assert_element(element): + assert len(element[test_query_column]) == 128 + + _ = (transformed_pcoll | beam.Map(assert_element)) + + @unittest.skipIf(ScaleTo01 is None, 'Tensorflow Transform is not installed.') + def test_embeddings_with_scale_to_0_1(self): + embedding_config = TensorflowHubTextEmbeddings( + hub_url=hub_url, + columns=[test_query_column], + ) + with beam.Pipeline() as pipeline: + transformed_pcoll = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config).with_transform( + ScaleTo01(columns=[test_query_column]))) + + def assert_element(element): + assert max(element[test_query_column]) == 1 + + _ = ( + transformed_pcoll | beam.Map(lambda x: x.as_dict()) + | beam.Map(assert_element)) + + def pipeline_with_configurable_artifact_location( + self, + pipeline, + embedding_config=None, + read_artifact_location=None, + write_artifact_location=None): + if write_artifact_location: + return ( + pipeline + | MLTransform(write_artifact_location=write_artifact_location). + with_transform(embedding_config)) + elif read_artifact_location: + return ( + pipeline + | MLTransform(read_artifact_location=read_artifact_location)) + else: + raise NotImplementedError + + def test_embeddings_with_read_artifact_location(self): + with beam.Pipeline() as p: + embedding_config = TensorflowHubTextEmbeddings( + hub_url=hub_url, columns=[test_query_column]) + + with beam.Pipeline() as p: + data = ( + p + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }])) + _ = self.pipeline_with_configurable_artifact_location( + pipeline=data, + embedding_config=embedding_config, + write_artifact_location=self.artifact_location) + + with beam.Pipeline() as p: + data = ( + p + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }, { + test_query_column: test_query + }])) + result_pcoll = self.pipeline_with_configurable_artifact_location( + pipeline=data, read_artifact_location=self.artifact_location) + + def assert_element(element): + assert round(element, 2) == 0.21 + + _ = ( + result_pcoll + | beam.Map(lambda x: max(x[test_query_column])) + # 0.14797046780586243 + | beam.Map(assert_element)) + + def test_with_int_data_types(self): + embedding_config = TensorflowHubTextEmbeddings( + hub_url=hub_url, columns=[test_query_column]) + with self.assertRaises(TypeError): + with beam.Pipeline() as pipeline: + _ = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: 1 + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + + def test_with_gcs_artifact_location(self): + artifact_location = 'gs://apache-beam-ml/testing/tensorflow_hub' + with beam.Pipeline() as p: + embedding_config = TensorflowHubTextEmbeddings( + hub_url=hub_url, columns=[test_query_column]) + + with beam.Pipeline() as p: + data = ( + p + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }])) + _ = self.pipeline_with_configurable_artifact_location( + pipeline=data, + embedding_config=embedding_config, + write_artifact_location=artifact_location) + + with beam.Pipeline() as p: + data = ( + p + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }, { + test_query_column: test_query + }])) + result_pcoll = self.pipeline_with_configurable_artifact_location( + pipeline=data, read_artifact_location=artifact_location) + + def assert_element(element): + assert round(element, 2) == 0.21 + + _ = ( + result_pcoll + | beam.Map(lambda x: max(x[test_query_column])) + # 0.14797046780586243 + | beam.Map(assert_element)) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py new file mode 100644 index 000000000000..e4c6745bb566 --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py @@ -0,0 +1,158 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Vertex AI Python SDK is required for this module. +# Follow https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk # pylint: disable=line-too-long +# to install Vertex AI Python SDK. + +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence + +from google.auth.credentials import Credentials + +import apache_beam as beam +import vertexai +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _TextEmbeddingHandler +from vertexai.language_models import TextEmbeddingInput +from vertexai.language_models import TextEmbeddingModel + +__all__ = ["VertexAITextEmbeddings"] + +TASK_TYPE = "RETRIEVAL_DOCUMENT" +TASK_TYPE_INPUTS = [ + "RETRIEVAL_DOCUMENT", + "RETRIEVAL_QUERY", + "SEMANTIC_SIMILARITY", + "CLASSIFICATION", + "CLUSTERING" +] + + +class _VertexAITextEmbeddingHandler(ModelHandler): + """ + Note: Intended for internal use and guarantees no backwards compatibility. + """ + def __init__( + self, + model_name: str, + title: Optional[str] = None, + task_type: str = TASK_TYPE, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[Credentials] = None, + ): + vertexai.init(project=project, location=location, credentials=credentials) + self.model_name = model_name + if task_type not in TASK_TYPE_INPUTS: + raise ValueError( + f"task_type must be one of {TASK_TYPE_INPUTS}, got {task_type}") + self.task_type = task_type + self.title = title + + def run_inference( + self, + batch: Sequence[str], + model: Any, + inference_args: Optional[Dict[str, Any]] = None, + ) -> Iterable: + embeddings = [] + batch_size = 5 # Vertex AI limits requests to 5 at a time. + for i in range(0, len(batch), batch_size): + text_batch = batch[i:i + batch_size] + text_batch = [ + TextEmbeddingInput( + text=text, title=self.title, task_type=self.task_type) + for text in text_batch + ] + embeddings_batch = model.get_embeddings(text_batch) + embeddings.extend([el.values for el in embeddings_batch]) + return embeddings + + def load_model(self): + model = TextEmbeddingModel.from_pretrained(self.model_name) + return model + + +class VertexAITextEmbeddings(EmbeddingsManager): + def __init__( + self, + model_name: str, + columns: List[str], + title: Optional[str] = None, + task_type: str = TASK_TYPE, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[Credentials] = None, + **kwargs): + """ + Embedding Config for Vertex AI Text Embedding models following + https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings # pylint: disable=line-too-long + Text Embeddings are generated for a batch of text using the Vertex AI SDK. + Embeddings are returned in a list for each text in the batch. Look at + https://cloud.google.com/vertex-ai/docs/generative-ai/learn/model-versioning#stable-versions-available.md # pylint: disable=line-too-long + for more information on model versions and lifecycle. + + Args: + model_name: The name of the Vertex AI Text Embedding model. + columns: The columns containing the text to be embedded. + task_type: The name of the downstream task the embeddings will be used for. + Valid values: + RETRIEVAL_QUERY + Specifies the given text is a query in a search/retrieval setting. + RETRIEVAL_DOCUMENT + Specifies the given text is a document from the corpus being searched. + SEMANTIC_SIMILARITY + Specifies the given text will be used for STS. + CLASSIFICATION + Specifies that the given text will be classified. + CLUSTERING + Specifies that the embeddings will be used for clustering. + title: Optional identifier of the text content. + project: The default GCP project to make Vertex API calls. + location: The default location to use when making API calls. + credentials: The default custom + credentials to use when making API calls. If not provided credentials + will be ascertained from the environment. + + """ + self.model_name = model_name + self.project = project + self.location = location + self.credentials = credentials + self.title = title + self.task_type = task_type + super().__init__(columns=columns, **kwargs) + + def get_model_handler(self) -> ModelHandler: + return _VertexAITextEmbeddingHandler( + model_name=self.model_name, + project=self.project, + location=self.location, + credentials=self.credentials, + title=self.title, + task_type=self.task_type, + ) + + def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: + return (RunInference(model_handler=_TextEmbeddingHandler(self))) diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py new file mode 100644 index 000000000000..7124aab9cbf2 --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py @@ -0,0 +1,197 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import tempfile +import unittest + +import apache_beam as beam +from apache_beam.ml.transforms.base import MLTransform + +try: + from apache_beam.ml.transforms.embeddings.vertex_ai import VertexAITextEmbeddings +except ImportError: + VertexAITextEmbeddings = None # type: ignore + +# pylint: disable=ungrouped-imports +try: + import tensorflow_transform as tft + from apache_beam.ml.transforms.tft import ScaleTo01 +except ImportError: + tft = None + +test_query = "This is a test" +test_query_column = "feature_1" +model_name: str = "textembedding-gecko@002" + + +@unittest.skipIf( + VertexAITextEmbeddings is None, 'Vertex AI Python SDK is not installed.') +class VertexAIEmbeddingsTest(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self) -> None: + shutil.rmtree(self.artifact_location) + + def test_vertex_ai_text_embeddings(self): + embedding_config = VertexAITextEmbeddings( + model_name=model_name, columns=[test_query_column]) + with beam.Pipeline() as pipeline: + transformed_pcoll = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + + def assert_element(element): + assert len(element[test_query_column]) == 768 + + _ = (transformed_pcoll | beam.Map(assert_element)) + + @unittest.skipIf(tft is None, 'Tensorflow Transform is not installed.') + def test_embeddings_with_scale_to_0_1(self): + embedding_config = VertexAITextEmbeddings( + model_name=model_name, + columns=[test_query_column], + ) + with beam.Pipeline() as pipeline: + transformed_pcoll = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config).with_transform( + ScaleTo01(columns=[test_query_column]))) + + def assert_element(element): + assert max(element.feature_1) == 1 + + _ = (transformed_pcoll | beam.Map(assert_element)) + + def pipeline_with_configurable_artifact_location( + self, + pipeline, + embedding_config=None, + read_artifact_location=None, + write_artifact_location=None): + if write_artifact_location: + return ( + pipeline + | MLTransform(write_artifact_location=write_artifact_location). + with_transform(embedding_config)) + elif read_artifact_location: + return ( + pipeline + | MLTransform(read_artifact_location=read_artifact_location)) + else: + raise NotImplementedError + + def test_embeddings_with_read_artifact_location(self): + with beam.Pipeline() as p: + embedding_config = VertexAITextEmbeddings( + model_name=model_name, columns=[test_query_column]) + + with beam.Pipeline() as p: + data = ( + p + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }])) + _ = self.pipeline_with_configurable_artifact_location( + pipeline=data, + embedding_config=embedding_config, + write_artifact_location=self.artifact_location) + + with beam.Pipeline() as p: + data = ( + p + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }, { + test_query_column: test_query + }])) + result_pcoll = self.pipeline_with_configurable_artifact_location( + pipeline=data, read_artifact_location=self.artifact_location) + + def assert_element(element): + assert round(element, 2) == 0.15 + + _ = ( + result_pcoll + | beam.Map(lambda x: max(x[test_query_column])) + # 0.14797046780586243 + | beam.Map(assert_element)) + + def test_with_int_data_types(self): + embedding_config = VertexAITextEmbeddings( + model_name=model_name, columns=[test_query_column]) + with self.assertRaises(TypeError): + with beam.Pipeline() as pipeline: + _ = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: 1 + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + + def test_with_gcs_artifact_location(self): + artifact_location = ('gs://apache-beam-ml/testing/vertex_ai') + with beam.Pipeline() as p: + embedding_config = VertexAITextEmbeddings( + model_name=model_name, columns=[test_query_column]) + + with beam.Pipeline() as p: + data = ( + p + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }])) + _ = self.pipeline_with_configurable_artifact_location( + pipeline=data, + embedding_config=embedding_config, + write_artifact_location=artifact_location) + + with beam.Pipeline() as p: + data = ( + p + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }, { + test_query_column: test_query + }])) + result_pcoll = self.pipeline_with_configurable_artifact_location( + pipeline=data, read_artifact_location=artifact_location) + + def assert_element(element): + assert round(element, 2) == 0.15 + + _ = ( + result_pcoll + | beam.Map(lambda x: max(x[test_query_column])) + # 0.14797046780586243 + | beam.Map(assert_element)) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/transforms/handlers.py b/sdks/python/apache_beam/ml/transforms/handlers.py index 8695d5146efa..1a673c51df26 100644 --- a/sdks/python/apache_beam/ml/transforms/handlers.py +++ b/sdks/python/apache_beam/ml/transforms/handlers.py @@ -15,6 +15,7 @@ # limitations under the License. # # pytype: skip-file +# pylint: skip-file import collections import hashlib @@ -217,6 +218,9 @@ def __init__( def append_transform(self, transform): self.transforms.append(transform) + def get_transforms(self): + return self.transforms + def _map_column_names_to_types(self, row_type): """ Return a dictionary of column names and types. @@ -319,6 +323,7 @@ def _get_raw_data_feature_spec_per_column( f"Please provide a valid type from the following: " f"{_default_type_to_tensor_type_map.keys()}") return tf.io.VarLenFeature(_default_type_to_tensor_type_map[dtype]) + # return tf.io.VarLenFeature() def get_raw_data_metadata( self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata: @@ -387,7 +392,7 @@ def _get_transformed_data_schema( transformed_types[name] = typing.Sequence[bytes] # type: ignore[assignment] return transformed_types - def process_data( + def expand( self, raw_data: beam.PCollection[tft_process_handler_input_type] ) -> beam.PCollection[tft_process_handler_output_type]: """ @@ -512,7 +517,7 @@ def process_data( # The schema only contains the columns that are transformed. transformed_dataset = ( - transformed_dataset | "ConvertToRowType" >> + transformed_dataset + | "ConvertToRowType" >> beam.Map(lambda x: beam.Row(**x)).with_output_types(row_type)) - return transformed_dataset diff --git a/sdks/python/apache_beam/ml/transforms/handlers_test.py b/sdks/python/apache_beam/ml/transforms/handlers_test.py index 327c8c76c0e9..d39a1d775f3f 100644 --- a/sdks/python/apache_beam/ml/transforms/handlers_test.py +++ b/sdks/python/apache_beam/ml/transforms/handlers_test.py @@ -298,7 +298,7 @@ def test_tft_process_handler_verify_artifacts(self): transforms=[tft.ScaleTo01(columns=['x'])], artifact_location=self.artifact_location, ) - _ = process_handler.process_data(raw_data) + _ = raw_data | process_handler self.assertTrue( os.path.exists( @@ -315,7 +315,7 @@ def test_tft_process_handler_verify_artifacts(self): raw_data = (p | beam.Create([{'x': np.array([2, 5])}])) process_handler = handlers.TFTProcessHandler( artifact_location=self.artifact_location, artifact_mode='consume') - transformed_data = process_handler.process_data(raw_data) + transformed_data = raw_data | process_handler transformed_data |= beam.Map(lambda x: x.x) # the previous min is 1 and max is 6. So this should scale by (1, 6) @@ -494,7 +494,7 @@ def test_tft_process_handler_unused_column(self): transforms=[scale_to_0_1_fn], artifact_location=self.artifact_location, ) - transformed_pcoll = process_handler.process_data(raw_data) + transformed_pcoll = raw_data | process_handler transformed_pcoll_x = transformed_pcoll | beam.Map(lambda x: x.x) transformed_pcoll_y = transformed_pcoll | beam.Map(lambda x: x.y) assert_that( @@ -520,7 +520,7 @@ def test_consume_mode_with_extra_columns_in_the_input(self): transforms=[tft.ScaleTo01(columns=['x'])], artifact_location=self.artifact_location, ) - _ = process_handler.process_data(raw_data) + _ = raw_data | process_handler test_data = [{ 'x': np.array([2, 5]), 'y': np.array([1, 2]), 'z': 'fake_string' @@ -548,7 +548,7 @@ def test_consume_mode_with_extra_columns_in_the_input(self): raw_data = (p | beam.Create(test_data)) process_handler = handlers.TFTProcessHandler( artifact_location=self.artifact_location, artifact_mode='consume') - transformed_data = process_handler.process_data(raw_data) + transformed_data = raw_data | process_handler transformed_data_x = transformed_data | beam.Map(lambda x: x.x) transformed_data_y = transformed_data | beam.Map(lambda x: x.y) diff --git a/sdks/python/apache_beam/ml/transforms/tft.py b/sdks/python/apache_beam/ml/transforms/tft.py index c7b8ff015324..8705b79aa309 100644 --- a/sdks/python/apache_beam/ml/transforms/tft.py +++ b/sdks/python/apache_beam/ml/transforms/tft.py @@ -42,6 +42,7 @@ from typing import Tuple from typing import Union +import apache_beam as beam import tensorflow as tf import tensorflow_transform as tft from apache_beam.ml.transforms.base import BaseOperation @@ -95,6 +96,27 @@ def __init__(self, columns: List[str]) -> None: "Columns are not specified. Please specify the column for the " " op %s" % self.__class__.__name__) + def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: + from apache_beam.ml.transforms.handlers import TFTProcessHandler + params = {} + artifact_location = kwargs.get('artifact_location') + if not artifact_location: + raise RuntimeError( + "artifact_location is not specified. Please specify the " + "artifact_location for the op %s" % self.__class__.__name__) + + transforms = kwargs.get('transforms') + if transforms: + params['transforms'] = transforms + + artifact_mode = kwargs.get('artifact_mode') + if artifact_mode: + params['artifact_mode'] = artifact_mode + return TFTProcessHandler(artifact_location=artifact_location, **params) + + def requires_chaining(self): + return True + @tf.function def _split_string_with_delimiter(self, data, delimiter): """ diff --git a/sdks/python/apache_beam/ml/transforms/tft_test.py b/sdks/python/apache_beam/ml/transforms/tft_test.py index 38ded6a809af..9f15db45bd28 100644 --- a/sdks/python/apache_beam/ml/transforms/tft_test.py +++ b/sdks/python/apache_beam/ml/transforms/tft_test.py @@ -711,8 +711,13 @@ def test_count_per_key_on_list(self): ])) def validate_count_per_key(key_vocab_filename): + files = os.listdir(self.artifact_location) + files.remove(base._ATTRIBUTE_FILE_NAME) key_vocab_location = os.path.join( - self.artifact_location, 'transform_fn/assets', key_vocab_filename) + self.artifact_location, + files[0], + 'transform_fn/assets', + key_vocab_filename) with open(key_vocab_location, 'r') as f: key_vocab_list = [line.strip() for line in f] return key_vocab_list diff --git a/sdks/python/apache_beam/ml/transforms/utils.py b/sdks/python/apache_beam/ml/transforms/utils.py index 19bb02c5ae1b..b66cb4162ce2 100644 --- a/sdks/python/apache_beam/ml/transforms/utils.py +++ b/sdks/python/apache_beam/ml/transforms/utils.py @@ -17,9 +17,11 @@ __all__ = ['ArtifactsFetcher'] +import os import typing import tensorflow_transform as tft +from apache_beam.ml.transforms import base class ArtifactsFetcher(): @@ -28,8 +30,13 @@ class ArtifactsFetcher(): to the TFTProcessHandlers in MLTransform. """ def __init__(self, artifact_location): - self.artifact_location = artifact_location - self.transform_output = tft.TFTransformOutput(self.artifact_location) + files = os.listdir(artifact_location) + files.remove(base._ATTRIBUTE_FILE_NAME) + if len(files) > 1: + raise NotImplementedError( + 'Multiple files in artifact location not supported yet.') + self._artifact_location = os.path.join(artifact_location, files[0]) + self.transform_output = tft.TFTransformOutput(self._artifact_location) def get_vocab_list( self, diff --git a/sdks/python/container/py310/base_image_requirements.txt b/sdks/python/container/py310/base_image_requirements.txt index a9f94104374e..fc1ce3f28eea 100644 --- a/sdks/python/container/py310/base_image_requirements.txt +++ b/sdks/python/container/py310/base_image_requirements.txt @@ -82,6 +82,7 @@ idna==3.4 iniconfig==2.0.0 joblib==1.3.2 Js2Py==0.74 +jsonpickle==3.0.2 jsonschema==4.19.1 jsonschema-specifications==2023.7.1 mmh3==4.0.1 diff --git a/sdks/python/container/py311/base_image_requirements.txt b/sdks/python/container/py311/base_image_requirements.txt index 865b856683a4..7b55936530a0 100644 --- a/sdks/python/container/py311/base_image_requirements.txt +++ b/sdks/python/container/py311/base_image_requirements.txt @@ -79,6 +79,7 @@ idna==3.4 iniconfig==2.0.0 joblib==1.3.2 Js2Py==0.74 +jsonpickle==3.0.2 jsonschema==4.19.1 jsonschema-specifications==2023.7.1 mmh3==4.0.1 diff --git a/sdks/python/container/py38/base_image_requirements.txt b/sdks/python/container/py38/base_image_requirements.txt index 5dffff5f80d9..fb8928496716 100644 --- a/sdks/python/container/py38/base_image_requirements.txt +++ b/sdks/python/container/py38/base_image_requirements.txt @@ -85,6 +85,7 @@ importlib-resources==6.1.0 iniconfig==2.0.0 joblib==1.3.2 Js2Py==0.74 +jsonpickle==3.0.2 jsonschema==4.19.1 jsonschema-specifications==2023.7.1 mmh3==4.0.1 diff --git a/sdks/python/container/py39/base_image_requirements.txt b/sdks/python/container/py39/base_image_requirements.txt index 1b8ad7a2e748..c0dcd6baf6a3 100644 --- a/sdks/python/container/py39/base_image_requirements.txt +++ b/sdks/python/container/py39/base_image_requirements.txt @@ -83,6 +83,7 @@ importlib-metadata==6.8.0 iniconfig==2.0.0 joblib==1.3.2 Js2Py==0.74 +jsonpickle==3.0.2 jsonschema==4.19.1 jsonschema-specifications==2023.7.1 mmh3==4.0.1 diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh index 06ad06320fcf..8d5b43167dd1 100755 --- a/sdks/python/scripts/generate_pydoc.sh +++ b/sdks/python/scripts/generate_pydoc.sh @@ -133,7 +133,9 @@ autodoc_inherit_docstrings = False autodoc_member_order = 'bysource' autodoc_mock_imports = ["tensorrt", "cuda", "torch", "onnxruntime", "onnx", "tensorflow", "tensorflow_hub", - "tensorflow_transform", "tensorflow_metadata", "transformers"] + "tensorflow_transform", "tensorflow_metadata", "transformers", "tensorflow_text", + "sentence_transformers", + ] # Allow a special section for documenting DataFrame API napoleon_custom_sections = ['Differences from pandas'] diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 1785cd75df80..6c99dad55504 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -286,6 +286,7 @@ def get_portability_package_data(): 'httplib2>=0.8,<0.23.0', 'js2py>=0.74,<1', 'jsonschema>=4.0.0,<5.0.0', + 'jsonpickle>=3.0.0,<4.0.0', # numpy can have breaking changes in minor versions. # Use a strict upper bound. 'numpy>=1.14.3,<1.25.0', # Update pyproject.toml as well. diff --git a/sdks/python/test-suites/tox/py38/build.gradle b/sdks/python/test-suites/tox/py38/build.gradle index b1ed5f88c7c9..c4fd300ca943 100644 --- a/sdks/python/test-suites/tox/py38/build.gradle +++ b/sdks/python/test-suites/tox/py38/build.gradle @@ -141,6 +141,18 @@ toxTask "testPy38transformers-430", "py38-transformers-430", "${posargs}" test.dependsOn "testPy38transformers-430" preCommitPyCoverage.dependsOn "testPy38transformers-430" +toxTask "testPy38sentenceTransformers-222", "py38-sentence-transformers-222", "${posargs}" +test.dependsOn "testPy38sentenceTransformers-222" +preCommitPyCoverage.dependsOn "testPy38sentenceTransformers-222" + +toxTask "testPy38tensorflowHub-014", "py38-tfhub-014", "${posargs}" +test.dependsOn "testPy38tensorflowHub-014" +preCommitPyCoverage.dependsOn "testPy38tensorflowHub-014" + +toxTask "testPy38tensorflowHub-015", "py38-tfhub-015", "${posargs}" +test.dependsOn "testPy38tensorflowHub-015" +preCommitPyCoverage.dependsOn "testPy38tensorflowHub-015" + toxTask "whitespacelint", "whitespacelint", "${posargs}" task archiveFilesToLint(type: Zip) { diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini index e4cf09cacba4..1cea858e8bbc 100644 --- a/sdks/python/tox.ini +++ b/sdks/python/tox.ini @@ -423,3 +423,27 @@ commands = # Run all Vertex AI unit tests # Allow exit code 5 (no tests run) so that we can run this command safely on arbitrary subdirectories. /bin/sh -c 'pytest -o junit_suite_name={envname} --junitxml=pytest_{envname}.xml -n 6 -m uses_vertex_ai {posargs}; ret=$?; [ $ret = 5 ] && exit 0 || exit $ret' + + +[testenv:py{38,39,310,311}-sentence-transformers-222] +deps = + sentence-transformers==2.2.2 +extras = test,gcp +commands = + # Log aiplatform and its dependencies version for debugging + /bin/sh -c "pip freeze | grep -E sentence-transformers" + # Allow exit code 5 (no tests run) so that we can run this command safely on arbitrary subdirectories. + bash {toxinidir}/scripts/run_pytest.sh {envname} 'apache_beam/ml/transforms/embeddings' + +[testenv:py{38,39,310,311}-tfhub-{014,015}] +deps = + 014: tensorflow-hub>=0.14.0,<0.15.0 + 015: tensorflow-hub>=0.15.0,<0.16.0 + tensorflow-text + +extras = test,gcp +commands = + # Log aiplatform and its dependencies version for debugging + /bin/sh -c "pip freeze | grep -E tensorflow" + # Allow exit code 5 (no tests run) so that we can run this command safely on arbitrary subdirectories. + bash {toxinidir}/scripts/run_pytest.sh {envname} 'apache_beam/ml/transforms/embeddings'