diff --git a/sdks/python/apache_beam/ml/inference/onnx_inference.py b/sdks/python/apache_beam/ml/inference/onnx_inference.py index e7af114ad43..4ac85645674 100644 --- a/sdks/python/apache_beam/ml/inference/onnx_inference.py +++ b/sdks/python/apache_beam/ml/inference/onnx_inference.py @@ -116,7 +116,15 @@ def load_model(self) -> ort.InferenceSession: # when path is remote, we should first load into memory then deserialize f = FileSystems.open(self._model_uri, "rb") model_proto = onnx.load(f) - model_proto_bytes = onnx._serialize(model_proto) + model_proto_bytes = model_proto + if not isinstance(model_proto, bytes): + if (hasattr(model_proto, "SerializeToString") and + callable(model_proto.SerializeToString)): + model_proto_bytes = model_proto.SerializeToString() + else: + raise TypeError( + "No SerializeToString method is detected on loaded model. " + + f"Type of model: {type(model_proto)}") ort_session = ort.InferenceSession( model_proto_bytes, sess_options=self._session_options,