diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index 1eb60f6e4959..9e1d1e1b80dd 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,5 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run.", - "modification": 3 + "modification": 4 } diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py index 28890083d93e..e1ba4f49b8fd 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -17,6 +17,7 @@ # pytype: skip-file +import asyncio import logging import os import subprocess @@ -35,6 +36,7 @@ from apache_beam.ml.inference.base import ModelHandler from apache_beam.ml.inference.base import PredictionResult from apache_beam.utils import subprocess_server +from openai import AsyncOpenAI from openai import OpenAI try: @@ -94,6 +96,15 @@ def getVLLMClient(port) -> OpenAI: ) +def getAsyncVLLMClient(port) -> AsyncOpenAI: + openai_api_key = "EMPTY" + openai_api_base = f"http://localhost:{port}/v1" + return AsyncOpenAI( + api_key=openai_api_key, + base_url=openai_api_base, + ) + + class _VLLMModelServer(): def __init__(self, model_name: str, vllm_server_kwargs: Dict[str, str]): self._model_name = model_name @@ -184,6 +195,34 @@ def __init__( def load_model(self) -> _VLLMModelServer: return _VLLMModelServer(self._model_name, self._vllm_server_kwargs) + async def _async_run_inference( + self, + batch: Sequence[str], + model: _VLLMModelServer, + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + client = getAsyncVLLMClient(model.get_server_port()) + inference_args = inference_args or {} + async_predictions = [] + for prompt in batch: + try: + completion = client.completions.create( + model=self._model_name, prompt=prompt, **inference_args) + async_predictions.append(completion) + except Exception as e: + model.check_connectivity() + raise e + + predictions = [] + for p in async_predictions: + try: + predictions.append(await p) + except Exception as e: + model.check_connectivity() + raise e + + return [PredictionResult(x, y) for x, y in zip(batch, predictions)] + def run_inference( self, batch: Sequence[str], @@ -200,22 +239,7 @@ def run_inference( Returns: An Iterable of type PredictionResult. """ - client = getVLLMClient(model.get_server_port()) - inference_args = inference_args or {} - predictions = [] - # TODO(https://github.com/apache/beam/issues/32528): We should add support - # for taking in batches and doing a bunch of async calls. That will end up - # being more efficient when we can do in bundle batching. - for prompt in batch: - try: - completion = client.completions.create( - model=self._model_name, prompt=prompt, **inference_args) - predictions.append(completion) - except Exception as e: - model.check_connectivity() - raise e - - return [PredictionResult(x, y) for x, y in zip(batch, predictions)] + return asyncio.run(self._async_run_inference(batch, model, inference_args)) def share_model_across_processes(self) -> bool: return True @@ -272,28 +296,15 @@ def load_model(self) -> _VLLMModelServer: return _VLLMModelServer(self._model_name, self._vllm_server_kwargs) - def run_inference( + async def _async_run_inference( self, batch: Sequence[Sequence[OpenAIChatMessage]], model: _VLLMModelServer, inference_args: Optional[Dict[str, Any]] = None ) -> Iterable[PredictionResult]: - """Runs inferences on a batch of text strings. - - Args: - batch: A sequence of examples as OpenAI messages. - model: A _VLLMModelServer for connecting to the spun up server. - inference_args: Any additional arguments for an inference. - - Returns: - An Iterable of type PredictionResult. - """ - client = getVLLMClient(model.get_server_port()) + client = getAsyncVLLMClient(model.get_server_port()) inference_args = inference_args or {} - predictions = [] - # TODO(https://github.com/apache/beam/issues/32528): We should add support - # for taking in batches and doing a bunch of async calls. That will end up - # being more efficient when we can do in bundle batching. + async_predictions = [] for messages in batch: formatted = [] for message in messages: @@ -301,12 +312,38 @@ def run_inference( try: completion = client.chat.completions.create( model=self._model_name, messages=formatted, **inference_args) - predictions.append(completion) + async_predictions.append(completion) + except Exception as e: + model.check_connectivity() + raise e + + predictions = [] + for p in async_predictions: + try: + predictions.append(await p) except Exception as e: model.check_connectivity() raise e return [PredictionResult(x, y) for x, y in zip(batch, predictions)] + def run_inference( + self, + batch: Sequence[Sequence[OpenAIChatMessage]], + model: _VLLMModelServer, + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + """Runs inferences on a batch of text strings. + + Args: + batch: A sequence of examples as OpenAI messages. + model: A _VLLMModelServer for connecting to the spun up server. + inference_args: Any additional arguments for an inference. + + Returns: + An Iterable of type PredictionResult. + """ + return asyncio.run(self._async_run_inference(batch, model, inference_args)) + def share_model_across_processes(self) -> bool: return True