Skip to content

Commit

Permalink
Back out "Add iter singular value into TBE optimizer state" (#2487)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2487

Backout torchrec changes in D63909559 to unblock MVAI

Reviewed By: dragonxlwang

Differential Revision: D64406709

fbshipit-source-id: 59ad4e25c567ba55d08447283d5bb7ee14f564d2
  • Loading branch information
Wang Zhou authored and facebook-github-bot committed Oct 16, 2024
1 parent b6e784e commit afd5726
Showing 1 changed file with 10 additions and 55 deletions.
65 changes: 10 additions & 55 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,42 +211,6 @@ class ShardParams:
local_metadata: List[ShardMetadata]
embedding_weights: List[torch.Tensor]

def get_optimizer_single_value_shard_metadata_and_global_metadata(
table_global_metadata: ShardedTensorMetadata,
optimizer_state: torch.Tensor,
) -> Tuple[Dict[ShardMetadata, ShardMetadata], ShardedTensorMetadata]:
table_global_shards_metadata: List[ShardMetadata] = (
table_global_metadata.shards_metadata
)

table_shard_metadata_to_optimizer_shard_metadata = {}
for offset, table_shard_metadata in enumerate(table_global_shards_metadata):
table_shard_metadata_to_optimizer_shard_metadata[
table_shard_metadata
] = ShardMetadata(
shard_sizes=[1], # single value optimizer state
shard_offsets=[offset], # offset increases by 1 for each shard
placement=table_shard_metadata.placement,
)

tensor_properties = TensorProperties(
dtype=optimizer_state.dtype,
layout=optimizer_state.layout,
requires_grad=False,
)
single_value_optimizer_st_metadata = ShardedTensorMetadata(
shards_metadata=list(
table_shard_metadata_to_optimizer_shard_metadata.values()
),
size=torch.Size([len(table_global_shards_metadata)]),
tensor_properties=tensor_properties,
)

return (
table_shard_metadata_to_optimizer_shard_metadata,
single_value_optimizer_st_metadata,
)

def get_optimizer_rowwise_shard_metadata_and_global_metadata(
table_global_metadata: ShardedTensorMetadata,
optimizer_state: torch.Tensor,
Expand Down Expand Up @@ -392,10 +356,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
if optimizer_states:
optimizer_state_values = tuple(optimizer_states.values())
for optimizer_state_value in optimizer_state_values:
assert (
table_config.local_rows == optimizer_state_value.size(0)
or optimizer_state_value.nelement() == 1 # single value state
)
assert table_config.local_rows == optimizer_state_value.size(0)
optimizer_states_keys_by_table[table_config.name] = list(
optimizer_states.keys()
)
Expand Down Expand Up @@ -474,35 +435,29 @@ def get_sharded_optim_state(momentum_idx: int) -> ShardedTensor:
momentum_local_shards: List[Shard] = []
optimizer_sharded_tensor_metadata: ShardedTensorMetadata

optim_state = shard_params.optimizer_states[0][momentum_idx - 1] # pyre-ignore[16]
if optim_state.nelement() == 1:
# single value state: one value per table
(
table_shard_metadata_to_optimizer_shard_metadata,
optimizer_sharded_tensor_metadata,
) = get_optimizer_single_value_shard_metadata_and_global_metadata(
table_config.global_metadata,
optim_state,
)
elif optim_state.dim() == 1:
# rowwise state: param.shape[0] == state.shape[0], state.shape[1] == 1
is_rowwise_optimizer_state: bool = (
# pyre-ignore
shard_params.optimizer_states[0][momentum_idx - 1].dim()
== 1
)

if is_rowwise_optimizer_state:
(
table_shard_metadata_to_optimizer_shard_metadata,
optimizer_sharded_tensor_metadata,
) = get_optimizer_rowwise_shard_metadata_and_global_metadata(
table_config.global_metadata,
optim_state,
shard_params.optimizer_states[0][momentum_idx - 1],
sharding_dim,
is_grid_sharded,
)
else:
# pointwise state: param.shape == state.shape
(
table_shard_metadata_to_optimizer_shard_metadata,
optimizer_sharded_tensor_metadata,
) = get_optimizer_pointwise_shard_metadata_and_global_metadata(
table_config.global_metadata,
optim_state,
shard_params.optimizer_states[0][momentum_idx - 1],
)

for optimizer_state, table_shard_local_metadata in zip(
Expand Down

0 comments on commit afd5726

Please sign in to comment.