-
Notifications
You must be signed in to change notification settings - Fork 242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question about Gemma tensor parallel sharding policy #1464
Comments
@qlzh727 for thoughts. |
Thanks for the reporting of the issue. Can you share more reference of the "typical" sharding/layout here? I would like to take a closer look for that. |
@qlzh727 Thanks for the reply. The typical tensor parallel sharding policy I provided comes from the implementation of Megatron-LM and the documentation of HuggingFace Transformers:
|
Thanks for the information. I think in general they are just different ways to shard the tensor/weights, especially for different conditions. In your approach, it is doing matmul without allgather for qkv and do the collective afterwards (at dotprod of qk and softmax) because your qkv are sharded. Whereas the current Keras implementation will do collective at qkv matmul (since the contrast dimension is sharded), and avoid the collective afterward. It also depends on the cost of collectives (network connection) vs the local computation speed, as well as whether this model is just for prediction or it need finetune and weights update. I did some benchmark for this and the results are show below. I think your setting does have advantage for the finetune use case. I am testing this on a TPU v3-8 setting. Feel free to provide more benchmark result with GPU testing as well. (Smaller value are better) =================== ===================== |
Should be addressed by #1491 |
Thanks for Gemma model implementation. I found that the
layout_map
inGemmaBackbone.get_layout_map()
seems to show that it has completely opposite sharding policy compared to the typical TP sharding policy used in Transformer architecture.Typical:
[vocab_size, hidden_dim]
, along axis=0);[hidden_dim, hidden_dim]
, along axis=1);[hidden_dim, hidden_dim]
, along axis=0);[hidden_dim, intermediate_dim]
, along axis=1);[intermediate_dim, hidden_dim]
, along axis=0);Gemma:
layout_map["token_embedding/embeddings"] = (None, model_dim)
, seems to be sharded along hidden_dim axis;layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (None, model_dim, None)
, seems to be sharded along the row except the num_heads axis.layout_map["decoder_block.*attention_output.*kernel"] = (None, None, model_dim)
, seems to be sharded along the column except the num_heads axis.layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None)
, seems to be sharded along the row.layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim)
, seems to be sharded along the column.Is my understanding correct? If they are opposite, can you please explain the reason?
The text was updated successfully, but these errors were encountered: