Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial grok #169

Merged
merged 52 commits into from
Sep 26, 2024
Merged

initial grok #169

merged 52 commits into from
Sep 26, 2024

Conversation

dan-garvey
Copy link
Member

@dan-garvey dan-garvey commented Sep 5, 2024

Initial grok work, also does some refactoring

@archana-ramalingam
Copy link
Collaborator

Posted a comment here, doesn't show up on the main page.

@dan-garvey
Copy link
Member Author

Posted a comment here, doesn't show up on the main page.

lets chat at meeting

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did say I didn't mind if we did this, but looking at it now, all the args are just values from the config anyway. I feel like separating them I think just makes our code harder to follow.

Not a big deal either way, but I'd personally prefer we just drop this commit. @KyleHerndon what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the right way to set things up. The bad argument is that llama.cpp has a similar setup, and we're referencing them for accuracy baselines. The better argument goes something like:

The KV cache is not ontologically related to the config, it just (currently) exclusively uses parameters from it. In the near future, we will want things like sharding, at which point the KV Cache might need additional args (number of devices, for example, which is not a ModelConfig parameter, but something like an ExecutionConfig parameter).

This change doesn't feel related to grok as much a code refactor, which I think is fine to include in a bigger patch for efficiency but might make things harder to follow when looking at commit history. Otherwise, I find the code just as easy to follow, if not slightly easier because I think it is better organized.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind either way, but do agree with Kyle that this keeps it better organized. We can consult Rob or Stella to see if there is a better way to do this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessary, if you two agree thats more than enough for me

@dan-garvey
Copy link
Member Author

looks good to me, @KyleHerndon @archana-ramalingam any final changes you think are needed?

@@ -233,7 +233,6 @@ def main():

device = torch.device(args.device) if args.device else None
activation_dtype = getattr(torch, args.activation_dtype)
attention_dtype = getattr(torch, args.attention_dtype)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch

@dan-garvey dan-garvey enabled auto-merge (squash) September 25, 2024 19:51
Copy link
Collaborator

@archana-ramalingam archana-ramalingam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested Llama and Grok models again and it's LGTM.

@sogartar
Copy link
Contributor

@dan-garvey, the error.
Seems to be related to calling __mul__ on the DefaultPrimitiveTensor.
This fails:

output = weight * to(output, weight.dtype)

but this is fine

output = elementwise(torch.mul, weight, to(output, weight.dtype))

__mul__'s implementation is just

    def __mul__(self, rhs):
        from ..ops import elementwise

        return elementwise(torch.mul, self, rhs)

@sogartar
Copy link
Contributor

When calling torch.export.export with strict=False, the error is not present, so it is about the export checks.

@sogartar
Copy link
Contributor

sogartar commented Sep 26, 2024

Instead of calling the * operator I changed the code to call a custom member function in weight

output = weight.my_mul(to(output, weight.dtype))
class InferenceTensor(ABC):

....

    def my_mul(self, rhs):
        from ..ops import elementwise

        return elementwise(torch.mul, self, rhs)

This is essentially the same code, but PyTorch does something special about binary operators. '+' also suffers from the same problem.

Also when running with env vars

TORCH_LOGS="+dynamo"
TORCHDYNAMO_VERBOSE=1

The tracer reports

V0926 06:19:31.456000 140060607524864 torch/_dynamo/symbolic_convert.py:798] [0/0] [__trace_bytecode] TRACE LOAD_FAST results []
V0926 06:19:31.456000 140060607524864 torch/_dynamo/symbolic_convert.py:798] [0/0] [__trace_bytecode] TRACE LOAD_CONST 0 [TupleVariable()]
V0926 06:19:31.456000 140060607524864 torch/_dynamo/symbolic_convert.py:798] [0/0] [__trace_bytecode] TRACE BINARY_SUBSCR None [TupleVariable(), ConstantVariable()]
V0926 06:19:31.456000 140060607524864 torch/_dynamo/symbolic_convert.py:798] [0/0] [__trace_bytecode] TRACE RETURN_VALUE None [TensorVariable()]
V0926 06:19:31.456000 140060607524864 torch/_dynamo/symbolic_convert.py:2807] [0/0] DONE INLINING <code object __call__ at 0x7f618b39f470, file "/home/bpetkant/ws/sharktank/repo/sharktank/sharktank/ops/_registry.py", line 196>
V0926 06:19:31.456000 140060607524864 torch/_dynamo/symbolic_convert.py:798] [0/0] [__trace_bytecode] TRACE BINARY_MULTIPLY None [UserDefinedObjectVariable(), TensorVariable()]
V0926 06:19:31.456000 140060607524864 torch/_dynamo/symbolic_convert.py:814] [0/0] empty checkpoint
V0926 06:19:31.456000 140060607524864 torch/_dynamo/symbolic_convert.py:2796] [0/0] FAILED INLINING <code object rms_norm_default at 0x7f618b2126b0, file "/home/bpetkant/ws/sharktank/repo/sharktank/sharktank/ops/default_impls.py", line 308>

It is doing something special about BINARY_MULTIPLY.

@sogartar
Copy link
Contributor

I opened an issue with PyTorch. I think that may be a bug there.

@dan-garvey dan-garvey merged commit 9f3f70f into main Sep 26, 2024
7 of 8 checks passed
@dan-garvey dan-garvey deleted the grokstar branch September 26, 2024 17:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants