Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for Pathways proxy #690

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions axlearn/cloud/gcp/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import atexit
import importlib
import io
import logging
import math
Expand Down Expand Up @@ -392,6 +393,7 @@ class Config(GKEJob.Config):
enable_tpu_ici_resiliency: Optional[bool] = None
location_hint: Optional[str] = None
enable_tpu_smart_repair: bool = False
use_pathways: Optional[bool] = False

@classmethod
def define_flags(cls, fv: flags.FlagValues):
Expand All @@ -406,6 +408,9 @@ def define_flags(cls, fv: flags.FlagValues):
"not all TPU types support this flag.",
**common_kwargs,
)
flags.DEFINE_boolean(
"use_pathways", False, "Wether the workload is pathways-enabled.", **common_kwargs
)

@classmethod
def from_flags(cls, fv: flags.FlagValues, **kwargs) -> Config:
Expand All @@ -430,6 +435,14 @@ def __init__(self, cfg: Config):
super().__init__(cfg)
self._gcsfuse_volume = "gcs-fuse-csi-ephemeral"
self._output_volume_mount = dict(name="shared-output", mountPath="/output")
if cfg.use_pathways:
self._import_pathways()

def _import_pathways(self):
try:
importlib.import_module("pathwaysutils")
except ModuleNotFoundError:
logging.error("An error occurred while importing pathways-utils.")

def _build_container(self) -> Nested[Any]:
"""Builds a config for a single container.
Expand Down
4 changes: 3 additions & 1 deletion axlearn/common/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def setup():
logging.info("Devices: %s", devices)
local_devices = jax.local_devices()
logging.info("Local Devices: %s", local_devices)
if not devices or not all(device.platform == FLAGS.jax_backend for device in devices):
if FLAGS.jax_backend != "proxy" and (
not devices or not all(device.platform == FLAGS.jax_backend for device in devices)
):
raise RuntimeError(f"Expected backend {FLAGS.jax_backend}. Got {devices}.")
if FLAGS.data_dir:
# TODO(ruoming): Get rid of --data_dir and use only env var DATA_DIR.
Expand Down
8 changes: 5 additions & 3 deletions axlearn/common/utils_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def setup(
if initialization_timeout is not None:
init_kwargs["initialization_timeout"] = initialization_timeout

if jax_backend == "tpu":
# TPU resources orchestrated by Pathways use 'proxy' as the JAX backend
if jax_backend in ("tpu", "proxy"):
jesus-orozco marked this conversation as resolved.
Show resolved Hide resolved
if not (
distributed_coordinator is None and num_processes is None and process_id is None
):
Expand Down Expand Up @@ -115,5 +116,6 @@ def setup(
f"({initialization_timeout} seconds)."
)
else:
jax.distributed.initialize(**init_kwargs)
_jax_distributed_initialized = True
if jax_backend != "proxy":
jax.distributed.initialize(**init_kwargs)
_jax_distributed_initialized = True
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ audio = [
"levenshtein==0.25.1",
]

jesus-orozco marked this conversation as resolved.
Show resolved Hide resolved
# Pathways utilities.
pathways = [
"pathwaysutils@git+https://github.com/google/[email protected]", # for JAX+Pathways single-controller accelerator coordinator
]

[tool.flit.module]
# This defines the import name. https://flit.pypa.io/en/stable/pyproject_toml.html#module-section
name = "axlearn"
Expand Down