-
Notifications
You must be signed in to change notification settings - Fork 255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New DataPartitionType DATA #567
base: main
Are you sure you want to change the base?
Conversation
# Data are fully replicated across all devices. | ||
REPLICATED = "replicated" | ||
# Data are partially partitioned across data axis | ||
DATA = "data" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A high level question, what is the purpose of this change?
I see that we have FULL
partition support already, which partitions on axis=0 which is the data axis, how is DATA
different from FULL
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DATA replicates over the sequence dimension. so the spec is ("data", None) versus ("data", "model") for FULL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Increases memory efficiency
Do you have measurements on how DATA improves memory efficiency? Thanks.
By replicating the sequence length over TP workers we limit collectives and dynamic-slices introduced by the SPMD partitioner. This lowers overall step time and also allows us to run sequence parallelism over TP workers. |
Thanks. Do you have quantitative measurements? |
No, we do not. It is more when we inspect the HLO after SPMD partition pass we see much more optimal sharding. Less all-to-alls and less dynamic-slices on right hand side. |
Increases memory efficiency during large scale training, input batches and labels are sharded along the 'data' axis.
Added new input data sharding option
DataPartitionType.DATA
.