Skip to content

Commit

Permalink
Fix stateful processing using direct runner with type checks enabled (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
sadovnychyi authored Aug 10, 2023
1 parent 49ed58f commit 6f60a6c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
19 changes: 19 additions & 0 deletions sdks/python/apache_beam/transforms/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from apache_beam.metrics import MetricsFilter
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.options.pipeline_options import TypeOptions
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.pvalue import AsList
Expand Down Expand Up @@ -1041,6 +1042,24 @@ def test_output_typehints(self):
ShardedKeyType[typehints.Tuple[int, int]], # type: ignore[misc]
typehints.Iterable[str]])

def test_runtime_type_check(self):
options = PipelineOptions()
options.view_as(TypeOptions).runtime_type_check = True
with TestPipeline(options=options) as pipeline:
collection = (
pipeline
| beam.Create(GroupIntoBatchesTest._create_test_data())
| util.GroupIntoBatches(GroupIntoBatchesTest.BATCH_SIZE))
num_batches = collection | beam.combiners.Count.Globally()
assert_that(
num_batches,
equal_to([
int(
math.ceil(
GroupIntoBatchesTest.NUM_ELEMENTS /
GroupIntoBatchesTest.BATCH_SIZE))
]))

def _test_runner_api_round_trip(self, transform, urn):
context = pipeline_context.PipelineContext()
proto = transform.to_runner_api(context)
Expand Down
7 changes: 7 additions & 0 deletions sdks/python/apache_beam/typehints/typecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ def __init__(self, dofn):
super().__init__()
self.dofn = dofn

def __getattribute__(self, name):
if (name.startswith('_') or name in self.__dict__ or
hasattr(type(self), name)):
return object.__getattribute__(self, name)
else:
return getattr(self.dofn, name)

def _inspect_start_bundle(self):
return self.dofn.get_function_arguments('start_bundle')

Expand Down

0 comments on commit 6f60a6c

Please sign in to comment.