Skip to content

Commit

Permalink
2024-10-01 nightly release (34e2472)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Oct 1, 2024
1 parent e268512 commit c22240b
Show file tree
Hide file tree
Showing 13 changed files with 789 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_dynamic_embedding_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
fail-fast: false
matrix:
os: [ ubuntu-latest ]
pyver: [ cp38, cp39, cp310 ]
pyver: [ cp39, cp310 ]
cuver: [ "11.8" ]

steps:
Expand Down
10 changes: 5 additions & 5 deletions docs/source/_static/css/custom.css
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
--sd-color-card-text: inherit;
--sd-color-card-header: transparent;
--sd-color-card-footer: transparent;
--sd-color-tabs-label-active: hsla(231, 99%, 66%, 1);
--sd-color-tabs-label-hover: hsla(231, 99%, 66%, 1);
--sd-color-tabs-label-inactive: hsl(0, 0%, 66%);
--sd-color-tabs-underline-active: hsla(231, 99%, 66%, 1);
--sd-color-tabs-underline-hover: rgba(178, 206, 245, 0.62);
--sd-color-tabs-label-active: #ee4c2c;
--sd-color-tabs-label-hover: #ee4c2c;
--sd-color-tabs-label-inactive: #6c6c6d;
--sd-color-tabs-underline-active: #ee4c2c;
--sd-color-tabs-underline-hover: #fabdbd;
--sd-color-tabs-underline-inactive: transparent;
--sd-color-tabs-overline: rgb(222, 222, 222);
--sd-color-tabs-underline: rgb(222, 222, 222);
Expand Down
43 changes: 31 additions & 12 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from fbgemm_gpu.tbe.ssd import ASSOC, SSDTableBatchedEmbeddingBags
from torch import nn
from torchrec.distributed.comm import get_local_rank
from torchrec.distributed.comm import get_local_rank, get_local_size
from torchrec.distributed.composable.table_batched_embedding_slice import (
TableBatchedEmbeddingSlice,
)
Expand Down Expand Up @@ -215,29 +215,33 @@ def get_optimizer_rowwise_shard_metadata_and_global_metadata(
table_global_metadata: ShardedTensorMetadata,
optimizer_state: torch.Tensor,
sharding_dim: int,
is_grid_sharded: bool = False,
) -> Tuple[Dict[ShardMetadata, ShardMetadata], ShardedTensorMetadata]:

table_global_shards_metadata: List[ShardMetadata] = (
table_global_metadata.shards_metadata
)

# column-wise sharding
# sort the metadata based on column offset and
# we construct the momentum tensor in row-wise sharded way
if sharding_dim == 1:
# column-wise sharding
# sort the metadata based on column offset and
# we construct the momentum tensor in row-wise sharded way
table_global_shards_metadata = sorted(
table_global_shards_metadata,
key=lambda shard: shard.shard_offsets[1],
)

table_shard_metadata_to_optimizer_shard_metadata = {}

rolling_offset = 0
for idx, table_shard_metadata in enumerate(table_global_shards_metadata):
offset = table_shard_metadata.shard_offsets[0]
# for column-wise sharding, we still create row-wise sharded metadata for optimizer
# manually create a row-wise offset

if sharding_dim == 1:
if is_grid_sharded:
# we use a rolling offset to calculate the current offset for shard to account for uneven row wise case for our shards
offset = rolling_offset
rolling_offset += table_shard_metadata.shard_sizes[0]
elif sharding_dim == 1:
# for column-wise sharding, we still create row-wise sharded metadata for optimizer
# manually create a row-wise offset
offset = idx * table_shard_metadata.shard_sizes[0]

table_shard_metadata_to_optimizer_shard_metadata[
Expand All @@ -255,14 +259,22 @@ def get_optimizer_rowwise_shard_metadata_and_global_metadata(
)
len_rw_shards = (
len(table_shard_metadata_to_optimizer_shard_metadata)
if sharding_dim == 1
if sharding_dim == 1 and not is_grid_sharded
else 1
)
# for grid sharding, the row dimension is replicated CW shard times
grid_shard_nodes = (
len(table_global_shards_metadata) // get_local_size()
if is_grid_sharded
else 1
)
rowwise_optimizer_st_metadata = ShardedTensorMetadata(
shards_metadata=list(
table_shard_metadata_to_optimizer_shard_metadata.values()
),
size=torch.Size([table_global_metadata.size[0] * len_rw_shards]),
size=torch.Size(
[table_global_metadata.size[0] * len_rw_shards * grid_shard_nodes]
),
tensor_properties=tensor_properties,
)

Expand Down Expand Up @@ -324,7 +336,6 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(

all_optimizer_states = emb_module.get_optimizer_state()
optimizer_states_keys_by_table: Dict[str, List[torch.Tensor]] = {}

for (
table_config,
optimizer_states,
Expand Down Expand Up @@ -408,6 +419,13 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
1 if table_config.local_cols != table_config.embedding_dim else 0
)

is_grid_sharded: bool = (
True
if table_config.local_cols != table_config.embedding_dim
and table_config.local_rows != table_config.num_embeddings
else False
)

if all(
opt_state is not None for opt_state in shard_params.optimizer_states
):
Expand All @@ -431,6 +449,7 @@ def get_sharded_optim_state(momentum_idx: int) -> ShardedTensor:
table_config.global_metadata,
shard_params.optimizer_states[0][momentum_idx - 1],
sharding_dim,
is_grid_sharded,
)
else:
(
Expand Down
8 changes: 8 additions & 0 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
)
from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding
from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding
from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding
from torchrec.distributed.sharding.rw_sharding import RwPooledEmbeddingSharding
from torchrec.distributed.sharding.tw_sharding import TwPooledEmbeddingSharding
from torchrec.distributed.sharding.twcw_sharding import TwCwPooledEmbeddingSharding
Expand Down Expand Up @@ -193,6 +194,13 @@ def create_embedding_bag_sharding(
permute_embeddings=permute_embeddings,
qcomm_codecs_registry=qcomm_codecs_registry,
)
elif sharding_type == ShardingType.GRID_SHARD.value:
return GridPooledEmbeddingSharding(
sharding_infos,
env,
device,
qcomm_codecs_registry=qcomm_codecs_registry,
)
else:
raise ValueError(f"Sharding type not supported {sharding_type}")

Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/planner/partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __lt__(self, other: "OrderedDeviceHardware") -> bool:


class GreedyPerfPartitioner(Partitioner):
"""Greedy Partitioner
"""Greedy Partitioner.
Args:
sort_by (SortBy): Sort sharding options by storage or perf in
Expand Down
41 changes: 41 additions & 0 deletions torchrec/distributed/planner/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,28 @@ class EmbeddingShardingPlanner(ShardingPlanner):
"""
Provides an optimized sharding plan for a given module with shardable parameters
according to the provided sharders, topology, and constraints.
Args:
topology (Optional[Topology]): the topology of the current process group.
batch_size (Optional[int]): the batch size of the model.
enumerator (Optional[Enumerator]): the enumerator to use
storage_reservation (Optional[StorageReservation]): the storage reservation to use
proposer (Optional[Union[Proposer, List[Proposer]]]): the proposer(s) to use
partitioner (Optional[Partitioner]): the partitioner to use
performance_model (Optional[PerfModel]): the performance model to use
stats (Optional[Union[Stats, List[Stats]]]): the stats to use
constraints (Optional[Dict[str, ParameterConstraints]]): per table constraints
for sharding.
debug (bool): whether to print debug information.
Example::
ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device("meta"))
planner = EmbeddingShardingPlanner()
plan = planner.plan(
module=ebc,
sharders=[EmbeddingBagCollectionSharder()],
)
"""

def __init__(
Expand Down Expand Up @@ -215,6 +237,14 @@ def collective_plan(
) -> ShardingPlan:
"""
Call self.plan(...) on rank 0 and broadcast
Args:
module (nn.Module): the module to shard.
sharders (Optional[List[ModuleSharder[nn.Module]]]): the sharders to use for sharding
pg (Optional[dist.ProcessGroup]): the process group to use for collective operations
Returns:
ShardingPlan: the sharding plan for the module.
"""
if pg is None:
assert dist.is_initialized(), (
Expand All @@ -239,6 +269,17 @@ def plan(
module: nn.Module,
sharders: List[ModuleSharder[nn.Module]],
) -> ShardingPlan:
"""
Provides an optimized sharding plan for a given module with shardable parameters
according to the provided sharders, topology, and constraints.
Args:
module (nn.Module): the module to shard.
sharders (List[ModuleSharder[nn.Module]]): the sharders to use for sharding.
Returns:
ShardingPlan: the sharding plan for the module.
"""
self._num_proposals = 0
self._num_plans = 0
start_time = perf_counter()
Expand Down
53 changes: 39 additions & 14 deletions torchrec/distributed/planner/shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,13 @@ def _is_prefetch_pipelined(

class EmbeddingPerfEstimator(ShardEstimator):
"""
Embedding Wall Time Perf Estimator
Embedding Wall Time Perf Estimator. This estimator estimates the wall time
of a given sharding option.
Args:
topology (Topology): device topology.
constraints (Optional[Dict[str, ParameterConstraints]]): parameter constraints.
is_inference (bool): whether or not the estimator is used for inference.
"""

def __init__(
Expand All @@ -88,6 +94,13 @@ def estimate(
sharding_options: List[ShardingOption],
sharder_map: Optional[Dict[str, ModuleSharder[nn.Module]]] = None,
) -> None:
"""
Estimates the wall time of a given sharding option.
Args:
sharding_options (List[ShardingOption]): list of sharding options.
sharder_map (Optional[Dict[str, ModuleSharder[nn.Module]]]): sharder map.
"""
if not sharder_map:
assert not sharding_options, "sharder_map not provided for sharding_options"
return
Expand Down Expand Up @@ -298,6 +311,7 @@ def perf_func_emb_wall_time(
of device.
prefetch_pipeline (bool = False): whether prefetch pipeline is enabled.
expected_cache_fetches (float): number of expected cache fetches across global batch
uneven_sharding_perf_multiplier (float = 1.0): multiplier to account for uneven sharding perf
Returns:
List[float]: the list of perf for each shard.
Expand Down Expand Up @@ -870,19 +884,22 @@ class EmbeddingStorageEstimator(ShardEstimator):
Embedding Storage Usage Estimator
Args:
pipeline_type: The type of pipeline, if any. Will determine the input replication
factor during memory estimation.
run_embedding_at_peak_memory: If the embedding fwd/bwd will be execute when HBM
usage is at peak. When set to TRUE, any temporary memory allocation during
embedding forward/backward, as long as output sizes before output_dist will
be counted towards HBM storage cost. Otherwise they won't since they'll be
"hidden" by the real memory peak.
Only take effect if pipeline_type is set for backward compatibility (not affecting
models using old pipeline-agnostic formula)
Default to FALSE because this is typically FALSE for a RecSys since memory
peak happens at the end of dense forwrad / beginning of dense backward instead.
topology (Topology): device topology.
constraints (Optional[Dict[str, ParameterConstraints]]): parameter constraints.
pipeline_type (PipelineType): The type of pipeline, if any. Will determine the
input replication factor during memory estimation.
run_embedding_at_peak_memory (bool): If the embedding fwd/bwd will be execute when HBM
usage is at peak. When set to TRUE, any temporary memory allocation during
embedding forward/backward, as long as output sizes before output_dist will
be counted towards HBM storage cost. Otherwise they won't since they'll be
"hidden" by the real memory peak.
Only take effect if pipeline_type is set for backward compatibility (not affecting
models using old pipeline-agnostic formula)
Default to false because this is typically false for RecSys since memory
peak happens at the end of dense forwrad / beginning of dense backward instead.
is_inference (bool): If the model is inference model. Default to False.
"""

def __init__(
Expand All @@ -904,6 +921,14 @@ def estimate(
sharding_options: List[ShardingOption],
sharder_map: Optional[Dict[str, ModuleSharder[nn.Module]]] = None,
) -> None:
"""
Estimate the storage cost of each sharding option.
Args:
sharding_options (List[ShardingOption]): list of sharding options.
sharder_map (Optional[Dict[str, ModuleSharder[nn.Module]]]): map from module
type to sharder.
"""
if not sharder_map:
assert not sharding_options, "sharder_map not provided for sharding_options"
return
Expand Down
3 changes: 2 additions & 1 deletion torchrec/distributed/planner/tests/test_proposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ def test_grid_search_three_table(self) -> None:
So the total number of pruned options will be:
(num_sharding_types - 1) * 3 + 1 = 16
"""
num_pruned_options = (len(ShardingType) - 1) * 3 + 1
# NOTE - remove -2 from sharding type length once grid sharding in planner is added
num_pruned_options = (len(ShardingType) - 2) * 3 + 1
self.grid_search_proposer.load(search_space)
for (
sharding_options
Expand Down
31 changes: 28 additions & 3 deletions torchrec/distributed/planner/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,15 @@ def load(
self,
search_space: List[ShardingOption],
enumerator: Optional[Enumerator] = None,
) -> None: ...
) -> None:
"""
Load search space into proposer.
Args:
search_space (List[ShardingOption]): search space to load.
enumerator (Enumerator): enumerator used to generate search space.
"""
...

@abc.abstractmethod
def feedback(
Expand All @@ -729,10 +737,27 @@ def feedback(
plan: Optional[List[ShardingOption]] = None,
perf_rating: Optional[float] = None,
storage_constraint: Optional[Topology] = None,
) -> None: ...
) -> None:
"""
Provide feedback to proposer.
Args:
partitionable (bool): whether the plan is partitionable.
plan (Optional[List[ShardingOption]]): plan to provide feedback on.
perf_rating (Optional[float]): performance rating of the plan.
storage_constraint (Optional[Topology]): storage constraint of the plan.
"""
...

@abc.abstractmethod
def propose(self) -> Optional[List[ShardingOption]]: ...
def propose(self) -> Optional[List[ShardingOption]]:
"""
Propose a sharding plan.
Returns:
Optional[List[ShardingOption]]: proposed plan.
"""
...


class Partitioner(abc.ABC):
Expand Down
Loading

0 comments on commit c22240b

Please sign in to comment.