diff --git a/README.md b/README.md index d969468bc..787d21677 100644 --- a/README.md +++ b/README.md @@ -550,6 +550,25 @@ sent via this stream. See more details about these APIs in [grpc/\_client.py](src/python/library/tritonclient/grpc/_client.py). +For gRPC AsyncIO requests, an AsyncIO task wrapping an `infer()` coroutine can +be safely cancelled. + +```python + infer_task = asyncio.create_task(aio_client.infer(...)) + infer_task.cancel() +``` + +For gRPC AsyncIO streaming requests, `cancel()` can be called on the +asynchronous iterator returned by `stream_infer()` API. + +```python + responses_iterator = aio_client.stream_infer(...) + responses_iterator.cancel() +``` + +See more details about these APIs in +[grpc/aio/\__init__.py](src/python/library/tritonclient/grpc/aio/__init__.py). + See [request_cancellation](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/request_cancellation.md) in the server user-guide to learn about how this is handled on the server side. diff --git a/src/python/library/tritonclient/grpc/aio/__init__.py b/src/python/library/tritonclient/grpc/aio/__init__.py index fc5eaccdb..37414dacb 100755 --- a/src/python/library/tritonclient/grpc/aio/__init__.py +++ b/src/python/library/tritonclient/grpc/aio/__init__.py @@ -624,7 +624,7 @@ async def infer( except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) - async def stream_infer( + def stream_infer( self, inputs_iterator, stream_timeout=None, @@ -636,7 +636,7 @@ async def stream_infer( Parameters ---------- - inputs_iterator : async_generator + inputs_iterator : asynchronous iterator Async iterator that yields a dict(s) consists of the input parameters to the async_stream_infer function defined in tritonclient.grpc.InferenceServerClient. @@ -653,9 +653,15 @@ async def stream_infer( Returns ------- - async_generator + asynchronous iterator Yield tuple holding (InferResult, InferenceServerException) objects. + This object can be used to cancel the inference request like below: + ---------- + it = stream_infer(...) + ret = it.cancel() + ---------- + Raises ------ InferenceServerException @@ -708,14 +714,16 @@ async def _request_iterator(inputs_iterator): parameters=inputs["parameters"], ) - try: - response_iterator = self._client_stub.ModelStreamInfer( - _request_iterator(inputs_iterator), - metadata=metadata, - timeout=stream_timeout, - compression=_grpc_compression_type(compression_algorithm), - ) - async for response in response_iterator: + class _ResponseIterator: + def __init__(self, grpc_call, verbose): + self._grpc_call = grpc_call + self._verbose = verbose + + def __aiter__(self): + return self + + async def __anext__(self): + response = await self._grpc_call.__aiter__().__anext__() if self._verbose: print(response) result = error = None @@ -723,6 +731,18 @@ async def _request_iterator(inputs_iterator): error = InferenceServerException(msg=response.error_message) else: result = InferResult(response.infer_response) - yield (result, error) + return result, error + + def cancel(self): + return self._grpc_call.cancel() + + try: + grpc_call = self._client_stub.ModelStreamInfer( + _request_iterator(inputs_iterator), + metadata=metadata, + timeout=stream_timeout, + compression=_grpc_compression_type(compression_algorithm), + ) + return _ResponseIterator(grpc_call, self._verbose) except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error)