From 63d89e56d70983b3033cdbda17f7e996179d4810 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian Date: Thu, 27 Jul 2023 11:51:04 -0400 Subject: [PATCH 1/4] Refactor as_json in asyncio client --- .../library/tritonclient/grpc/aio/__init__.py | 83 ++++--------------- 1 file changed, 17 insertions(+), 66 deletions(-) diff --git a/src/python/library/tritonclient/grpc/aio/__init__.py b/src/python/library/tritonclient/grpc/aio/__init__.py index 11aa083de..4f9a5ea0b 100755 --- a/src/python/library/tritonclient/grpc/aio/__init__.py +++ b/src/python/library/tritonclient/grpc/aio/__init__.py @@ -112,6 +112,12 @@ def __init__( self._client_stub = service_pb2_grpc.GRPCInferenceServiceStub(self._channel) self._verbose = verbose + def _return_response(self, response, as_json): + if as_json: + return json.loads(MessageToJson(response, preserving_proto_field_name=True)) + else: + return response + async def __aenter__(self): return self @@ -198,12 +204,7 @@ async def get_server_metadata(self, headers=None, as_json=False): ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -225,12 +226,7 @@ async def get_model_metadata( ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -252,12 +248,7 @@ async def get_model_config( ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -277,12 +268,7 @@ async def get_model_repository_index(self, headers=None, as_json=False): ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -349,12 +335,7 @@ async def get_inference_statistics( ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -384,12 +365,7 @@ async def update_trace_settings( ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -407,12 +383,7 @@ async def get_trace_settings(self, model_name=None, headers=None, as_json=False) ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -439,12 +410,7 @@ async def update_log_settings(self, settings, headers=None, as_json=False): ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -460,12 +426,7 @@ async def get_log_settings(self, headers=None, as_json=False): ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -487,12 +448,7 @@ async def get_system_shared_memory_status( ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -562,12 +518,7 @@ async def get_cuda_shared_memory_status( ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) From 821297dca61210528bc7bf0de9bde0759e5906a3 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian Date: Thu, 27 Jul 2023 11:52:35 -0400 Subject: [PATCH 2/4] Add unit testing for asyncio client --- src/python/library/tests/test_grpc_asyncio.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 src/python/library/tests/test_grpc_asyncio.py diff --git a/src/python/library/tests/test_grpc_asyncio.py b/src/python/library/tests/test_grpc_asyncio.py new file mode 100644 index 000000000..6757fb945 --- /dev/null +++ b/src/python/library/tests/test_grpc_asyncio.py @@ -0,0 +1,89 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import unittest + +from tritonclient.grpc.aio import InferenceServerClient + + +class GRPCAsyncIOTest(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + await super().asyncSetUp() + self._client = InferenceServerClient(url="localhost:8001") + self._model_name = "resnet50" + + async def test_server_live(self): + self.assertTrue(await self._client.is_server_live()) + + async def test_server_ready(self): + self.assertTrue(await self._client.is_server_ready()) + + async def test_is_model_ready(self): + self.assertTrue(await self._client.is_model_ready(self._model_name)) + + async def test_get_server_metadata(self): + server_metadata = await self._client.get_server_metadata() + self.assertIn("trace", server_metadata.extensions) + + server_metadata = await self._client.get_server_metadata(as_json=True) + self.assertIn("trace", server_metadata["extensions"]) + + async def test_get_model_metadata(self): + model_metadata = await self._client.get_model_metadata(self._model_name) + self.assertEqual(model_metadata.name, self._model_name) + + model_metadata = await self._client.get_model_metadata( + self._model_name, as_json=True + ) + self.assertEqual(model_metadata["name"], self._model_name) + + async def test_get_model_config(self): + model_config = await self._client.get_model_config(self._model_name) + self.assertEqual(model_config.config.name, self._model_name) + + async def test_get_model_repository_index(self): + models = await self._client.get_model_repository_index() + for model in models.models: + if model.name == self._model_name: + break + else: + self.assertFalse( + f"Failed to find model ({self._model_name}) in the list of models." + ) + + async def test_model_load_unload(self): + self.assertTrue(await self._client.is_model_ready(self._model_name)) + await self._client.unload_model(self._model_name) + self.assertFalse(await self._client.is_model_ready(self._model_name)) + await self._client.load_model(self._model_name) + self.assertTrue(await self._client.is_model_ready(self._model_name)) + + async def test_get_inference_statistics(self): + statistics = await self._client.get_inference_statistics(self._model_name) + self.assertEqual(statistics.model_stats[0].name, self._model_name) + + +if __name__ == "__main__": + unittest.main() From c0bfeeb44e4ebd40d2019a1a7ec000f8f8739531 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian Date: Fri, 28 Jul 2023 13:05:21 -0400 Subject: [PATCH 3/4] Remove the testing since it already exists --- src/python/library/tests/test_grpc_asyncio.py | 89 ------------------- 1 file changed, 89 deletions(-) delete mode 100644 src/python/library/tests/test_grpc_asyncio.py diff --git a/src/python/library/tests/test_grpc_asyncio.py b/src/python/library/tests/test_grpc_asyncio.py deleted file mode 100644 index 6757fb945..000000000 --- a/src/python/library/tests/test_grpc_asyncio.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import unittest - -from tritonclient.grpc.aio import InferenceServerClient - - -class GRPCAsyncIOTest(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - self._client = InferenceServerClient(url="localhost:8001") - self._model_name = "resnet50" - - async def test_server_live(self): - self.assertTrue(await self._client.is_server_live()) - - async def test_server_ready(self): - self.assertTrue(await self._client.is_server_ready()) - - async def test_is_model_ready(self): - self.assertTrue(await self._client.is_model_ready(self._model_name)) - - async def test_get_server_metadata(self): - server_metadata = await self._client.get_server_metadata() - self.assertIn("trace", server_metadata.extensions) - - server_metadata = await self._client.get_server_metadata(as_json=True) - self.assertIn("trace", server_metadata["extensions"]) - - async def test_get_model_metadata(self): - model_metadata = await self._client.get_model_metadata(self._model_name) - self.assertEqual(model_metadata.name, self._model_name) - - model_metadata = await self._client.get_model_metadata( - self._model_name, as_json=True - ) - self.assertEqual(model_metadata["name"], self._model_name) - - async def test_get_model_config(self): - model_config = await self._client.get_model_config(self._model_name) - self.assertEqual(model_config.config.name, self._model_name) - - async def test_get_model_repository_index(self): - models = await self._client.get_model_repository_index() - for model in models.models: - if model.name == self._model_name: - break - else: - self.assertFalse( - f"Failed to find model ({self._model_name}) in the list of models." - ) - - async def test_model_load_unload(self): - self.assertTrue(await self._client.is_model_ready(self._model_name)) - await self._client.unload_model(self._model_name) - self.assertFalse(await self._client.is_model_ready(self._model_name)) - await self._client.load_model(self._model_name) - self.assertTrue(await self._client.is_model_ready(self._model_name)) - - async def test_get_inference_statistics(self): - statistics = await self._client.get_inference_statistics(self._model_name) - self.assertEqual(statistics.model_stats[0].name, self._model_name) - - -if __name__ == "__main__": - unittest.main() From 2c150b767f273655e0c330e040f95a637c1d76ad Mon Sep 17 00:00:00 2001 From: Iman Tabrizian Date: Fri, 28 Jul 2023 13:10:47 -0400 Subject: [PATCH 4/4] Fix up --- src/python/library/tritonclient/grpc/aio/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/library/tritonclient/grpc/aio/__init__.py b/src/python/library/tritonclient/grpc/aio/__init__.py index 4f9a5ea0b..fc5eaccdb 100755 --- a/src/python/library/tritonclient/grpc/aio/__init__.py +++ b/src/python/library/tritonclient/grpc/aio/__init__.py @@ -204,7 +204,7 @@ async def get_server_metadata(self, headers=None, as_json=False): ) if self._verbose: print(response) - return self._return_response(response) + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error)