Skip to content

Commit

Permalink
Add iter singular value into TBE optimizer state (#3228)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/torchrec#2474

Pull Request resolved: #3228

X-link: facebookresearch/FBGEMM#326

When the optimizer states for sharded embedding tables are tracked in TorchRec, they are assumed to be either point-wise (same shape as the embedding table, for example, Adam's exp_avg), or row-wise (same length as the embedding hashsize, for example, rowwise_adagrad's momentum/sum). However, there may be other formats, a single value for each table. Specifically, for Adam/Partial_rowwise_adam/Lamb/Partial_rowwise_lamb and GWD, the `iter` number is a single value tensor, which **cannot be tracked and checkpointed properly** (this also means that there is a bug in Adam/Partial_rowwise_adam/Lamb/Partial_rowwise_lamb usages!)

Here we support tracking and checkpointing single-value states, by constructing ShardMetadata with rowwise-sharding and replicating the single-value for each Sharded param (this is similar to how the rowwise state for colume-wise sharded tables are concatenated along row-dim).

By doing so, single-value `iter` can be properly checkpointed just like other states, ensuring correct reloading of states and continuous training.

This diff checkpoints `iter` for rowwise_adagrad with GWD. The next diff would checkpoint `iter` for Adam/Partial_rowwise_adam/Lamb/Partial_rowwise_lamb.

Reviewed By: iamzainhuda, spcyppt

Differential Revision: D63909559

fbshipit-source-id: e14c1dc3e8f87bfc4cc95f2321b358526719d88f
  • Loading branch information
Wang Zhou authored and facebook-github-bot committed Oct 11, 2024
1 parent 7ae6a75 commit f9f0600
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,7 @@ def __init__( # noqa C901
torch.zeros(1, dtype=torch.int64, device=self.current_device),
persistent=False,
)
self.iter_int: int = 0

cache_state = construct_cache_state(rows, locations, self.feature_table_map)

Expand Down Expand Up @@ -1791,10 +1792,12 @@ def forward( # noqa: C901
offsets=self.momentum2_offsets,
placements=self.momentum2_placements,
)
# Ensure iter is always on CPU so the increment doesn't synchronize.
if not self.iter.is_cpu:
self.iter = self.iter.cpu()
self.iter[0] += 1
# Sync with loaded state
if self.iter_int == 0:
self.iter_int = int(self.iter.item())
# Increment the iteration counter
self.iter_int += 1 # used for local computation
self.iter.add_(1) # used for checkpointing

if self.optimizer == OptimType.ADAM:
return self._report_io_size_count(
Expand All @@ -1804,9 +1807,7 @@ def forward( # noqa: C901
self.optimizer_args,
momentum1,
momentum2,
# pyre-fixme[6]: Expected `int` for 5th param but got `Union[float,
# int]`.
self.iter.item(),
self.iter_int,
),
)
if self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
Expand All @@ -1817,9 +1818,7 @@ def forward( # noqa: C901
self.optimizer_args,
momentum1,
momentum2,
# pyre-fixme[6]: Expected `int` for 5th param but got `Union[float,
# int]`.
self.iter.item(),
self.iter_int,
),
)
if self.optimizer == OptimType.LAMB:
Expand All @@ -1830,9 +1829,7 @@ def forward( # noqa: C901
self.optimizer_args,
momentum1,
momentum2,
# pyre-fixme[6]: Expected `int` for 5th param but got `Union[float,
# int]`.
self.iter.item(),
self.iter_int,
),
)
if self.optimizer == OptimType.PARTIAL_ROWWISE_LAMB:
Expand All @@ -1843,9 +1840,7 @@ def forward( # noqa: C901
self.optimizer_args,
momentum1,
momentum2,
# pyre-fixme[6]: Expected `int` for 5th param but got `Union[float,
# int]`.
self.iter.item(),
self.iter_int,
),
)

Expand Down Expand Up @@ -1883,7 +1878,7 @@ def forward( # noqa: C901
if self._used_rowwise_adagrad_with_counter:
if (
self._max_counter_update_freq > 0
and self.iter.item() % self._max_counter_update_freq == 0
and self.iter_int % self._max_counter_update_freq == 0
):
row_counter_dev = self.row_counter_dev.detach()
if row_counter_dev.numel() > 0:
Expand All @@ -1901,24 +1896,21 @@ def forward( # noqa: C901
momentum1,
prev_iter,
row_counter,
int(
self.iter.item()
), # Cast to int to suppress pyre type error
self.iter_int,
self.max_counter.item(),
),
)
elif self._used_rowwise_adagrad_with_global_weight_decay:
iter_ = int(self.iter.item())
apply_global_weight_decay = (
iter_ >= self.gwd_start_iter and self.training
self.iter_int >= self.gwd_start_iter and self.training
)
return self._report_io_size_count(
"fwd_output",
invokers.lookup_rowwise_adagrad.invoke(
common_args,
self.optimizer_args,
momentum1,
iter=iter_,
iter=self.iter_int,
apply_global_weight_decay=apply_global_weight_decay,
prev_iter_dev=self.prev_iter_dev,
gwd_lower_bound=self.gwd_lower_bound,
Expand All @@ -1943,14 +1935,14 @@ def ensemble_and_swap(self, ensemble_mode: Dict[str, float]) -> None:
Returns:
Sparse embedding weights and optimizer states will be updated in-place.
"""
should_ema = self.iter.item() % int(ensemble_mode["step_ema"]) == 0
should_swap = self.iter.item() % int(ensemble_mode["step_swap"]) == 0
should_ema = self.iter_int % int(ensemble_mode["step_ema"]) == 0
should_swap = self.iter_int % int(ensemble_mode["step_swap"]) == 0
if should_ema or should_swap:
weights = self.split_embedding_weights()
states = self.split_optimizer_states()
coef_ema = (
0.0
if self.iter.item() <= int(ensemble_mode["step_start"])
if self.iter_int <= int(ensemble_mode["step_start"])
else ensemble_mode["step_ema_coef"]
)
for i in range(len(self.embedding_specs)):
Expand Down Expand Up @@ -2337,7 +2329,7 @@ def get_optimizer_buffer(self, state: str) -> torch.Tensor:
for name, buffer in self.named_buffers():
if name == state:
return buffer
return torch.tensor(0)
raise ValueError(f"Optimizer buffer {state} not found")

@torch.jit.export
def get_optimizer_state(self) -> List[Dict[str, torch.Tensor]]:
Expand All @@ -2355,7 +2347,11 @@ def get_optimizer_state(self) -> List[Dict[str, torch.Tensor]]:
{"sum": states[0], "prev_iter": states[1], "row_counter": states[2]}
if self._used_rowwise_adagrad_with_counter
else (
{"sum": states[0], "prev_iter": states[1]}
{
"sum": states[0],
"prev_iter": states[1],
"iter": self.iter,
}
if self._used_rowwise_adagrad_with_global_weight_decay
else {"sum": states[0]}
)
Expand Down Expand Up @@ -2583,9 +2579,9 @@ def set_optimizer_step(self, step: int) -> None:
Sets the optimizer step.
Args:
step (int): The setp value to set to
step (int): The step value to set to
"""
self.log(f"set_optimizer_step from {self.iter[0]} to {step}")
self.log(f"set_optimizer_step from {self.iter[0]=} to {step=}")
if self.optimizer == OptimType.NONE:
raise NotImplementedError(
f"Setting optimizer step is not supported for {self.optimizer}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,8 @@ def execute_global_weight_decay( # noqa C901
gwd_lower_bound,
)
if i != 1:
tbe.step = i - 1 # step will be incremented when forward is called
tbe.iter = torch.Tensor([tbe.step])
tbe.iter_int = i - 1 # step will be incremented when forward is called
tbe.iter = torch.Tensor([tbe.iter_int])

# Run forward pass
output = tbe(
Expand Down

0 comments on commit f9f0600

Please sign in to comment.