diff --git a/sdks/python/apache_beam/io/requestresponse.py b/sdks/python/apache_beam/io/requestresponse.py index 910e4570b4a8..706bce95f5ee 100644 --- a/sdks/python/apache_beam/io/requestresponse.py +++ b/sdks/python/apache_beam/io/requestresponse.py @@ -392,26 +392,30 @@ def get_write(self): """returns a PTransform that writes to the cache.""" pass + @property @abc.abstractmethod - def has_request_coder(self) -> bool: - """returns `True` if the request coder is present.""" + def request_coder(self): + """request coder to use with Cache.""" pass + @request_coder.setter @abc.abstractmethod - def set_request_coder(self, request_coder: coders.Coder): + def request_coder(self, request_coder: coders.Coder): """sets the request coder to use with Cache.""" pass + @property @abc.abstractmethod - def set_response_coder(self, response_coder: coders.Coder): - """sets the response coder to use with Cache.""" + def source_caller(self): + """Actual caller that is using the cache.""" pass + @source_caller.setter @abc.abstractmethod - def set_source_caller(self, caller: Caller): - """This method allows + def source_caller(self, caller: Caller): + """Sets the source caller for :class:`apache_beam.io.requestresponse.RequestResponseIO` to pull - cache requests from respective callers.""" + cache request key from respective callers.""" pass @@ -624,8 +628,8 @@ def expand( def ensure_coders_exist(request_coder): """checks if the coder exists to encode the request for caching.""" if not request_coder: - _LOGGER.warning( - 'need request coder to be able to use' + raise ValueError( + 'need request coder to be able to use ' 'Cache with RequestResponseIO.') @@ -687,23 +691,21 @@ def get_write(self): response_coder=self._response_coder, source_caller=self._source_caller) - def has_request_coder(self) -> bool: - """returns True if the request coder exists.""" - return self._request_coder is not None + @property + def source_caller(self): + return self._source_caller - def set_request_coder(self, request_coder: coders.Coder): - """sets the request coder to encode request for `RedisCache`.""" - if request_coder and not self._request_coder: - self._request_coder = request_coder + @source_caller.setter + def source_caller(self, source_caller: Caller): + self._source_caller = source_caller - def set_response_coder(self, response_coder: coders.Coder): - """sets the response coder to encode/decode response for `RedisCache`.""" - if response_coder and not self._response_coder: - self._response_coder = response_coder + @property + def request_coder(self): + return self._request_coder - def set_source_caller(self, caller: Caller[RequestT, ResponseT]): - """sets the actual caller using the `RedisCache`.""" - self._source_caller = caller + @request_coder.setter + def request_coder(self, request_coder: coders.Coder): + self._request_coder = request_coder class RequestResponseIO(beam.PTransform[beam.PCollection[RequestT], @@ -758,11 +760,11 @@ def expand( # TODO(riteshghorse): handle Throttle PTransforms when available. if self._cache: - self._cache.set_source_caller(caller=self._caller) + self._cache.source_caller = self._caller inputs = requests - if self._cache and self._cache.has_request_coder(): + if self._cache: # read from cache. outputs = inputs | self._cache.get_read() # filter responses that are None and send them to the Call transform @@ -794,7 +796,7 @@ def expand( should_backoff=self._should_backoff, repeater=self._repeater)) - if self._cache and self._cache.has_request_coder(): + if self._cache: # write to cache. _ = responses | self._cache.get_write() return (cached_responses, responses) | beam.Flatten() diff --git a/sdks/python/apache_beam/io/requestresponse_it_test.py b/sdks/python/apache_beam/io/requestresponse_it_test.py index 7a84ff208e2e..bd8c63dea587 100644 --- a/sdks/python/apache_beam/io/requestresponse_it_test.py +++ b/sdks/python/apache_beam/io/requestresponse_it_test.py @@ -274,6 +274,19 @@ def test_rrio_cache_miss_and_hit(self): | RequestResponseIO(caller, cache=cache) | beam.ParDo(ValidateCallerResponses())) + def test_rrio_no_coder_exception(self): + caller = FakeCallerForCache() + requests = ['beam', 'flink', 'spark'] + cache = RedisCache(self.host, self.port) + with self.assertRaises(ValueError): + test_pipeline = beam.Pipeline() + _ = ( + test_pipeline + | beam.Create(requests) + | RequestResponseIO(caller, cache=cache)) + res = test_pipeline.run() + res.wait_until_finish() + def tearDown(self) -> None: self.container.stop() diff --git a/sdks/python/apache_beam/transforms/enrichment.py b/sdks/python/apache_beam/transforms/enrichment.py index f2cb6c088629..93344835e930 100644 --- a/sdks/python/apache_beam/transforms/enrichment.py +++ b/sdks/python/apache_beam/transforms/enrichment.py @@ -151,7 +151,7 @@ def expand(self, # request for that row. request_coder = coders.StrUtf8Coder() if self._cache: - self._cache.set_request_coder(request_coder) + self._cache.request_coder = request_coder fetched_data = input_row | RequestResponseIO( caller=self._source_handler,