diff --git a/aphrodite/common/config.py b/aphrodite/common/config.py index 4925f3ce7..e0e0ff210 100644 --- a/aphrodite/common/config.py +++ b/aphrodite/common/config.py @@ -466,10 +466,10 @@ def verify_async_output_proc(self, parallel_config, speculative_config, "with pipeline parallel") self.use_async_output_proc = False return - if device_config.device_type != "cuda": + if device_config.device_type not in ("cuda", "tpu"): logger.warning( - "Async output processing is only supported for CUDA." - " Disabling it for other platforms.") + "Async output processing is only supported for CUDA or TPU. " + "Disabling it for other platforms.") self.use_async_output_proc = False return if envs.APHRODITE_USE_RAY_SPMD_WORKER: diff --git a/aphrodite/task_handler/tpu_model_runner.py b/aphrodite/task_handler/tpu_model_runner.py index 5166d4baf..7e33eaa04 100644 --- a/aphrodite/task_handler/tpu_model_runner.py +++ b/aphrodite/task_handler/tpu_model_runner.py @@ -1,6 +1,7 @@ import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Type, Union) from unittest.mock import patch import numpy as np @@ -51,6 +52,8 @@ class ModelInputForTPU(ModelRunnerInputBase): num_samples: int best_of: List[int] seq_groups: List[List[int]] + virtual_engine: int = 0 + async_callback: Optional[Callable] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -520,27 +523,19 @@ def execute_model( raise ValueError( "TPUModelRunner does not support multi-step execution.") - def _execute_model(*args, clone: bool = False) -> torch.Tensor: + def _execute_model(*args): """Move input args from CPU to device and execute the model.""" - def _copy_to_device(x: torch.Tensor) -> torch.Tensor: - if clone: - # When x is a slice of a CPU tensor, XLA may copy the whole - # original tensor to TPU instead of only copying x. - # To avoid this, we copy x after cloning. - x = x.clone() - return x.to(self.device) - new_args = [] for arg in args: if isinstance(arg, torch.Tensor): - arg = _copy_to_device(arg) + arg = arg.to(self.device) elif isinstance(arg, AttentionMetadata): - arg.slot_mapping = _copy_to_device(arg.slot_mapping) + arg.slot_mapping = arg.slot_mapping.to(self.device) if getattr(arg, "block_tables", None) is not None: - arg.block_tables = _copy_to_device(arg.block_tables) + arg.block_tables = arg.block_tables.to(self.device) if getattr(arg, "context_lens", None) is not None: - arg.context_lens = _copy_to_device(arg.context_lens) + arg.context_lens = arg.context_lens.to(self.device) new_args.append(arg) return self.model(*new_args, is_prompt=is_prompt) @@ -567,13 +562,11 @@ def _copy_to_device(x: torch.Tensor) -> torch.Tensor: output_token_ids = _execute_model( model_input.token_ids[None, start_idx:end_idx], model_input.position_ids[None, start_idx:end_idx], - model_input.attn_metadata, - model_input.input_lens[i:i + 1], - model_input.t[i:i + 1], - model_input.p[i:i + 1], - model_input.num_samples, - kv_caches, - clone=True) + model_input.attn_metadata, model_input.input_lens[i:i + 1], + model_input.t[i:i + 1], model_input.p[i:i + 1], + model_input.num_samples, kv_caches) + if i == 0 and model_input.async_callback is not None: + model_input.async_callback() # Retrieve the outputs to CPU. next_token_ids += output_token_ids.cpu().tolist() start_idx = end_idx @@ -584,6 +577,8 @@ def _copy_to_device(x: torch.Tensor) -> torch.Tensor: model_input.attn_metadata, model_input.input_lens, model_input.t, model_input.p, model_input.num_samples, kv_caches) + if model_input.async_callback is not None: + model_input.async_callback() # Retrieve the outputs to CPU. next_token_ids = output_token_ids.cpu().tolist() @@ -626,6 +621,7 @@ def __init__(self, model: nn.Module): fullgraph=True, dynamic=False) super().__init__(compiled_callable) + def __call__(self, *args, is_prompt: bool, **kwargs): if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: # not fully compiled yet, or not using the custom dispatcher,