diff --git a/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py b/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py index 3cb5f32c3114..56610e95297f 100644 --- a/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py +++ b/sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py @@ -17,6 +17,7 @@ # pytype: skip-file +import math from typing import Set from typing import Tuple @@ -124,6 +125,38 @@ def run_combine(pipeline, input_elements=5, lift_combiners=True): assert_that(pcoll, equal_to([(expected_result, expected_result)])) +def run_combine_uncopyable_attr( + pipeline, input_elements=5, lift_combiners=True): + # Calculate the expected result, which is the sum of an arithmetic sequence. + # By default, this is equal to: 0 + 1 + 2 + 3 + 4 = 10 + expected_result = input_elements * (input_elements - 1) / 2 + + # Enable runtime type checking in order to cover TypeCheckCombineFn by + # the test. + pipeline.get_pipeline_options().view_as(TypeOptions).runtime_type_check = True + pipeline.get_pipeline_options().view_as( + TypeOptions).allow_unsafe_triggers = True + + with pipeline as p: + pcoll = p | 'Start' >> beam.Create(range(input_elements)) + + # Certain triggers, such as AfterCount, are incompatible with combiner + # lifting. We can use that fact to prevent combiners from being lifted. + if not lift_combiners: + pcoll |= beam.WindowInto( + window.GlobalWindows(), + trigger=trigger.AfterCount(input_elements), + accumulation_mode=trigger.AccumulationMode.DISCARDING) + + combine_fn = CallSequenceEnforcingCombineFn() + # Modules are not deep copyable. Ensure fanout falls back to pickling for + # copying combine_fn. + combine_fn.module_attribute = math + pcoll |= 'Do' >> beam.CombineGlobally(combine_fn).with_fanout(fanout=1) + + assert_that(pcoll, equal_to([expected_result])) + + def run_pardo(pipeline, input_elements=10): with pipeline as p: _ = ( diff --git a/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py b/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py index 62dbbc5fb77c..647e08db7aaa 100644 --- a/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py +++ b/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py @@ -31,6 +31,7 @@ from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.transforms.combinefn_lifecycle_pipeline import CallSequenceEnforcingCombineFn from apache_beam.transforms.combinefn_lifecycle_pipeline import run_combine +from apache_beam.transforms.combinefn_lifecycle_pipeline import run_combine_uncopyable_attr from apache_beam.transforms.combinefn_lifecycle_pipeline import run_pardo @@ -53,15 +54,24 @@ def test_combining_value_state(self): @parameterized_class([ - {'runner': direct_runner.BundleBasedDirectRunner}, - {'runner': fn_api_runner.FnApiRunner}, -]) # yapf: disable + {'runner': direct_runner.BundleBasedDirectRunner, 'pickler': 'dill'}, + {'runner': direct_runner.BundleBasedDirectRunner, 'pickler': 'cloudpickle'}, + {'runner': fn_api_runner.FnApiRunner, 'pickler': 'dill'}, + {'runner': fn_api_runner.FnApiRunner, 'pickler': 'cloudpickle'}, + ]) # yapf: disable class LocalCombineFnLifecycleTest(unittest.TestCase): def tearDown(self): CallSequenceEnforcingCombineFn.instances.clear() def test_combine(self): - run_combine(TestPipeline(runner=self.runner())) + test_options = PipelineOptions(flags=[f"--pickle_library={self.pickler}"]) + run_combine(TestPipeline(runner=self.runner(), options=test_options)) + self._assert_teardown_called() + + def test_combine_deepcopy_fails(self): + test_options = PipelineOptions(flags=[f"--pickle_library={self.pickler}"]) + run_combine_uncopyable_attr( + TestPipeline(runner=self.runner(), options=test_options)) self._assert_teardown_called() def test_non_liftable_combine(self): diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 91ca4c8e33c3..8a5bb00eeb98 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -3170,33 +3170,48 @@ def process(self, element): yield pvalue.TaggedOutput('hot', ((self._nonce % fanout, key), value)) class PreCombineFn(CombineFn): + def __init__(self): + # Deepcopy of the combine_fn to avoid sharing state between lifted + # stages when using cloudpickle. + try: + self._combine_fn_copy = copy.deepcopy(combine_fn) + except Exception: + self._combine_fn_copy = pickler.loads(pickler.dumps(combine_fn)) + + self.setup = self._combine_fn_copy.setup + self.create_accumulator = self._combine_fn_copy.create_accumulator + self.add_input = self._combine_fn_copy.add_input + self.merge_accumulators = self._combine_fn_copy.merge_accumulators + self.compact = self._combine_fn_copy.compact + self.teardown = self._combine_fn_copy.teardown + @staticmethod def extract_output(accumulator): # Boolean indicates this is an accumulator. return (True, accumulator) - setup = combine_fn.setup - create_accumulator = combine_fn.create_accumulator - add_input = combine_fn.add_input - merge_accumulators = combine_fn.merge_accumulators - compact = combine_fn.compact - teardown = combine_fn.teardown - class PostCombineFn(CombineFn): - @staticmethod - def add_input(accumulator, element): + def __init__(self): + # Deepcopy of the combine_fn to avoid sharing state between lifted + # stages when using cloudpickle. + try: + self._combine_fn_copy = copy.deepcopy(combine_fn) + except Exception: + self._combine_fn_copy = pickler.loads(pickler.dumps(combine_fn)) + + self.setup = self._combine_fn_copy.setup + self.create_accumulator = self._combine_fn_copy.create_accumulator + self.merge_accumulators = self._combine_fn_copy.merge_accumulators + self.compact = self._combine_fn_copy.compact + self.extract_output = self._combine_fn_copy.extract_output + self.teardown = self._combine_fn_copy.teardown + + def add_input(self, accumulator, element): is_accumulator, value = element if is_accumulator: - return combine_fn.merge_accumulators([accumulator, value]) + return self._combine_fn_copy.merge_accumulators([accumulator, value]) else: - return combine_fn.add_input(accumulator, value) - - setup = combine_fn.setup - create_accumulator = combine_fn.create_accumulator - merge_accumulators = combine_fn.merge_accumulators - compact = combine_fn.compact - extract_output = combine_fn.extract_output - teardown = combine_fn.teardown + return self._combine_fn_copy.add_input(accumulator, value) def StripNonce(nonce_key_value): (_, key), value = nonce_key_value