diff --git a/src/python/library/tritonclient/grpc/aio/__init__.py b/src/python/library/tritonclient/grpc/aio/__init__.py index 11aa083de..fc5eaccdb 100755 --- a/src/python/library/tritonclient/grpc/aio/__init__.py +++ b/src/python/library/tritonclient/grpc/aio/__init__.py @@ -112,6 +112,12 @@ def __init__( self._client_stub = service_pb2_grpc.GRPCInferenceServiceStub(self._channel) self._verbose = verbose + def _return_response(self, response, as_json): + if as_json: + return json.loads(MessageToJson(response, preserving_proto_field_name=True)) + else: + return response + async def __aenter__(self): return self @@ -198,12 +204,7 @@ async def get_server_metadata(self, headers=None, as_json=False): ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -225,12 +226,7 @@ async def get_model_metadata( ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -252,12 +248,7 @@ async def get_model_config( ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -277,12 +268,7 @@ async def get_model_repository_index(self, headers=None, as_json=False): ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -349,12 +335,7 @@ async def get_inference_statistics( ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -384,12 +365,7 @@ async def update_trace_settings( ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -407,12 +383,7 @@ async def get_trace_settings(self, model_name=None, headers=None, as_json=False) ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -439,12 +410,7 @@ async def update_log_settings(self, settings, headers=None, as_json=False): ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -460,12 +426,7 @@ async def get_log_settings(self, headers=None, as_json=False): ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -487,12 +448,7 @@ async def get_system_shared_memory_status( ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -562,12 +518,7 @@ async def get_cuda_shared_memory_status( ) if self._verbose: print(response) - if as_json: - return json.loads( - MessageToJson(response, preserving_proto_field_name=True) - ) - else: - return response + return self._return_response(response, as_json) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error)