Skip to content

Commit

Permalink
2024-10-10 nightly release (669a5eb)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Oct 10, 2024
1 parent 9509ca0 commit aa7c886
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 13 deletions.
41 changes: 28 additions & 13 deletions torchrec/distributed/sharding_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import math
import warnings
from typing import Callable, cast, Dict, List, Optional, Tuple, Type
from typing import Callable, cast, Dict, List, Optional, Tuple, Type, Union

import torch
from torch import distributed as dist, nn
Expand Down Expand Up @@ -70,11 +70,12 @@ def placement(


# TODO: Consolidate placement and placement_helper into one function.
def placement_helper(device_type: str, index: int = 0) -> str:
def placement_helper(device_type: str, index: int = 0, rank: int = 0) -> str:
if device_type == "cpu":
return f"rank:0/{device_type}" # cpu only use rank 0

return f"rank:{index}/{device_type}:{index}"
result = f"rank:{rank}/{device_type}:{index}"
return result


def calculate_shard_sizes_and_offsets(
Expand Down Expand Up @@ -442,15 +443,15 @@ def _parameter_sharding_generator(
local_size,
device_type,
sharder,
placements=([placement_helper(device, rank)] if device else None),
placements=([placement_helper(device, rank, rank)] if device else None),
compute_kernel=compute_kernel,
)

return _parameter_sharding_generator


def row_wise(
sizes_placement: Optional[Tuple[List[int], str]] = None
sizes_placement: Optional[Tuple[List[int], Union[str, List[str]]]] = None
) -> ParameterShardingGenerator:
"""
Returns a generator of ParameterShardingPlan for `ShardingType::ROW_WISE` for construct_module_sharding_plan.
Expand All @@ -470,6 +471,12 @@ def row_wise(
)
"""

if sizes_placement is not None and isinstance(sizes_placement[1], list):
assert len(sizes_placement[0]) == len(
sizes_placement[1]
), "sizes_placement and device per placement (in case of sharding "
"across HBM and CPU host) must have the same length"

def _parameter_sharding_generator(
param: nn.Parameter,
local_size: int,
Expand Down Expand Up @@ -507,21 +514,29 @@ def _parameter_sharding_generator(
f"Cannot fit tensor of {rows, cols} into sizes_ranks_placements = {sizes_placement}"
)

index: int = 0
placements: List[str] = []
if sizes_placement is not None:
device_type = ""
for i in range(len(sizes_placement[0])):
if isinstance(sizes_placement[1], list):
device_type = sizes_placement[1][i]
placements.append(placement_helper(device_type, index, i))
else:
device_type = str(sizes_placement[1])
placements.append(placement_helper(device_type, index, i))

if device_type == "cuda":
index += 1

return _get_parameter_sharding(
param,
ShardingType.ROW_WISE.value,
size_offset_ranks,
local_size,
device_type,
sharder,
placements=(
[
placement_helper(sizes_placement[1], i)
for i in range(len(sizes_placement[0]))
]
if sizes_placement
else None
),
placements=placements if sizes_placement else None,
compute_kernel=(
EmbeddingComputeKernel.QUANT.value if sizes_placement else None
),
Expand Down
72 changes: 72 additions & 0 deletions torchrec/distributed/tests/test_sharding_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,78 @@ def test_table_wise_set_device(self) -> None:
"cpu",
)

def test_row_wise_set_heterogenous_device(self) -> None:
embedding_bag_config = [
EmbeddingBagConfig(
name=f"table_{idx}",
feature_names=[f"feature_{idx}"],
embedding_dim=64,
num_embeddings=4096,
)
for idx in range(2)
]
module_sharding_plan = construct_module_sharding_plan(
EmbeddingBagCollection(tables=embedding_bag_config),
per_param_sharding={
"table_0": row_wise(
sizes_placement=([2048, 1024, 1024], ["cpu", "cuda", "cuda"])
),
"table_1": row_wise(
sizes_placement=([2048, 1024, 1024], ["cpu", "cpu", "cpu"])
),
},
local_size=1,
world_size=2,
device_type="cuda",
)

# Make sure per_param_sharding setting override the default device_type
device_table_0_shard_0 = (
# pyre-ignore[16]
module_sharding_plan["table_0"]
.sharding_spec.shards[0]
.placement
)
self.assertEqual(
device_table_0_shard_0.device().type,
"cpu",
)
# cpu always has rank 0
self.assertEqual(
device_table_0_shard_0.rank(),
0,
)
for i in range(1, 3):
device_table_0_shard_i = (
module_sharding_plan["table_0"].sharding_spec.shards[i].placement
)
self.assertEqual(
device_table_0_shard_i.device().type,
"cuda",
)
# first rank is assigned to cpu so index = rank - 1
self.assertEqual(
device_table_0_shard_i.device().index,
i - 1,
)
self.assertEqual(
device_table_0_shard_i.rank(),
i,
)
for i in range(3):
device_table_1_shard_i = (
module_sharding_plan["table_1"].sharding_spec.shards[i].placement
)
self.assertEqual(
device_table_1_shard_i.device().type,
"cpu",
)
# cpu always has rank 0
self.assertEqual(
device_table_1_shard_i.rank(),
0,
)

def test_column_wise(self) -> None:
embedding_bag_config = [
EmbeddingBagConfig(
Expand Down

0 comments on commit aa7c886

Please sign in to comment.