From afd5726e7ca64bf526f92462706c85c668169d3f Mon Sep 17 00:00:00 2001 From: Wang Zhou Date: Tue, 15 Oct 2024 22:17:26 -0700 Subject: [PATCH] Back out "Add `iter` singular value into TBE optimizer state" (#2487) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2487 Backout torchrec changes in D63909559 to unblock MVAI Reviewed By: dragonxlwang Differential Revision: D64406709 fbshipit-source-id: 59ad4e25c567ba55d08447283d5bb7ee14f564d2 --- .../distributed/batched_embedding_kernel.py | 65 +++---------------- 1 file changed, 10 insertions(+), 55 deletions(-) diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 8500e21fa..0da8df7d8 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -211,42 +211,6 @@ class ShardParams: local_metadata: List[ShardMetadata] embedding_weights: List[torch.Tensor] - def get_optimizer_single_value_shard_metadata_and_global_metadata( - table_global_metadata: ShardedTensorMetadata, - optimizer_state: torch.Tensor, - ) -> Tuple[Dict[ShardMetadata, ShardMetadata], ShardedTensorMetadata]: - table_global_shards_metadata: List[ShardMetadata] = ( - table_global_metadata.shards_metadata - ) - - table_shard_metadata_to_optimizer_shard_metadata = {} - for offset, table_shard_metadata in enumerate(table_global_shards_metadata): - table_shard_metadata_to_optimizer_shard_metadata[ - table_shard_metadata - ] = ShardMetadata( - shard_sizes=[1], # single value optimizer state - shard_offsets=[offset], # offset increases by 1 for each shard - placement=table_shard_metadata.placement, - ) - - tensor_properties = TensorProperties( - dtype=optimizer_state.dtype, - layout=optimizer_state.layout, - requires_grad=False, - ) - single_value_optimizer_st_metadata = ShardedTensorMetadata( - shards_metadata=list( - table_shard_metadata_to_optimizer_shard_metadata.values() - ), - size=torch.Size([len(table_global_shards_metadata)]), - tensor_properties=tensor_properties, - ) - - return ( - table_shard_metadata_to_optimizer_shard_metadata, - single_value_optimizer_st_metadata, - ) - def get_optimizer_rowwise_shard_metadata_and_global_metadata( table_global_metadata: ShardedTensorMetadata, optimizer_state: torch.Tensor, @@ -392,10 +356,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata( if optimizer_states: optimizer_state_values = tuple(optimizer_states.values()) for optimizer_state_value in optimizer_state_values: - assert ( - table_config.local_rows == optimizer_state_value.size(0) - or optimizer_state_value.nelement() == 1 # single value state - ) + assert table_config.local_rows == optimizer_state_value.size(0) optimizer_states_keys_by_table[table_config.name] = list( optimizer_states.keys() ) @@ -474,35 +435,29 @@ def get_sharded_optim_state(momentum_idx: int) -> ShardedTensor: momentum_local_shards: List[Shard] = [] optimizer_sharded_tensor_metadata: ShardedTensorMetadata - optim_state = shard_params.optimizer_states[0][momentum_idx - 1] # pyre-ignore[16] - if optim_state.nelement() == 1: - # single value state: one value per table - ( - table_shard_metadata_to_optimizer_shard_metadata, - optimizer_sharded_tensor_metadata, - ) = get_optimizer_single_value_shard_metadata_and_global_metadata( - table_config.global_metadata, - optim_state, - ) - elif optim_state.dim() == 1: - # rowwise state: param.shape[0] == state.shape[0], state.shape[1] == 1 + is_rowwise_optimizer_state: bool = ( + # pyre-ignore + shard_params.optimizer_states[0][momentum_idx - 1].dim() + == 1 + ) + + if is_rowwise_optimizer_state: ( table_shard_metadata_to_optimizer_shard_metadata, optimizer_sharded_tensor_metadata, ) = get_optimizer_rowwise_shard_metadata_and_global_metadata( table_config.global_metadata, - optim_state, + shard_params.optimizer_states[0][momentum_idx - 1], sharding_dim, is_grid_sharded, ) else: - # pointwise state: param.shape == state.shape ( table_shard_metadata_to_optimizer_shard_metadata, optimizer_sharded_tensor_metadata, ) = get_optimizer_pointwise_shard_metadata_and_global_metadata( table_config.global_metadata, - optim_state, + shard_params.optimizer_states[0][momentum_idx - 1], ) for optimizer_state, table_shard_local_metadata in zip(