Skip to content

Commit

Permalink
2024-09-26 nightly release (68180e6)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Sep 26, 2024
1 parent 0a9dee2 commit b934216
Show file tree
Hide file tree
Showing 14 changed files with 871 additions and 208 deletions.
1 change: 1 addition & 0 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class OptimType(Enum):
LION = "LION"
ADAMW = "ADAMW"
SHAMPOO_V2_MRS = "SHAMPOO_V2_MRS"
SHAMPOO_MRS = "SHAMPOO_MRS"


@unique
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def _create_input_dists(
self._features_order.append(input_feature_names.index(f))
self._features_order = (
[]
if self._features_order == list(range(len(self._features_order)))
if self._features_order == list(range(len(input_feature_names)))
else self._features_order
)
self.register_buffer(
Expand Down
17 changes: 11 additions & 6 deletions torchrec/distributed/tests/test_mc_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def forward(


def _test_sharding_and_remapping( # noqa C901
output_keys: List[str],
tables: List[EmbeddingConfig],
rank: int,
world_size: int,
Expand Down Expand Up @@ -217,7 +218,7 @@ def _test_sharding_and_remapping( # noqa C901
), f"initial state {key} on {ctx.rank} does not match, got {tensor}, expect {final_state_per_rank[rank][postfix]}"

remapped_ids = [remapped_ids1, remapped_ids2]
for key in kjt_input.keys():
for key in output_keys:
for i, kjt_out in enumerate(kjt_out_per_iter):
assert torch.equal(
remapped_ids[i][key].values(),
Expand Down Expand Up @@ -699,15 +700,15 @@ def test_sharding_zch_mc_ec_remap(self, backend: str) -> None:

kjt_input_per_rank = [ # noqa
KeyedJaggedTensor.from_lengths_sync(
keys=["feature_0", "feature_1"],
keys=["feature_0", "feature_1", "feature_2"],
values=torch.LongTensor(
[1000, 2000, 1001, 2000, 2001, 2002],
[1000, 2000, 1001, 2000, 2001, 2002, 1, 1, 1],
),
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]),
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),
weights=None,
),
KeyedJaggedTensor.from_lengths_sync(
keys=["feature_0", "feature_1"],
keys=["feature_0", "feature_1", "feature_2"],
values=torch.LongTensor(
[
1000,
Expand All @@ -716,9 +717,12 @@ def test_sharding_zch_mc_ec_remap(self, backend: str) -> None:
2000,
2002,
2004,
2,
2,
2,
],
),
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]),
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),
weights=None,
),
]
Expand Down Expand Up @@ -814,6 +818,7 @@ def test_sharding_zch_mc_ec_remap(self, backend: str) -> None:

self._run_multi_process_test(
callable=_test_sharding_and_remapping,
output_keys=["feature_0", "feature_1"],
world_size=WORLD_SIZE,
tables=embedding_config,
kjt_input_per_rank=kjt_input_per_rank,
Expand Down
17 changes: 11 additions & 6 deletions torchrec/distributed/tests/test_mc_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def _test_sharding( # noqa C901


def _test_sharding_and_remapping( # noqa C901
output_keys: List[str],
tables: List[EmbeddingBagConfig],
rank: int,
world_size: int,
Expand Down Expand Up @@ -245,7 +246,7 @@ def _test_sharding_and_remapping( # noqa C901
loss2, remapped_ids2 = sharded_sparse_arch(kjt_input)
loss2.backward()
remapped_ids = [remapped_ids1, remapped_ids2]
for key in kjt_input.keys():
for key in output_keys:
for i, kjt_out in enumerate(kjt_out_per_iter):
assert torch.equal(
remapped_ids[i][key].values(),
Expand Down Expand Up @@ -351,15 +352,15 @@ def test_sharding_zch_mc_ebc(self, backend: str) -> None:

kjt_input_per_rank = [ # noqa
KeyedJaggedTensor.from_lengths_sync(
keys=["feature_0", "feature_1"],
keys=["feature_0", "feature_1", "feature_2"],
values=torch.LongTensor(
[1000, 2000, 1001, 2000, 2001, 2002],
[1000, 2000, 1001, 2000, 2001, 2002, 1, 1, 1],
),
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]),
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),
weights=None,
),
KeyedJaggedTensor.from_lengths_sync(
keys=["feature_0", "feature_1"],
keys=["feature_0", "feature_1", "feature_2"],
values=torch.LongTensor(
[
1000,
Expand All @@ -368,9 +369,12 @@ def test_sharding_zch_mc_ebc(self, backend: str) -> None:
2000,
2002,
2004,
1,
1,
1,
],
),
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]),
lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),
weights=None,
),
]
Expand Down Expand Up @@ -421,6 +425,7 @@ def test_sharding_zch_mc_ebc(self, backend: str) -> None:

self._run_multi_process_test(
callable=_test_sharding_and_remapping,
output_keys=["feature_0", "feature_1"],
world_size=WORLD_SIZE,
tables=embedding_bag_config,
kjt_input_per_rank=kjt_input_per_rank,
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
dataloader_iter=dataloader_iter,
to_device_non_blocking=True,
memcpy_stream_priority=-1,
memcpy_stream=self._memcpy_stream,
)
self._batch_loader.start()

Expand Down
14 changes: 9 additions & 5 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,17 +1235,21 @@ def __init__(
dataloader_iter: Iterator[In],
to_device_non_blocking: bool,
memcpy_stream_priority: int = 0,
memcpy_stream: Optional[torch.Stream] = None,
) -> None:
super().__init__()
self._stop: bool = False
self._dataloader_iter = dataloader_iter
self._buffer_empty_event: Event = Event()
self._buffer_filled_event: Event = Event()
self._memcpy_stream: Optional[torch.Stream] = (
torch.get_device_module(device).Stream(priority=memcpy_stream_priority)
if device.type in ["cuda", "mtia"]
else None
)
if memcpy_stream is None:
self._memcpy_stream: Optional[torch.Stream] = (
torch.get_device_module(device).Stream(priority=memcpy_stream_priority)
if device.type in ["cuda", "mtia"]
else None
)
else:
self._memcpy_stream = memcpy_stream
self._device = device
self._to_device_non_blocking = to_device_non_blocking
self._buffered: Optional[In] = None
Expand Down
39 changes: 39 additions & 0 deletions torchrec/modules/embedding_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@

@unique
class PoolingType(Enum):
"""
Pooling type for embedding table.
Args:
SUM (str): sum pooling.
MEAN (str): mean pooling.
NONE (str): no pooling.
"""

SUM = "SUM"
MEAN = "MEAN"
NONE = "NONE"
Expand Down Expand Up @@ -153,6 +162,23 @@ def data_type_to_dtype(data_type: DataType) -> torch.dtype:

@dataclass
class BaseEmbeddingConfig:
"""
Base class for embedding configs.
Args:
num_embeddings (int): number of embeddings.
embedding_dim (int): embedding dimension.
name (str): name of the embedding table.
data_type (DataType): data type of the embedding table.
feature_names (List[str]): list of feature names.
weight_init_max (Optional[float]): max value for weight initialization.
weight_init_min (Optional[float]): min value for weight initialization.
num_embeddings_post_pruning (Optional[int]): number of embeddings after pruning for inference.
If None, no pruning is applied.
init_fn (Optional[Callable[[torch.Tensor], Optional[torch.Tensor]]]): init function for embedding weights.
need_pos (bool): whether table is position weighted.
"""

num_embeddings: int
embedding_dim: int
name: str = ""
Expand Down Expand Up @@ -204,11 +230,24 @@ class EmbeddingTableConfig(BaseEmbeddingConfig):

@dataclass
class EmbeddingBagConfig(BaseEmbeddingConfig):
"""
EmbeddingBagConfig is a dataclass that represents a single embedding table,
where outputs are meant to be pooled.
Args:
pooling (PoolingType): pooling type.
"""

pooling: PoolingType = PoolingType.SUM


@dataclass
class EmbeddingConfig(BaseEmbeddingConfig):
"""
EmbeddingConfig is a dataclass that represents a single embedding table.
"""

pass


Expand Down
50 changes: 48 additions & 2 deletions torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,11 @@ def __init__(

def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
"""
Args:
features (KeyedJaggedTensor): KJT of form [F X B X L].
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature.
Args:
features (KeyedJaggedTensor): Input KJT
Returns:
KeyedTensor
"""
Expand Down Expand Up @@ -240,16 +242,32 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
)

def is_weighted(self) -> bool:
"""
Returns:
bool: Whether the EmbeddingBagCollection is weighted.
"""
return self._is_weighted

def embedding_bag_configs(self) -> List[EmbeddingBagConfig]:
"""
Returns:
List[EmbeddingBagConfig]: The embedding bag configs.
"""
return self._embedding_bag_configs

@property
def device(self) -> torch.device:
"""
Returns:
torch.device: The compute device.
"""
return self._device

def reset_parameters(self) -> None:
"""
Reset the parameters of the EmbeddingBagCollection. Parameter values
are intiialized based on the `init_fn` of each EmbeddingBagConfig if it exists.
"""
if (isinstance(self.device, torch.device) and self.device.type == "meta") or (
isinstance(self.device, str) and self.device == "meta"
):
Expand Down Expand Up @@ -407,6 +425,9 @@ def forward(
features: KeyedJaggedTensor,
) -> Dict[str, JaggedTensor]:
"""
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
and returns a `Dict[str, JaggedTensor]`, which is the result of the individual embeddings for each feature.
Args:
features (KeyedJaggedTensor): KJT of form [F X B X L].
Expand All @@ -433,22 +454,47 @@ def forward(
return feature_embeddings

def need_indices(self) -> bool:
"""
Returns:
bool: Whether the EmbeddingCollection needs indices.
"""
return self._need_indices

def embedding_dim(self) -> int:
"""
Returns:
int: The embedding dimension.
"""
return self._embedding_dim

def embedding_configs(self) -> List[EmbeddingConfig]:
"""
Returns:
List[EmbeddingConfig]: The embedding configs.
"""
return self._embedding_configs

def embedding_names_by_table(self) -> List[List[str]]:
"""
Returns:
List[List[str]]: The embedding names by table.
"""
return self._embedding_names_by_table

@property
def device(self) -> torch.device:
"""
Returns:
torch.device: The compute device.
"""
return self._device

def reset_parameters(self) -> None:
"""
Reset the parameters of the EmbeddingCollection. Parameter values
are intiialized based on the `init_fn` of each EmbeddingConfig if it exists.
"""

if (isinstance(self.device, torch.device) and self.device.type == "meta") or (
isinstance(self.device, str) and self.device == "meta"
):
Expand Down
2 changes: 1 addition & 1 deletion torchrec/modules/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _mcc_lazy_init(
for f in feature_names:
features_order.append(input_feature_names.index(f))

if features_order == list(range(len(features_order))):
if features_order == list(range(len(input_feature_names))):
features_order = torch.jit.annotate(List[int], [])
created_feature_order = True

Expand Down
Loading

0 comments on commit b934216

Please sign in to comment.