From aa7c88686ea6fcfa36c81daa01224ff48dd4979d Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Thu, 10 Oct 2024 11:34:46 +0000 Subject: [PATCH] 2024-10-10 nightly release (669a5eb7caac7b181a193ad39ab1d8e02b3f1a8b) --- torchrec/distributed/sharding_plan.py | 41 +++++++---- .../distributed/tests/test_sharding_plan.py | 72 +++++++++++++++++++ 2 files changed, 100 insertions(+), 13 deletions(-) diff --git a/torchrec/distributed/sharding_plan.py b/torchrec/distributed/sharding_plan.py index f5c245baa..6eb3468e8 100644 --- a/torchrec/distributed/sharding_plan.py +++ b/torchrec/distributed/sharding_plan.py @@ -11,7 +11,7 @@ import math import warnings -from typing import Callable, cast, Dict, List, Optional, Tuple, Type +from typing import Callable, cast, Dict, List, Optional, Tuple, Type, Union import torch from torch import distributed as dist, nn @@ -70,11 +70,12 @@ def placement( # TODO: Consolidate placement and placement_helper into one function. -def placement_helper(device_type: str, index: int = 0) -> str: +def placement_helper(device_type: str, index: int = 0, rank: int = 0) -> str: if device_type == "cpu": return f"rank:0/{device_type}" # cpu only use rank 0 - return f"rank:{index}/{device_type}:{index}" + result = f"rank:{rank}/{device_type}:{index}" + return result def calculate_shard_sizes_and_offsets( @@ -442,7 +443,7 @@ def _parameter_sharding_generator( local_size, device_type, sharder, - placements=([placement_helper(device, rank)] if device else None), + placements=([placement_helper(device, rank, rank)] if device else None), compute_kernel=compute_kernel, ) @@ -450,7 +451,7 @@ def _parameter_sharding_generator( def row_wise( - sizes_placement: Optional[Tuple[List[int], str]] = None + sizes_placement: Optional[Tuple[List[int], Union[str, List[str]]]] = None ) -> ParameterShardingGenerator: """ Returns a generator of ParameterShardingPlan for `ShardingType::ROW_WISE` for construct_module_sharding_plan. @@ -470,6 +471,12 @@ def row_wise( ) """ + if sizes_placement is not None and isinstance(sizes_placement[1], list): + assert len(sizes_placement[0]) == len( + sizes_placement[1] + ), "sizes_placement and device per placement (in case of sharding " + "across HBM and CPU host) must have the same length" + def _parameter_sharding_generator( param: nn.Parameter, local_size: int, @@ -507,6 +514,21 @@ def _parameter_sharding_generator( f"Cannot fit tensor of {rows, cols} into sizes_ranks_placements = {sizes_placement}" ) + index: int = 0 + placements: List[str] = [] + if sizes_placement is not None: + device_type = "" + for i in range(len(sizes_placement[0])): + if isinstance(sizes_placement[1], list): + device_type = sizes_placement[1][i] + placements.append(placement_helper(device_type, index, i)) + else: + device_type = str(sizes_placement[1]) + placements.append(placement_helper(device_type, index, i)) + + if device_type == "cuda": + index += 1 + return _get_parameter_sharding( param, ShardingType.ROW_WISE.value, @@ -514,14 +536,7 @@ def _parameter_sharding_generator( local_size, device_type, sharder, - placements=( - [ - placement_helper(sizes_placement[1], i) - for i in range(len(sizes_placement[0])) - ] - if sizes_placement - else None - ), + placements=placements if sizes_placement else None, compute_kernel=( EmbeddingComputeKernel.QUANT.value if sizes_placement else None ), diff --git a/torchrec/distributed/tests/test_sharding_plan.py b/torchrec/distributed/tests/test_sharding_plan.py index 93f24ff69..b51453391 100644 --- a/torchrec/distributed/tests/test_sharding_plan.py +++ b/torchrec/distributed/tests/test_sharding_plan.py @@ -623,6 +623,78 @@ def test_table_wise_set_device(self) -> None: "cpu", ) + def test_row_wise_set_heterogenous_device(self) -> None: + embedding_bag_config = [ + EmbeddingBagConfig( + name=f"table_{idx}", + feature_names=[f"feature_{idx}"], + embedding_dim=64, + num_embeddings=4096, + ) + for idx in range(2) + ] + module_sharding_plan = construct_module_sharding_plan( + EmbeddingBagCollection(tables=embedding_bag_config), + per_param_sharding={ + "table_0": row_wise( + sizes_placement=([2048, 1024, 1024], ["cpu", "cuda", "cuda"]) + ), + "table_1": row_wise( + sizes_placement=([2048, 1024, 1024], ["cpu", "cpu", "cpu"]) + ), + }, + local_size=1, + world_size=2, + device_type="cuda", + ) + + # Make sure per_param_sharding setting override the default device_type + device_table_0_shard_0 = ( + # pyre-ignore[16] + module_sharding_plan["table_0"] + .sharding_spec.shards[0] + .placement + ) + self.assertEqual( + device_table_0_shard_0.device().type, + "cpu", + ) + # cpu always has rank 0 + self.assertEqual( + device_table_0_shard_0.rank(), + 0, + ) + for i in range(1, 3): + device_table_0_shard_i = ( + module_sharding_plan["table_0"].sharding_spec.shards[i].placement + ) + self.assertEqual( + device_table_0_shard_i.device().type, + "cuda", + ) + # first rank is assigned to cpu so index = rank - 1 + self.assertEqual( + device_table_0_shard_i.device().index, + i - 1, + ) + self.assertEqual( + device_table_0_shard_i.rank(), + i, + ) + for i in range(3): + device_table_1_shard_i = ( + module_sharding_plan["table_1"].sharding_spec.shards[i].placement + ) + self.assertEqual( + device_table_1_shard_i.device().type, + "cpu", + ) + # cpu always has rank 0 + self.assertEqual( + device_table_1_shard_i.rank(), + 0, + ) + def test_column_wise(self) -> None: embedding_bag_config = [ EmbeddingBagConfig(