-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[llama] Remove unnecessary model attribute assignment on 'freqs_cis' #1766
Conversation
CI failure seems to be a flake not your fault, id retry triggering build |
The upstream model code has a similar issue: https://github.com/facebookresearch/llama/blob/main/llama/model.py#L226 The convention is before we merge this PR, we require the PR author to create a PR to fix the upstream model code as well, and reference the upstream PR link in the code of torchbench model. |
@@ -224,8 +224,8 @@ def forward(self, tokens: torch.Tensor, start_pos: int): | |||
|
|||
h = self.tok_embeddings(tokens) | |||
|
|||
self.freqs_cis = self.freqs_cis.to(h.device) | |||
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] | |||
freqs_cis = self.freqs_cis.to(h.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a reference to meta-llama/llama#349 here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@xuzhao9 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Fixes #1767, more details there.
Upstream PR link meta-llama/llama#349