Skip to content

Commit

Permalink
tpu: add support for async postprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale committed Dec 23, 2024
1 parent a8bdd48 commit 7526df4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 24 deletions.
6 changes: 3 additions & 3 deletions aphrodite/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,10 @@ def verify_async_output_proc(self, parallel_config, speculative_config,
"with pipeline parallel")
self.use_async_output_proc = False
return
if device_config.device_type != "cuda":
if device_config.device_type not in ("cuda", "tpu"):
logger.warning(
"Async output processing is only supported for CUDA."
" Disabling it for other platforms.")
"Async output processing is only supported for CUDA or TPU. "
"Disabling it for other platforms.")
self.use_async_output_proc = False
return
if envs.APHRODITE_USE_RAY_SPMD_WORKER:
Expand Down
38 changes: 17 additions & 21 deletions aphrodite/task_handler/tpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Type, Union)
from unittest.mock import patch

import numpy as np
Expand Down Expand Up @@ -51,6 +52,8 @@ class ModelInputForTPU(ModelRunnerInputBase):
num_samples: int
best_of: List[int]
seq_groups: List[List[int]]
virtual_engine: int = 0
async_callback: Optional[Callable] = None

def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
Expand Down Expand Up @@ -520,27 +523,19 @@ def execute_model(
raise ValueError(
"TPUModelRunner does not support multi-step execution.")

def _execute_model(*args, clone: bool = False) -> torch.Tensor:
def _execute_model(*args):
"""Move input args from CPU to device and execute the model."""

def _copy_to_device(x: torch.Tensor) -> torch.Tensor:
if clone:
# When x is a slice of a CPU tensor, XLA may copy the whole
# original tensor to TPU instead of only copying x.
# To avoid this, we copy x after cloning.
x = x.clone()
return x.to(self.device)

new_args = []
for arg in args:
if isinstance(arg, torch.Tensor):
arg = _copy_to_device(arg)
arg = arg.to(self.device)
elif isinstance(arg, AttentionMetadata):
arg.slot_mapping = _copy_to_device(arg.slot_mapping)
arg.slot_mapping = arg.slot_mapping.to(self.device)
if getattr(arg, "block_tables", None) is not None:
arg.block_tables = _copy_to_device(arg.block_tables)
arg.block_tables = arg.block_tables.to(self.device)
if getattr(arg, "context_lens", None) is not None:
arg.context_lens = _copy_to_device(arg.context_lens)
arg.context_lens = arg.context_lens.to(self.device)
new_args.append(arg)
return self.model(*new_args, is_prompt=is_prompt)

Expand All @@ -567,13 +562,11 @@ def _copy_to_device(x: torch.Tensor) -> torch.Tensor:
output_token_ids = _execute_model(
model_input.token_ids[None, start_idx:end_idx],
model_input.position_ids[None, start_idx:end_idx],
model_input.attn_metadata,
model_input.input_lens[i:i + 1],
model_input.t[i:i + 1],
model_input.p[i:i + 1],
model_input.num_samples,
kv_caches,
clone=True)
model_input.attn_metadata, model_input.input_lens[i:i + 1],
model_input.t[i:i + 1], model_input.p[i:i + 1],
model_input.num_samples, kv_caches)
if i == 0 and model_input.async_callback is not None:
model_input.async_callback()
# Retrieve the outputs to CPU.
next_token_ids += output_token_ids.cpu().tolist()
start_idx = end_idx
Expand All @@ -584,6 +577,8 @@ def _copy_to_device(x: torch.Tensor) -> torch.Tensor:
model_input.attn_metadata, model_input.input_lens,
model_input.t, model_input.p, model_input.num_samples,
kv_caches)
if model_input.async_callback is not None:
model_input.async_callback()
# Retrieve the outputs to CPU.
next_token_ids = output_token_ids.cpu().tolist()

Expand Down Expand Up @@ -626,6 +621,7 @@ def __init__(self, model: nn.Module):
fullgraph=True,
dynamic=False)
super().__init__(compiled_callable)

def __call__(self, *args, is_prompt: bool, **kwargs):
if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher:
# not fully compiled yet, or not using the custom dispatcher,
Expand Down

0 comments on commit 7526df4

Please sign in to comment.