Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PhiMoE #33363

Merged
merged 31 commits into from
Oct 4, 2024
Merged

PhiMoE #33363

merged 31 commits into from
Oct 4, 2024

Conversation

garg-amit
Copy link
Contributor

@garg-amit garg-amit commented Sep 6, 2024

What does this PR do?

Integrates PhiMoE into transformers. https://huggingface.co/microsoft/Phi-3.5-MoE-instruct

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker @gante

@garg-amit
Copy link
Contributor Author

@ArthurZucker @gante can I please get a review?

@merryHunter
Copy link

Hi, it seems to be a very important and awaited PR!:) Other frameworks are willing to integrate MoE too, like in litgpt Lightning-AI/litgpt#1686.

@ArthurZucker
Copy link
Collaborator

We are very much willing to integrate it as well 🤗 just came back from the torch conf, was a bit OO because of it 😢

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!
Let's go with camel cased classes, If we want to be compile compatible we need to have a script conversion and use the formulation from gpt fast moe with a version implemented here: https://github.com/huggingface/transformers/pull/30793/files#diff-733ab0a772c69f78b1d8ed361e6ae1fda7243652887aed0bab5d3ecf07794c01R789

Lot's of stuff seems similar to phi3 so we can probably copy from it!

docs/source/en/perf_infer_gpu_one.md Outdated Show resolved Hide resolved
src/transformers/__init__.py Outdated Show resolved Hide resolved
src/transformers/models/phimoe/configuration_phimoe.py Outdated Show resolved Hide resolved
src/transformers/models/phimoe/configuration_phimoe.py Outdated Show resolved Hide resolved
src/transformers/models/phimoe/modeling_phimoe.py Outdated Show resolved Hide resolved
src/transformers/models/phimoe/modeling_phimoe.py Outdated Show resolved Hide resolved
src/transformers/models/phimoe/modeling_phimoe.py Outdated Show resolved Hide resolved
@ArthurZucker
Copy link
Collaborator

TLDR, overall the mixer needs to be properly documented and written to be more understandable!

@garg-amit
Copy link
Contributor Author

@ArthurZucker Thanks for reviewing the PR. I’ve refactored the code according to your suggestions, and it’s ready for another look. Also, the failing test case appears to be unrelated to this PR. Please let me know if it needs to be addressed.

@ArthurZucker
Copy link
Collaborator

Reviewing! 🤗

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, the only thing needed to merge:

  1. The Copied from need a capital letter
  2. The core part needs a tad bit more doc as I said, why do we need a specific gradient computation (had to go through the paper to see that indeed you need a special gradient approx)
  3. That part of the code is IMO less readable than the rest, but fine for now!
    THanks and sorry for the late revies!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you no longer need this complicated structred! See the __init__ for Albert for example!
You need to define a __all__ in the modeling and config and that's it

return torch.cat((-x2, x1), dim=-1)


# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb

Comment on lines 319 to 322
self.rotary_emb = PhimoeRotaryEmbedding(
config=self.config,
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO you can already put this outside the Attention layer, and remove the copied from mixtral to pass in the position embedding!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, moved it to the PhimoeModel class

return attn_output, attn_weights, past_key_value


# copied from transformers.models.mixtral.modeling_mixtral.MixtralFlashAttention2 with Mixtral->Phimoe
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# copied from transformers.models.mixtral.modeling_mixtral.MixtralFlashAttention2 with Mixtral->Phimoe
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralFlashAttention2 with Mixtral->Phimoe

}


# copied from transformers.models.mixtral.modeling_mixtral.MixtralBlockSparseTop2MLP with Mixtral->Phimoe
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# copied from transformers.models.mixtral.modeling_mixtral.MixtralBlockSparseTop2MLP with Mixtral->Phimoe
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralBlockSparseTop2MLP with Mixtral->Phimoe

Returns:
Tuple[torch.Tensor, torch.Tensor]: Multiplier and selected experts tensors.
"""
assert top_k == 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also let's raise an error rather than an assert!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


routing_weights, selected_experts = sparsemixer(
router_logits,
top_k=2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's hardcoded we can also just not put it!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker I’ve removed top_k from here and instead created it as a keyword argument.

config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True
)

# copied from transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# copied from transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward

@garg-amit
Copy link
Contributor Author

@ArthurZucker Thanks for reviewing! I've addressed the comments and moved PhimoeRotaryEmbedding out of the PhimoeAttention class.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Thanks for integrating this new model 🔥

kv_seq_len = hidden_states.shape[-2]
if past_key_values is not None:
kv_seq_len += past_key_values.get_usable_length(kv_seq_len)
position_embeddings = self.rotary_emb(hidden_states, seq_len=kv_seq_len)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty sure you should be using cache positions here! cache_position[0]!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's the last nit!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I've updated it to cache_position[-1]+1 as cache_position[0] would return 0 when the kv cache is empty.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed! 🤗

@ArthurZucker ArthurZucker merged commit e377553 into huggingface:main Oct 4, 2024
24 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks everyone and @garg-amit for bearing with me! Congrats on the model release again 🤗

dataKim1201 pushed a commit to dataKim1201/transformers that referenced this pull request Oct 7, 2024
* onboard phimoe model

* removed debug code

* added unit tests

* updated docs

* formatted

* fixed unit tests

* fixed test case

* fixed format

* refactored code

* fixed expected outputs in the integration tests

* Added a warning msg

* Addressed comments

* Addressed comments

* fixed test cases

* added paper link

* Addressed comments

* Refactored PhimoeForCausalLM forward fn

* Refactored PhimoeRotaryEmbedding class

* fixed test cases

* fixed testcase

* fixed test case

* Addressed comments

* fixed test cases

* fixed testcases

* Used cache position instead to get the seq len
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants