From de25f9da898e81c27e4ac0665a2516a2fd98c3e1 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Fri, 1 Nov 2024 12:03:59 -0400 Subject: [PATCH 1/6] Remove usage of deprecated _serialize --- .../python/apache_beam/ml/inference/onnx_inference.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/onnx_inference.py b/sdks/python/apache_beam/ml/inference/onnx_inference.py index e7af114ad431..5bd9e0eeb5ef 100644 --- a/sdks/python/apache_beam/ml/inference/onnx_inference.py +++ b/sdks/python/apache_beam/ml/inference/onnx_inference.py @@ -116,7 +116,16 @@ def load_model(self) -> ort.InferenceSession: # when path is remote, we should first load into memory then deserialize f = FileSystems.open(self._model_uri, "rb") model_proto = onnx.load(f) - model_proto_bytes = onnx._serialize(model_proto) + model_proto_bytes = model_proto + if not isinstance(model_proto, bytes): + if (hasattr(model_proto, "SerializeToString") + and callable(model_proto.SerializeToString)): + model_proto = model_proto.SerializeToString() + else: + raise TypeError( + f"No SerializeToString method is detected on loaded model. " + + f"Type of model: {type(model_proto)}" + ) ort_session = ort.InferenceSession( model_proto_bytes, sess_options=self._session_options, From b13f481d5bbf44f90f90c4cd66f145395ab9cf13 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Fri, 1 Nov 2024 12:58:44 -0400 Subject: [PATCH 2/6] Correct assignment --- sdks/python/apache_beam/ml/inference/onnx_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/onnx_inference.py b/sdks/python/apache_beam/ml/inference/onnx_inference.py index 5bd9e0eeb5ef..c241c72a451b 100644 --- a/sdks/python/apache_beam/ml/inference/onnx_inference.py +++ b/sdks/python/apache_beam/ml/inference/onnx_inference.py @@ -120,7 +120,7 @@ def load_model(self) -> ort.InferenceSession: if not isinstance(model_proto, bytes): if (hasattr(model_proto, "SerializeToString") and callable(model_proto.SerializeToString)): - model_proto = model_proto.SerializeToString() + model_proto_bytes = model_proto.SerializeToString() else: raise TypeError( f"No SerializeToString method is detected on loaded model. " From 6d277aacb8b865be90765220d3ccde2cff6ae5ad Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Fri, 1 Nov 2024 13:00:52 -0400 Subject: [PATCH 3/6] indentation --- .../apache_beam/ml/inference/onnx_inference.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/onnx_inference.py b/sdks/python/apache_beam/ml/inference/onnx_inference.py index c241c72a451b..17983f8d2d14 100644 --- a/sdks/python/apache_beam/ml/inference/onnx_inference.py +++ b/sdks/python/apache_beam/ml/inference/onnx_inference.py @@ -118,14 +118,14 @@ def load_model(self) -> ort.InferenceSession: model_proto = onnx.load(f) model_proto_bytes = model_proto if not isinstance(model_proto, bytes): - if (hasattr(model_proto, "SerializeToString") - and callable(model_proto.SerializeToString)): - model_proto_bytes = model_proto.SerializeToString() - else: - raise TypeError( - f"No SerializeToString method is detected on loaded model. " - + f"Type of model: {type(model_proto)}" - ) + if (hasattr(model_proto, "SerializeToString") + and callable(model_proto.SerializeToString)): + model_proto_bytes = model_proto.SerializeToString() + else: + raise TypeError( + f"No SerializeToString method is detected on loaded model. " + + f"Type of model: {type(model_proto)}" + ) ort_session = ort.InferenceSession( model_proto_bytes, sess_options=self._session_options, From 0eb36f0ec1b11db707c265cbd778c01960ebc875 Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Fri, 1 Nov 2024 14:47:35 -0400 Subject: [PATCH 4/6] lint --- sdks/python/apache_beam/ml/inference/onnx_inference.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/onnx_inference.py b/sdks/python/apache_beam/ml/inference/onnx_inference.py index 17983f8d2d14..e1ce47675a45 100644 --- a/sdks/python/apache_beam/ml/inference/onnx_inference.py +++ b/sdks/python/apache_beam/ml/inference/onnx_inference.py @@ -118,13 +118,13 @@ def load_model(self) -> ort.InferenceSession: model_proto = onnx.load(f) model_proto_bytes = model_proto if not isinstance(model_proto, bytes): - if (hasattr(model_proto, "SerializeToString") - and callable(model_proto.SerializeToString)): + if (hasattr(model_proto, "SerializeToString") and + callable(model_proto.SerializeToString)): model_proto_bytes = model_proto.SerializeToString() else: raise TypeError( - f"No SerializeToString method is detected on loaded model. " - + f"Type of model: {type(model_proto)}" + "No SerializeToString method is detected on loaded model. " + + f"Type of model: {type(model_proto)}" ) ort_session = ort.InferenceSession( model_proto_bytes, From 8c55d9ee94fd185286b44aa2a55d9de6fa5023f8 Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Fri, 1 Nov 2024 16:03:05 -0400 Subject: [PATCH 5/6] fmt --- sdks/python/apache_beam/ml/inference/onnx_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/onnx_inference.py b/sdks/python/apache_beam/ml/inference/onnx_inference.py index e1ce47675a45..91ba94c6c29d 100644 --- a/sdks/python/apache_beam/ml/inference/onnx_inference.py +++ b/sdks/python/apache_beam/ml/inference/onnx_inference.py @@ -123,8 +123,8 @@ def load_model(self) -> ort.InferenceSession: model_proto_bytes = model_proto.SerializeToString() else: raise TypeError( - "No SerializeToString method is detected on loaded model. " + - f"Type of model: {type(model_proto)}" + "No SerializeToString method is detected on loaded model. " + + f"Type of model: {type(model_proto)}" ) ort_session = ort.InferenceSession( model_proto_bytes, From d986a440ceee27c2c580818a1c63197f9f4875a4 Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Fri, 1 Nov 2024 16:37:30 -0400 Subject: [PATCH 6/6] fmt --- sdks/python/apache_beam/ml/inference/onnx_inference.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/onnx_inference.py b/sdks/python/apache_beam/ml/inference/onnx_inference.py index 91ba94c6c29d..4ac856456748 100644 --- a/sdks/python/apache_beam/ml/inference/onnx_inference.py +++ b/sdks/python/apache_beam/ml/inference/onnx_inference.py @@ -124,8 +124,7 @@ def load_model(self) -> ort.InferenceSession: else: raise TypeError( "No SerializeToString method is detected on loaded model. " + - f"Type of model: {type(model_proto)}" - ) + f"Type of model: {type(model_proto)}") ort_session = ort.InferenceSession( model_proto_bytes, sess_options=self._session_options,