Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tpu: support single and multi-host TPUs on GKE and RayServe #970

Merged
merged 1 commit into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion aphrodite/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def __init__(
raise NotImplementedError("TPU version must be 4 or higher.")

self.megacore_mode = None
tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower()
tpu_env = torch_xla.tpu.get_tpu_env()
tpu_type = tpu_env.get("TYPE") or tpu_env.get("ACCELERATOR_TYPE")
tpu_type = tpu_type.lower()
if "lite" not in tpu_type:
if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head"
Expand Down
23 changes: 21 additions & 2 deletions aphrodite/distributed/device_communicators/tpu_communicator.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import os

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from aphrodite.platforms import current_platform

if current_platform.is_tpu():
import ray
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt

from aphrodite.executor import ray_utils


class TpuCommunicator:

Expand All @@ -24,9 +27,25 @@ def __init__(self, group: ProcessGroup):
# size can be simply calculated as follows.
global_rank = dist.get_rank(group)
global_world_size = dist.get_world_size(group)
num_nodes = len(ray.nodes())
# Calculate how many TPU nodes are in the current deployment. This
# is the Ray placement group if it is deployed with Ray. Default
# to the number of TPU nodes in the Ray cluster. The number of TPU
# nodes is computed by the total number of TPUs divided by the
# number of TPU accelerators per node, to account for clusters
# with both CPUs and TPUs.
num_nodes = ray_utils.get_num_tpu_nodes()
num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group()
if num_nodes_in_pg > 0:
num_nodes = num_nodes_in_pg
local_world_size = global_world_size // num_nodes
local_rank = global_rank % local_world_size
# Ensure environment variables are set for multihost deployments.
# On GKE, this is needed for libtpu and TPU driver to know which TPU
# chip is actually visible. Otherwise the TPU driver will fail to
# initialize because the number of devices would be different from
# the number of visible worker addresses.
os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank)
os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank)
pjrt.initialize_multiprocess(local_rank, local_world_size)
xr._init_world_size_ordinal()

Expand Down
15 changes: 15 additions & 0 deletions aphrodite/executor/ray_tpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
worker_module_name = "aphrodite.task_handler.tpu_worker"
worker_class_name = "TPUWorker"

# GKE does not fetch environment information from metadata server
# and instead sets these from within the Ray process. Therefore we
# need to override the Ray environment variables manually.
override_env = {}
if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ:
override_env.update({
"TPU_CHIPS_PER_HOST_BOUNDS":
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"]
})
if "TPU_HOST_BOUNDS" in os.environ:
override_env.update(
{"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]})

worker = ray.remote(
num_cpus=0,
resources={"TPU": 1},
Expand All @@ -83,6 +96,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
if override_env:
worker.override_env_vars.remote(override_env)

worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
Expand Down
27 changes: 27 additions & 0 deletions aphrodite/executor/ray_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import time
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -80,6 +81,9 @@ def execute_model_spmd(
output = self.output_encoder.encode(output)
return output

def override_env_vars(self, vars: Dict[str, str]):
os.environ.update(vars)

ray_import_err = None

except ImportError as e:
Expand Down Expand Up @@ -139,6 +143,7 @@ def _verify_bundles(placement_group: "PlacementGroup",
"sure you have more than "
f"than {parallel_config.tensor_parallel_size} GPUs available "
"at each node.")

def _wait_until_pg_ready(current_placement_group: "PlacementGroup"):
"""Wait until a placement group is ready.
It prints the informative log messages if the placement group is
Expand Down Expand Up @@ -271,3 +276,25 @@ def initialize_ray_cluster(
_verify_bundles(current_placement_group, parallel_config, device_str)
# Set the placement group in the parallel config
parallel_config.placement_group = current_placement_group


def get_num_tpu_nodes() -> int:
from ray._private.accelerators import TPUAcceleratorManager
cluster_resources = ray.cluster_resources()
total_tpus = int(cluster_resources["TPU"])
tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators()
assert total_tpus % tpus_per_node == 0
return total_tpus // tpus_per_node

def get_num_nodes_in_placement_group() -> int:
pg_table = ray.util.placement_group_table()
current_pg = ray.util.get_current_placement_group()
num_nodes = 0
if current_pg:
nodes_in_pg = set()
for pg_key, pg in pg_table.items():
if pg_key == current_pg.id.hex():
for _, node in pg["bundles_to_node_id"].items():
nodes_in_pg.add(node)
num_nodes = len(nodes_in_pg)
return num_nodes
2 changes: 1 addition & 1 deletion requirements-tpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
# Dependencies for TPU
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
# You can install the dependencies in Dockerfile.tpu.
ray
ray[default]
Loading