Skip to content

Commit

Permalink
update grid sharding doc strings (#2488)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2488

tsia - updated docstrings to be more useful/accurate

Differential Revision: D64423503

fbshipit-source-id: 7f1a92b259fb815571d5238a9238d447543cdb02
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Oct 16, 2024
1 parent afd5726 commit d317c0b
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions torchrec/distributed/sharding/grid_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def __init__(

def _init_combined_embeddings(self) -> None:
"""
similar to CW sharding, but this time each CW shard is on a node and not rank
Initializes combined embeddings, similar to the CW sharding implementation,
but in this case the CW shard is treated on a per node basis and not per rank.
"""
embedding_names = []
for grouped_embedding_configs in self._grouped_embedding_configs_per_node:
Expand Down Expand Up @@ -179,6 +180,17 @@ def _shard(
self,
sharding_infos: List[EmbeddingShardingInfo],
) -> List[List[ShardedEmbeddingTable]]:
"""
Shards the embedding tables.
This method takes the sharding infos and returns a list of lists of
sharded embedding tables, where each inner list represents the tables
for a specific rank.
Args:
sharding_infos (List[EmbeddingShardingInfo]): The sharding infos.
Returns:
List[List[ShardedEmbeddingTable]]: The sharded embedding tables.
"""
world_size = self._world_size
tables_per_rank: List[List[ShardedEmbeddingTable]] = [
[] for i in range(world_size)
Expand All @@ -198,7 +210,7 @@ def _shard(
),
)

# expectation is planner CW shards across a node, so each CW shard will have local_size num row shards
# Expectation is planner CW shards across a node, so each CW shard will have local_size number of row shards
# pyre-fixme [6]
for i, rank in enumerate(info.param_sharding.ranks):
tables_per_rank[rank].append(
Expand All @@ -212,7 +224,6 @@ def _shard(
pooling=info.embedding_config.pooling,
is_weighted=info.embedding_config.is_weighted,
has_feature_processor=info.embedding_config.has_feature_processor,
# sharding by row and col
local_rows=shards[i].shard_sizes[0],
local_cols=shards[i].shard_sizes[1],
compute_kernel=EmbeddingComputeKernel(
Expand Down Expand Up @@ -420,7 +431,7 @@ class GridPooledEmbeddingSharding(
]
):
"""
Shards embedding bags table-wise then row-wise.
Shards embedding bags into column wise shards and shards each CW shard table wise row wise within a node
"""

def create_input_dist(
Expand Down

0 comments on commit d317c0b

Please sign in to comment.