-
Notifications
You must be signed in to change notification settings - Fork 489
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
iter
singular value into TBE optimizer state (#3228)
Summary: X-link: pytorch/torchrec#2474 Pull Request resolved: #3228 X-link: facebookresearch/FBGEMM#326 When the optimizer states for sharded embedding tables are tracked in TorchRec, they are assumed to be either point-wise (same shape as the embedding table, for example, Adam's exp_avg), or row-wise (same length as the embedding hashsize, for example, rowwise_adagrad's momentum/sum). However, there may be other formats, a single value for each table. Specifically, for Adam/Partial_rowwise_adam/Lamb/Partial_rowwise_lamb and GWD, the `iter` number is a single value tensor, which **cannot be tracked and checkpointed properly** (this also means that there is a bug in Adam/Partial_rowwise_adam/Lamb/Partial_rowwise_lamb usages!) Here we support tracking and checkpointing single-value states, by constructing ShardMetadata with rowwise-sharding and replicating the single-value for each Sharded param (this is similar to how the rowwise state for colume-wise sharded tables are concatenated along row-dim). By doing so, single-value `iter` can be properly checkpointed just like other states, ensuring correct reloading of states and continuous training. This diff checkpoints `iter` for rowwise_adagrad with GWD. The next diff would checkpoint `iter` for Adam/Partial_rowwise_adam/Lamb/Partial_rowwise_lamb. Reviewed By: iamzainhuda, spcyppt Differential Revision: D63909559 fbshipit-source-id: e14c1dc3e8f87bfc4cc95f2321b358526719d88f
- Loading branch information
1 parent
7ae6a75
commit f9f0600
Showing
2 changed files
with
28 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters