Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about Gemma tensor parallel sharding policy #1464

Closed
AIGideon opened this issue Feb 23, 2024 · 5 comments
Closed

Question about Gemma tensor parallel sharding policy #1464

AIGideon opened this issue Feb 23, 2024 · 5 comments
Assignees
Labels
Gemma Gemma model specific issues

Comments

@AIGideon
Copy link

Thanks for Gemma model implementation. I found that the layout_map in GemmaBackbone.get_layout_map() seems to show that it has completely opposite sharding policy compared to the typical TP sharding policy used in Transformer architecture.

Typical:

  • embedding: embedding matrix is sharded along vocab_size axis (for matrix of shape [vocab_size, hidden_dim], along axis=0);
  • attention:
    parallelism-tp-parallel_self_attention
    • query|key|value dense kernel are sharded along the column (for kernel of shape [hidden_dim, hidden_dim], along axis=1);
    • output dense kernel is sharded along the row (for kernel of shape [hidden_dim, hidden_dim], along axis=0);
  • feedforward:
    parallelism-tp-parallel_shard_processing
    • gating dense (the first dense) kernel is sharded along the column (for kernel of shape [hidden_dim, intermediate_dim], along axis=1);
    • output dense (the second dense) kernel is sharded along the row (for kernel of shape [intermediate_dim, hidden_dim], along axis=0);

Gemma:

  • embedding: layout_map["token_embedding/embeddings"] = (None, model_dim), seems to be sharded along hidden_dim axis;
  • attention:
    • query|key|value dense: layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (None, model_dim, None), seems to be sharded along the row except the num_heads axis.
    • output dense: layout_map["decoder_block.*attention_output.*kernel"] = (None, None, model_dim), seems to be sharded along the column except the num_heads axis.
  • feedforward:
    • the first dense: layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None), seems to be sharded along the row.
    • the second dense: layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim), seems to be sharded along the column.

Is my understanding correct? If they are opposite, can you please explain the reason?

@mattdangerw
Copy link
Member

@qlzh727 for thoughts.

@qlzh727
Copy link
Member

qlzh727 commented Mar 5, 2024

Thanks for the reporting of the issue. Can you share more reference of the "typical" sharding/layout here? I would like to take a closer look for that.

@AIGideon
Copy link
Author

AIGideon commented Mar 6, 2024

@qlzh727 Thanks for the reply. The typical tensor parallel sharding policy I provided comes from the implementation of Megatron-LM and the documentation of HuggingFace Transformers:

@SuryanarayanaY SuryanarayanaY added the Gemma Gemma model specific issues label Mar 6, 2024
@qlzh727
Copy link
Member

qlzh727 commented Mar 6, 2024

Thanks for the information.

I think in general they are just different ways to shard the tensor/weights, especially for different conditions.

In your approach, it is doing matmul without allgather for qkv and do the collective afterwards (at dotprod of qk and softmax) because your qkv are sharded. Whereas the current Keras implementation will do collective at qkv matmul (since the contrast dimension is sharded), and avoid the collective afterward. It also depends on the cost of collectives (network connection) vs the local computation speed, as well as whether this model is just for prediction or it need finetune and weights update.

I did some benchmark for this and the results are show below. I think your setting does have advantage for the finetune use case. I am testing this on a TPU v3-8 setting. Feel free to provide more benchmark result with GPU testing as well.

(Smaller value are better)

===================
base line (current setting):
generate: 1342 ms per 100 token
finetune with lora: 125ms/step

=====================
Your setting
generate: 1501 ms per 100 token
finetune with lora: 77ms/step

@qlzh727
Copy link
Member

qlzh727 commented Mar 15, 2024

Should be addressed by #1491

@qlzh727 qlzh727 closed this as completed Mar 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues
Projects
None yet
Development

No branches or pull requests

4 participants