Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Gemma2 in torchtitan #594

Open
pansershrek opened this issue Oct 1, 2024 · 3 comments
Open

Support Gemma2 in torchtitan #594

pansershrek opened this issue Oct 1, 2024 · 3 comments
Labels
enhancement New feature or request

Comments

@pansershrek
Copy link

Are there any plans to support Gemma2 in the torchtitan? I tried to use torchtitan to finetune Gemma2 model, but stuck on the following problem: how to parallelize tied layer in Gemma2 model? Maybe somebody kwon the solution for this problem 😄

@awgu
Copy link
Contributor

awgu commented Oct 1, 2024

If you apply fully_shard to each transformer block and then to the root module, this should work for tied embedding and final linear. The root module will manage both.

@pansershrek
Copy link
Author

I want to shard output embedding layer - I use same strategy as in Llama, but training stacked after first butch
ColwiseParallel( input_layouts=Shard(1), output_layouts=Shard(-1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, )

@awgu
Copy link
Contributor

awgu commented Oct 1, 2024

Do you want to train with 2D parallelism (FSDP + TP)? With TP only?

@yf225 yf225 added the enhancement New feature or request label Oct 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants