diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index bdeeb609..7c394838 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -6,6 +6,7 @@ """ import atexit +import importlib import io import logging import math @@ -392,6 +393,7 @@ class Config(GKEJob.Config): enable_tpu_ici_resiliency: Optional[bool] = None location_hint: Optional[str] = None enable_tpu_smart_repair: bool = False + use_pathways: Optional[bool] = False @classmethod def define_flags(cls, fv: flags.FlagValues): @@ -406,6 +408,9 @@ def define_flags(cls, fv: flags.FlagValues): "not all TPU types support this flag.", **common_kwargs, ) + flags.DEFINE_boolean( + "use_pathways", False, "Wether the workload is pathways-enabled.", **common_kwargs + ) @classmethod def from_flags(cls, fv: flags.FlagValues, **kwargs) -> Config: @@ -430,6 +435,14 @@ def __init__(self, cfg: Config): super().__init__(cfg) self._gcsfuse_volume = "gcs-fuse-csi-ephemeral" self._output_volume_mount = dict(name="shared-output", mountPath="/output") + if cfg.use_pathways: + self._import_pathways() + + def _import_pathways(self): + try: + importlib.import_module("pathwaysutils") + except ModuleNotFoundError: + logging.error("An error occurred while importing pathways-utils.") def _build_container(self) -> Nested[Any]: """Builds a config for a single container. diff --git a/axlearn/common/launch.py b/axlearn/common/launch.py index 55a86aed..2966d1d3 100644 --- a/axlearn/common/launch.py +++ b/axlearn/common/launch.py @@ -103,7 +103,9 @@ def setup(): logging.info("Devices: %s", devices) local_devices = jax.local_devices() logging.info("Local Devices: %s", local_devices) - if not devices or not all(device.platform == FLAGS.jax_backend for device in devices): + if FLAGS.jax_backend != "proxy" and ( + not devices or not all(device.platform == FLAGS.jax_backend for device in devices) + ): raise RuntimeError(f"Expected backend {FLAGS.jax_backend}. Got {devices}.") if FLAGS.data_dir: # TODO(ruoming): Get rid of --data_dir and use only env var DATA_DIR. diff --git a/axlearn/common/utils_spmd.py b/axlearn/common/utils_spmd.py index 142da070..baebe413 100644 --- a/axlearn/common/utils_spmd.py +++ b/axlearn/common/utils_spmd.py @@ -53,7 +53,8 @@ def setup( if initialization_timeout is not None: init_kwargs["initialization_timeout"] = initialization_timeout - if jax_backend == "tpu": + # TPU resources orchestrated by Pathways use 'proxy' as the JAX backend + if jax_backend in ("tpu", "proxy"): if not ( distributed_coordinator is None and num_processes is None and process_id is None ): @@ -115,5 +116,6 @@ def setup( f"({initialization_timeout} seconds)." ) else: - jax.distributed.initialize(**init_kwargs) - _jax_distributed_initialized = True + if jax_backend != "proxy": + jax.distributed.initialize(**init_kwargs) + _jax_distributed_initialized = True diff --git a/pyproject.toml b/pyproject.toml index 4596fe4f..17d809ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -154,6 +154,11 @@ audio = [ "levenshtein==0.25.1", ] +# Pathways utilities. +pathways = [ + "pathwaysutils@git+https://github.com/google/pathways-utils@v0.0.5", # for JAX+Pathways single-controller accelerator coordinator +] + [tool.flit.module] # This defines the import name. https://flit.pypa.io/en/stable/pyproject_toml.html#module-section name = "axlearn"