diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 593dba0769b..78712640674 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -46,13 +46,13 @@ class DistributedCheckpointTestBase(test_xla_sharding_base.XlaShardingTest): def setUpClass(cls): super().setUpClass() - def _get_sharded_model(self, mesh_shape=None): + def _get_sharded_model(self, mesh_shape=None, pspec=(0, 1)): # Return a sharded SimpleLinear model with fc1.weight sharded and # fc2.weight explicitly replicated mesh_shape = mesh_shape or (1, self.n_devices) model = self.SimpleLinear().to(xm.xla_device()) mesh = self._get_mesh(mesh_shape) - xs.mark_sharding(model.fc1.weight, mesh, (0, 1)) + xs.mark_sharding(model.fc1.weight, mesh, pspec) xs.mark_sharding(model.fc2.weight, mesh, (None, None)) return model @@ -184,6 +184,20 @@ def test_resharding_transpose_device_mesh(self): save_planner=SPMDSavePlanner(), load_planner=SPMDLoadPlanner()) + @unittest.skipIf(xr.global_runtime_device_count() == 1, + "Multiple devices needed to change mesh") + def test_in_host_partial_replication(self): + dim = self.n_devices // 2 + model1 = self._get_sharded_model( + mesh_shape=(dim, self.n_devices // dim), pspec=(None, 0)) + model2 = self._get_sharded_model( + mesh_shape=(self.n_devices // dim, dim), pspec=(None, 0)) + self._save_and_restore( + model1, + model2, + save_planner=SPMDSavePlanner(), + load_planner=SPMDLoadPlanner()) + @unittest.skipIf(xr.global_runtime_device_count() == 1, "Multiple devices needed to change mesh") def test_padded_tensor(self): diff --git a/torch_xla/experimental/distributed_checkpoint/_helpers.py b/torch_xla/experimental/distributed_checkpoint/_helpers.py index 16a4b2181ee..52c91a9b68b 100644 --- a/torch_xla/experimental/distributed_checkpoint/_helpers.py +++ b/torch_xla/experimental/distributed_checkpoint/_helpers.py @@ -151,38 +151,6 @@ def flat_copy(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: return flattened, mappings -# TODO(jonbolin): Take a dependency on the upstream implementation when the APIs -# are stable. -# https://github.com/pytorch/pytorch/blob/d1cecd9c32ba700c27f2b0716bf2cbef41469495/torch/distributed/checkpoint/_dedup_tensors.py#L29 -def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]: - all_plans = list(all_plans) - key_to_plan: Dict[MetadataIndex, List[int]] = {} - for plan_idx, plan in enumerate(all_plans): - for write_item in plan.items: - key_to_plan.setdefault(write_item.index, []).append(plan_idx) - - replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1} - - # Remove duplicates by always keeping the first entry. - # Compute the per-rank remove set. - plan_to_keys: Dict[int, List[MetadataIndex]] = {} - for key, plans in replicated_items.items(): - for plan_idx in plans[1:]: - plan_to_keys.setdefault(plan_idx, []).append(key) - - for plan_idx, keys in plan_to_keys.items(): - key_set = set(keys) - # rewrite items and remove elements - new_items = [ - write_item for write_item in all_plans[plan_idx].items - if write_item.index not in key_set - ] - all_plans[plan_idx] = dataclasses.replace( - all_plans[plan_idx], items=new_items) - - return all_plans - - # TODO(jonbolin): Take a dependency on the upstream implementation when the APIs # are stable # https://github.com/pytorch/pytorch/blob/d1cecd9c32ba700c27f2b0716bf2cbef41469495/torch/distributed/_shard/_utils.py#L7 diff --git a/torch_xla/experimental/distributed_checkpoint/planners.py b/torch_xla/experimental/distributed_checkpoint/planners.py index 32fe987a97d..8d43194119e 100644 --- a/torch_xla/experimental/distributed_checkpoint/planners.py +++ b/torch_xla/experimental/distributed_checkpoint/planners.py @@ -36,9 +36,10 @@ from torch.utils._pytree import tree_map from torch_xla.distributed.spmd import XLAShardedTensor, XLAShard from torch_xla.experimental.distributed_checkpoint._helpers import ( - FLATTEN_MAPPING, flatten_state_dict, dedup_tensors, _is_sharded_tensor, - set_element, narrow_tensor_by_index, _unwrap_xla_sharded_tensor, _CpuShards) + FLATTEN_MAPPING, flatten_state_dict, _is_sharded_tensor, set_element, + narrow_tensor_by_index, _unwrap_xla_sharded_tensor, _CpuShards) from typing import Any, Dict, List, Tuple, Union +from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans class SPMDSavePlanner(SavePlanner): @@ -107,7 +108,7 @@ def create_local_plan(self) -> SavePlan: def create_global_plan( self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: # Deduplicate write items across plans - all_plans = dedup_tensors(all_plans) + all_plans = dedup_save_plans(all_plans) global_plan, metadata = create_default_global_save_plan( all_plans, rewrite_index_hints=False)