Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TESTPR] #26307

Closed
wants to merge 13 commits into from
445 changes: 401 additions & 44 deletions sdks/python/apache_beam/ml/transforms/base.py

Large diffs are not rendered by default.

185 changes: 170 additions & 15 deletions sdks/python/apache_beam/ml/transforms/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,42 @@
import tempfile
import typing
import unittest
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence

import numpy as np
from parameterized import param
from parameterized import parameterized

import apache_beam as beam
from apache_beam.metrics.metric import MetricsFilter
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.transforms import base
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to

# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
try:
from apache_beam.ml.transforms import base
from apache_beam.ml.transforms import tft
from apache_beam.ml.transforms.tft import TFTOperation
except ImportError:
tft = None # type: ignore

if tft is None:
raise unittest.SkipTest('tensorflow_transform is not installed')

try:

class _FakeOperation(TFTOperation):
def __init__(self, name, *args, **kwargs):
super().__init__(*args, **kwargs)
self.name = name
class _FakeOperation(TFTOperation):
def __init__(self, name, *args, **kwargs):
super().__init__(*args, **kwargs)
self.name = name

def apply_transform(self, inputs, output_column_name, **kwargs):
return {output_column_name: inputs}
def apply_transform(self, inputs, output_column_name, **kwargs):
return {output_column_name: inputs}
except: # pylint: disable=bare-except
pass


class BaseMLTransformTest(unittest.TestCase):
Expand All @@ -59,6 +65,7 @@ def setUp(self) -> None:
def tearDown(self):
shutil.rmtree(self.artifact_location)

@unittest.skipIf(tft is None, 'tft module is not installed.')
def test_ml_transform_appends_transforms_to_process_handler_correctly(self):
fake_fn_1 = _FakeOperation(name='fake_fn_1', columns=['x'])
transforms = [fake_fn_1]
Expand All @@ -67,12 +74,11 @@ def test_ml_transform_appends_transforms_to_process_handler_correctly(self):
ml_transform = ml_transform.with_transform(
transform=_FakeOperation(name='fake_fn_2', columns=['x']))

self.assertEqual(len(ml_transform._process_handler.transforms), 2)
self.assertEqual(
ml_transform._process_handler.transforms[0].name, 'fake_fn_1')
self.assertEqual(
ml_transform._process_handler.transforms[1].name, 'fake_fn_2')
self.assertEqual(len(ml_transform.transforms), 2)
self.assertEqual(ml_transform.transforms[0].name, 'fake_fn_1')
self.assertEqual(ml_transform.transforms[1].name, 'fake_fn_2')

@unittest.skipIf(tft is None, 'tft module is not installed.')
def test_ml_transform_on_dict(self):
transforms = [tft.ScaleTo01(columns=['x'])]
data = [{'x': 1}, {'x': 2}]
Expand All @@ -91,6 +97,7 @@ def test_ml_transform_on_dict(self):
assert_that(
actual_output, equal_to(expected_output, equals_fn=np.array_equal))

@unittest.skipIf(tft is None, 'tft module is not installed.')
def test_ml_transform_on_list_dict(self):
transforms = [tft.ScaleTo01(columns=['x'])]
data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}]
Expand Down Expand Up @@ -162,6 +169,7 @@ def test_ml_transform_on_list_dict(self):
},
),
])
@unittest.skipIf(tft is None, 'tft module is not installed.')
def test_ml_transform_dict_output_pcoll_schema(
self, input_data, input_types, expected_dtype):
transforms = [tft.ScaleTo01(columns=['x'])]
Expand All @@ -178,6 +186,7 @@ def test_ml_transform_dict_output_pcoll_schema(
if name in expected_dtype:
self.assertEqual(expected_dtype[name], typ)

@unittest.skipIf(tft is None, 'tft module is not installed.')
def test_ml_transform_fail_for_non_global_windows_in_produce_mode(self):
transforms = [tft.ScaleTo01(columns=['x'])]
with beam.Pipeline() as p:
Expand All @@ -193,6 +202,7 @@ def test_ml_transform_fail_for_non_global_windows_in_produce_mode(self):
write_artifact_location=self.artifact_location,
))

@unittest.skipIf(tft is None, 'tft module is not installed.')
def test_ml_transform_on_multiple_columns_single_transform(self):
transforms = [tft.ScaleTo01(columns=['x', 'y'])]
data = [{'x': [1, 2, 3], 'y': [1.0, 10.0, 20.0]}]
Expand All @@ -217,6 +227,7 @@ def test_ml_transform_on_multiple_columns_single_transform(self):
equal_to(expected_output_y, equals_fn=np.array_equal),
label='y')

@unittest.skipIf(tft is None, 'tft module is not installed.')
def test_ml_transforms_on_multiple_columns_multiple_transforms(self):
transforms = [
tft.ScaleTo01(columns=['x']),
Expand Down Expand Up @@ -245,6 +256,7 @@ def test_ml_transforms_on_multiple_columns_multiple_transforms(self):
equal_to(expected_output_y, equals_fn=np.array_equal),
label='actual_output_y')

@unittest.skipIf(tft is None, 'tft module is not installed.')
def test_mltransform_with_counter(self):
transforms = [
tft.ComputeAndApplyVocabulary(columns=['y']),
Expand All @@ -269,6 +281,149 @@ def test_mltransform_with_counter(self):
self.assertEqual(
result.metrics().query(mltransform_counter)['counters'][0].result, 1)

def test_non_ptransfrom_provider_class_to_mltransform(self):
class Add:
def __call__(self, x):
return x + 1

with self.assertRaisesRegex(
TypeError, 'transform must be a subclass of PTransformProvider'):
with beam.Pipeline() as p:
_ = (
p
| beam.Create([{
'x': 1
}])
| base.MLTransform(
write_artifact_location=self.artifact_location).with_transform(
Add()))


class FakeModel:
def __call__(self, example: List[str]) -> List[str]:
for i in range(len(example)):
example[i] = example[i][::-1]
return example


class FakeModelHandler(ModelHandler):
def run_inference(
self,
batch: Sequence[str],
model: Any,
inference_args: Optional[Dict[str, Any]] = None):
return model(batch)

def load_model(self):
return FakeModel()


class FakeEmbeddingsManager(base.EmbeddingsManager):
def __init__(self, columns):
super().__init__(columns=columns)

def get_model_handler(self) -> ModelHandler:
return FakeModelHandler()

def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
return (RunInference(model_handler=base._TextEmbeddingHandler(self)))


class TextEmbeddingHandlerTest(unittest.TestCase):
def setUp(self) -> None:
self.embedding_conig = FakeEmbeddingsManager(columns=['x'])
self.artifact_location = tempfile.mkdtemp()

def tearDown(self) -> None:
shutil.rmtree(self.artifact_location)

def test_handler_with_incompatible_datatype(self):
text_handler = base._TextEmbeddingHandler(
embeddings_manager=self.embedding_conig)
data = [
('x', 1),
('x', 2),
('x', 3),
]
with self.assertRaises(TypeError):
text_handler.run_inference(data, None, None)

def test_handler_with_dict_inputs(self):
data = [
{
'x': "Hello world"
},
{
'x': "Apache Beam"
},
]
expected_data = [{key: value[::-1]
for key, value in d.items()} for d in data]
with beam.Pipeline() as p:
result = (
p
| beam.Create(data)
| base.MLTransform(
write_artifact_location=self.artifact_location).with_transform(
self.embedding_conig))
assert_that(
result,
equal_to(expected_data),
)

def test_handler_with_batch_sizes(self):
self.embedding_conig.max_batch_size = 100
self.embedding_conig.min_batch_size = 10
data = [
{
'x': "Hello world"
},
{
'x': "Apache Beam"
},
] * 100
expected_data = [{key: value[::-1]
for key, value in d.items()} for d in data]
with beam.Pipeline() as p:
result = (
p
| beam.Create(data)
| base.MLTransform(
write_artifact_location=self.artifact_location).with_transform(
self.embedding_conig))
assert_that(
result,
equal_to(expected_data),
)

def test_handler_on_multiple_columns(self):
self.embedding_conig.columns = ['x', 'y']
data = [
{
'x': "Hello world", 'y': "Apache Beam", 'z': 'unchanged'
},
{
'x': "Apache Beam", 'y': "Hello world", 'z': 'unchanged'
},
]
self.embedding_conig.columns = ['x', 'y']
expected_data = [{
key: (value[::-1] if key in self.embedding_conig.columns else value)
for key,
value in d.items()
} for d in data]
with beam.Pipeline() as p:
result = (
p
| beam.Create(data)
| base.MLTransform(
write_artifact_location=self.artifact_location).with_transform(
self.embedding_conig))
assert_that(
result,
equal_to(expected_data),
)


if __name__ == '__main__':
unittest.main()
21 changes: 21 additions & 0 deletions sdks/python/apache_beam/ml/transforms/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: Add dead letter queue for RunInference transforms.

"""
This module contains embedding configs that can be used to generate
embeddings using MLTransform.
"""
Loading
Loading