Skip to content

Commit

Permalink
vLLM model handler efficiency improvements (#32687)
Browse files Browse the repository at this point in the history
* vLLM model handler efficiency improvements

* fmt

* Remove bad exceptions

* lint

* lint
  • Loading branch information
damccorm authored Oct 15, 2024
1 parent 89dd088 commit 06ecee9
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .github/trigger_files/beam_PostCommit_Python.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run.",
"modification": 3
"modification": 4
}

103 changes: 70 additions & 33 deletions sdks/python/apache_beam/ml/inference/vllm_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# pytype: skip-file

import asyncio
import logging
import os
import subprocess
Expand All @@ -35,6 +36,7 @@
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.utils import subprocess_server
from openai import AsyncOpenAI
from openai import OpenAI

try:
Expand Down Expand Up @@ -94,6 +96,15 @@ def getVLLMClient(port) -> OpenAI:
)


def getAsyncVLLMClient(port) -> AsyncOpenAI:
openai_api_key = "EMPTY"
openai_api_base = f"http://localhost:{port}/v1"
return AsyncOpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)


class _VLLMModelServer():
def __init__(self, model_name: str, vllm_server_kwargs: Dict[str, str]):
self._model_name = model_name
Expand Down Expand Up @@ -184,6 +195,34 @@ def __init__(
def load_model(self) -> _VLLMModelServer:
return _VLLMModelServer(self._model_name, self._vllm_server_kwargs)

async def _async_run_inference(
self,
batch: Sequence[str],
model: _VLLMModelServer,
inference_args: Optional[Dict[str, Any]] = None
) -> Iterable[PredictionResult]:
client = getAsyncVLLMClient(model.get_server_port())
inference_args = inference_args or {}
async_predictions = []
for prompt in batch:
try:
completion = client.completions.create(
model=self._model_name, prompt=prompt, **inference_args)
async_predictions.append(completion)
except Exception as e:
model.check_connectivity()
raise e

predictions = []
for p in async_predictions:
try:
predictions.append(await p)
except Exception as e:
model.check_connectivity()
raise e

return [PredictionResult(x, y) for x, y in zip(batch, predictions)]

def run_inference(
self,
batch: Sequence[str],
Expand All @@ -200,22 +239,7 @@ def run_inference(
Returns:
An Iterable of type PredictionResult.
"""
client = getVLLMClient(model.get_server_port())
inference_args = inference_args or {}
predictions = []
# TODO(https://github.com/apache/beam/issues/32528): We should add support
# for taking in batches and doing a bunch of async calls. That will end up
# being more efficient when we can do in bundle batching.
for prompt in batch:
try:
completion = client.completions.create(
model=self._model_name, prompt=prompt, **inference_args)
predictions.append(completion)
except Exception as e:
model.check_connectivity()
raise e

return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
return asyncio.run(self._async_run_inference(batch, model, inference_args))

def share_model_across_processes(self) -> bool:
return True
Expand Down Expand Up @@ -272,41 +296,54 @@ def load_model(self) -> _VLLMModelServer:

return _VLLMModelServer(self._model_name, self._vllm_server_kwargs)

def run_inference(
async def _async_run_inference(
self,
batch: Sequence[Sequence[OpenAIChatMessage]],
model: _VLLMModelServer,
inference_args: Optional[Dict[str, Any]] = None
) -> Iterable[PredictionResult]:
"""Runs inferences on a batch of text strings.
Args:
batch: A sequence of examples as OpenAI messages.
model: A _VLLMModelServer for connecting to the spun up server.
inference_args: Any additional arguments for an inference.
Returns:
An Iterable of type PredictionResult.
"""
client = getVLLMClient(model.get_server_port())
client = getAsyncVLLMClient(model.get_server_port())
inference_args = inference_args or {}
predictions = []
# TODO(https://github.com/apache/beam/issues/32528): We should add support
# for taking in batches and doing a bunch of async calls. That will end up
# being more efficient when we can do in bundle batching.
async_predictions = []
for messages in batch:
formatted = []
for message in messages:
formatted.append({"role": message.role, "content": message.content})
try:
completion = client.chat.completions.create(
model=self._model_name, messages=formatted, **inference_args)
predictions.append(completion)
async_predictions.append(completion)
except Exception as e:
model.check_connectivity()
raise e

predictions = []
for p in async_predictions:
try:
predictions.append(await p)
except Exception as e:
model.check_connectivity()
raise e

return [PredictionResult(x, y) for x, y in zip(batch, predictions)]

def run_inference(
self,
batch: Sequence[Sequence[OpenAIChatMessage]],
model: _VLLMModelServer,
inference_args: Optional[Dict[str, Any]] = None
) -> Iterable[PredictionResult]:
"""Runs inferences on a batch of text strings.
Args:
batch: A sequence of examples as OpenAI messages.
model: A _VLLMModelServer for connecting to the spun up server.
inference_args: Any additional arguments for an inference.
Returns:
An Iterable of type PredictionResult.
"""
return asyncio.run(self._async_run_inference(batch, model, inference_args))

def share_model_across_processes(self) -> bool:
return True

0 comments on commit 06ecee9

Please sign in to comment.