From 1bc793f79b082e864dc815bca1214e68d437d8bc Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Mon, 23 Dec 2024 14:37:36 +0000 Subject: [PATCH] tpu: support single and multi-host TPUs on GKE and RayServe --- aphrodite/attention/backends/pallas.py | 4 ++- .../device_communicators/tpu_communicator.py | 23 ++++++++++++++-- aphrodite/executor/ray_tpu_executor.py | 15 +++++++++++ aphrodite/executor/ray_utils.py | 27 +++++++++++++++++++ requirements-tpu.txt | 2 +- 5 files changed, 67 insertions(+), 4 deletions(-) diff --git a/aphrodite/attention/backends/pallas.py b/aphrodite/attention/backends/pallas.py index 3463f8584..586b3ac03 100644 --- a/aphrodite/attention/backends/pallas.py +++ b/aphrodite/attention/backends/pallas.py @@ -125,7 +125,9 @@ def __init__( raise NotImplementedError("TPU version must be 4 or higher.") self.megacore_mode = None - tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower() + tpu_env = torch_xla.tpu.get_tpu_env() + tpu_type = tpu_env.get("TYPE") or tpu_env.get("ACCELERATOR_TYPE") + tpu_type = tpu_type.lower() if "lite" not in tpu_type: if self.num_kv_heads % 2 == 0: self.megacore_mode = "kv_head" diff --git a/aphrodite/distributed/device_communicators/tpu_communicator.py b/aphrodite/distributed/device_communicators/tpu_communicator.py index 1e6ebe547..e67c1b983 100644 --- a/aphrodite/distributed/device_communicators/tpu_communicator.py +++ b/aphrodite/distributed/device_communicators/tpu_communicator.py @@ -1,3 +1,5 @@ +import os + import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -5,11 +7,12 @@ from aphrodite.platforms import current_platform if current_platform.is_tpu(): - import ray import torch_xla.core.xla_model as xm import torch_xla.runtime as xr from torch_xla._internal import pjrt + from aphrodite.executor import ray_utils + class TpuCommunicator: @@ -24,9 +27,25 @@ def __init__(self, group: ProcessGroup): # size can be simply calculated as follows. global_rank = dist.get_rank(group) global_world_size = dist.get_world_size(group) - num_nodes = len(ray.nodes()) + # Calculate how many TPU nodes are in the current deployment. This + # is the Ray placement group if it is deployed with Ray. Default + # to the number of TPU nodes in the Ray cluster. The number of TPU + # nodes is computed by the total number of TPUs divided by the + # number of TPU accelerators per node, to account for clusters + # with both CPUs and TPUs. + num_nodes = ray_utils.get_num_tpu_nodes() + num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() + if num_nodes_in_pg > 0: + num_nodes = num_nodes_in_pg local_world_size = global_world_size // num_nodes local_rank = global_rank % local_world_size + # Ensure environment variables are set for multihost deployments. + # On GKE, this is needed for libtpu and TPU driver to know which TPU + # chip is actually visible. Otherwise the TPU driver will fail to + # initialize because the number of devices would be different from + # the number of visible worker addresses. + os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank) + os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank) pjrt.initialize_multiprocess(local_rank, local_world_size) xr._init_world_size_ordinal() diff --git a/aphrodite/executor/ray_tpu_executor.py b/aphrodite/executor/ray_tpu_executor.py index 3cf4c98fe..9c24ed3df 100644 --- a/aphrodite/executor/ray_tpu_executor.py +++ b/aphrodite/executor/ray_tpu_executor.py @@ -73,6 +73,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", worker_module_name = "aphrodite.task_handler.tpu_worker" worker_class_name = "TPUWorker" + # GKE does not fetch environment information from metadata server + # and instead sets these from within the Ray process. Therefore we + # need to override the Ray environment variables manually. + override_env = {} + if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ: + override_env.update({ + "TPU_CHIPS_PER_HOST_BOUNDS": + os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] + }) + if "TPU_HOST_BOUNDS" in os.environ: + override_env.update( + {"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]}) + worker = ray.remote( num_cpus=0, resources={"TPU": 1}, @@ -83,6 +96,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", worker_class_name=worker_class_name, trust_remote_code=self.model_config.trust_remote_code, ) + if override_env: + worker.override_env_vars.remote(override_env) worker_ip = ray.get(worker.get_node_ip.remote()) if worker_ip == driver_ip and self.driver_dummy_worker is None: diff --git a/aphrodite/executor/ray_utils.py b/aphrodite/executor/ray_utils.py index 4dd5881bc..4431921d2 100644 --- a/aphrodite/executor/ray_utils.py +++ b/aphrodite/executor/ray_utils.py @@ -1,3 +1,4 @@ +import os import time from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union @@ -80,6 +81,9 @@ def execute_model_spmd( output = self.output_encoder.encode(output) return output + def override_env_vars(self, vars: Dict[str, str]): + os.environ.update(vars) + ray_import_err = None except ImportError as e: @@ -139,6 +143,7 @@ def _verify_bundles(placement_group: "PlacementGroup", "sure you have more than " f"than {parallel_config.tensor_parallel_size} GPUs available " "at each node.") + def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): """Wait until a placement group is ready. It prints the informative log messages if the placement group is @@ -271,3 +276,25 @@ def initialize_ray_cluster( _verify_bundles(current_placement_group, parallel_config, device_str) # Set the placement group in the parallel config parallel_config.placement_group = current_placement_group + + +def get_num_tpu_nodes() -> int: + from ray._private.accelerators import TPUAcceleratorManager + cluster_resources = ray.cluster_resources() + total_tpus = int(cluster_resources["TPU"]) + tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators() + assert total_tpus % tpus_per_node == 0 + return total_tpus // tpus_per_node + +def get_num_nodes_in_placement_group() -> int: + pg_table = ray.util.placement_group_table() + current_pg = ray.util.get_current_placement_group() + num_nodes = 0 + if current_pg: + nodes_in_pg = set() + for pg_key, pg in pg_table.items(): + if pg_key == current_pg.id.hex(): + for _, node in pg["bundles_to_node_id"].items(): + nodes_in_pg.add(node) + num_nodes = len(nodes_in_pg) + return num_nodes diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 45914fef8..b57210754 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -5,4 +5,4 @@ # Dependencies for TPU # Currently, the TPU backend uses a nightly version of PyTorch XLA. # You can install the dependencies in Dockerfile.tpu. -ray \ No newline at end of file +ray[default] \ No newline at end of file