Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Softmax kernel in Triton. Use softmax kernel and argmax in Llama generation.py. + Small changes #11

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

catherinelee274
Copy link

@catherinelee274 catherinelee274 commented Sep 24, 2024

  • Update .gitignore
  • Add README.md
  • Add README.md for pytests
  • Add fused softmax kernel and use in generation.py
  • Added pytest for softmax
  • Use triton.argmax in generation.py
  • Add line dist.destroy_process_group() to remove warning during benchmarking

Results from calling python3 main.py llama_chat_completion --benchmark --ckpt_dir <model_checkpoint_path> --tokenizer_path <model_tokenizer_path>

With No Changes:

|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| kernel                  | kernel_path                                                                                                              | triton                 | non_triton             | triton-non_triton       |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| chat_completion         | chat_completion                                                                                                          |     23.363035631999992 |      23.20719621399985 |     0.15583941800014145 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| chat_completion         | chat_completion.chat_completion                                                                                          |     15.086765727000056 |     15.037501877000068 |     0.04926384999998845 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| generate                | chat_completion.chat_completion.generate                                                                                 |     15.085606602000098 |     15.036371463999785 |     0.04923513800031287 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| softmax                 | chat_completion.chat_completion.generate.softmax                                                                         | 3.286413995026158e-05  | 3.2326578084264846e-05 | 5.375618659967339e-07   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| transformer_forward     | chat_completion.chat_completion.generate.transformer_forward                                                             |    0.02763666868558919 |     0.0275613940689651 | 7.527461662408877e-05   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| RMSNorm                 | chat_completion.chat_completion.generate.transformer_forward.RMSNorm                                                     | 5.827654970598593e-05  | 5.889426978790332e-05  | -6.177200819173836e-07  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| transform_block_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward                                     |  0.0008483634264688104 |  0.0008461694928391175 | 2.1939336296929223e-06  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| RMSNorm                 | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.RMSNorm                             | 5.88654680211981e-05   | 5.911337566830836e-05  | -2.4790764711026176e-07 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| attention_forward       | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward                   |   0.000491074871831931 |  0.0004896370724516646 | 1.437799380266374e-06   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| apply_rotary_emb        | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.apply_rotary_emb  | 8.806513197629571e-05  | 8.784457036059364e-05  | 2.2056161570207199e-07  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| attention               | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention         | 0.00017238609349734332 | 0.00017186252022342258 | 5.235732739207336e-07   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| matmul                  | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.matmul  | 4.327956193072774e-05  | 4.3119774561757945e-05 | 1.597873689697973e-07   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| softmax                 | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.softmax | 2.88178547136412e-05   | 2.8733377409596218e-05 | 8.4477304044983e-08     |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| feed_forward_forward    | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.feed_forward_forward                | 0.00012714412176862338 | 0.00012642639452434758 | 7.177272442757999e-07   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| precompute_freqs_cis    | chat_completion.precompute_freqs_cis                                                                                     |  0.0003592040002331487 | 0.00034214400011478574 | 1.706000011836295e-05   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|

With just softmax

|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| kernel                  | kernel_path                                                                                                              | triton                 | non_triton             | triton-non_triton       |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| chat_completion         | chat_completion                                                                                                          |     23.529098738999892 |                        |                         |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| chat_completion         | chat_completion.chat_completion                                                                                          |     15.229127281999808 |     23.885352947999763 |      -8.656225665999955 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| generate                | chat_completion.chat_completion.generate                                                                                 |     15.228001847000087 |     15.719125695999992 |      -0.491123848999905 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| softmax                 | chat_completion.chat_completion.generate.softmax                                                                         | 3.3104419885907815e-05 |     15.717982729999676 |      -15.71794962557979 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| transformer_forward     | chat_completion.chat_completion.generate.transformer_forward                                                             |   0.028188293081122473 |   0.028127436847877902 | 6.0856233244570984e-05  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| RMSNorm                 | chat_completion.chat_completion.generate.transformer_forward.RMSNorm                                                     | 5.8095263698646935e-05 | 5.832987626748923e-05  | -2.3461256884229476e-07 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| transform_block_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward                                     |  0.0008650784662146882 |  0.0008631522671130063 | 1.9261991016819406e-06  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| RMSNorm                 | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.RMSNorm                             | 5.8953483520847516e-05 | 5.920526296552413e-05  | -2.517794446766161e-07  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| attention_forward       | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward                   |   0.000502045900099974 |  0.0005004565054525663 | 1.5893946474076093e-06  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| apply_rotary_emb        | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.apply_rotary_emb  | 8.903067526494468e-05  | 8.906119377373757e-05  | -3.05185087928929e-08   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| attention               | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention         | 0.00017652477433859322 | 0.00017632209013441347 | 2.0268420417975867e-07  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| matmul                  | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.matmul  | 4.470880419631436e-05  | 4.464734606365697e-05  | 6.145813265738515e-08   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| softmax                 | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.softmax | 2.94319926492972e-05   | 2.9444132737926863e-05 | -1.2140088629663967e-08 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| feed_forward_forward    | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.feed_forward_forward                | 0.00013109435198944578 | 0.00013054154848626072 | 5.528035031850662e-07   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| precompute_freqs_cis    | chat_completion.precompute_freqs_cis                                                                                     |  0.0003498339997349831 |  0.0003586540001379035 | -8.820000402920414e-06  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|

With softmax and argmax

|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| kernel                  | kernel_path                                                                                                              | triton                 | non_triton             | triton-non_triton       |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| chat_completion         | chat_completion                                                                                                          |     23.316155643000002 |                        |                         |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| chat_completion         | chat_completion.chat_completion                                                                                          |     15.026287104999938 |       23.7322098059999 |      -8.705922700999963 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| generate                | chat_completion.chat_completion.generate                                                                                 |     15.025166173999878 |     15.588987290000205 |     -0.5638211160003266 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| softmax                 | chat_completion.chat_completion.generate.softmax                                                                         | 3.194667139574176e-05  |     15.587871648000146 |      -15.58783970132875 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| transformer_forward     | chat_completion.chat_completion.generate.transformer_forward                                                             |    0.02752115360446059 |   0.027770376318450807 |   -0.000249222713990218 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| RMSNorm                 | chat_completion.chat_completion.generate.transformer_forward.RMSNorm                                                     | 5.7826257591456e-05    | 5.821012373749101e-05  | -3.838661460350075e-07  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| transform_block_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward                                     |  0.0008447297065781728 |   0.000852438977687303 | -7.709271109130212e-06  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| RMSNorm                 | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.RMSNorm                             | 5.789487043606993e-05  | 5.838420017865719e-05  | -4.893297425872609e-07  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| attention_forward       | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward                   |  0.0004889343329106125 |  0.0004935158465387692 | -4.581513628156746e-06  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| apply_rotary_emb        | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.apply_rotary_emb  | 8.793432188105586e-05  | 8.822047401022832e-05  | -2.861521291724629e-07  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| attention               | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention         | 0.00017178861625323266 | 0.00017370242444190314 | -1.9138081886704806e-06 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| matmul                  | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.matmul  | 4.3013189750603796e-05 | 4.373048187156608e-05  | -7.172921209622871e-07  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| softmax                 | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.softmax | 2.8741690918662413e-05 | 2.8887915059976048e-05 | -1.462241413136349e-07  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| feed_forward_forward    | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.feed_forward_forward                | 0.00012659588773971398 | 0.00012815826540009848 | -1.5623776603845016e-06 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| precompute_freqs_cis    | chat_completion.precompute_freqs_cis                                                                                     |  0.0003504739997879369 |  0.0003422939998927177 | 8.179999895219225e-06   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|

@catherinelee274 catherinelee274 changed the title Add Softmax kernel in triton. Use softmax and argmax in llama generation. Add Softmax kernel in Triton. Use softmax kernel and argmax in Llama generation.py. + Small changes Sep 24, 2024
@catherinelee274 catherinelee274 marked this pull request as ready for review September 24, 2024 08:08
import pandas as pd


def compare_benchmarks(benchmarks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason you are deleting this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added back

models/llama/llama/generation.py Outdated Show resolved Hide resolved
models/llama/llama/generation.py Outdated Show resolved Hide resolved
import pytest
from kernels.fused_softmax import triton_softmax

@pytest.mark.parametrize("input_size", [(1024, 1024), (512, 512), (2048, 512)])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding tests!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, this might not be ideal because we are not calling softmax from triton.ops like the other tests. I ran into issues with doing it that way.

@@ -70,14 +71,15 @@ def attention(self, xq, keys, values, head_dim, mask):

@Profiler.profiling_decorator("softmax")
def softmax(self, x, dim):
if self.use_triton:
return F.softmax(x, dim=-1)
if self.use_triton and len(x) == 2:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you're trying to check the number of dimensions here, right? len(x) gets the number of elements, equivalent to x.numel(). I think you want x.dim() or x.ndim.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if self.use_triton:
return F.softmax(x, dim=-1)
if self.use_triton and len(x) == 2:
return triton_softmax(x, dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we passing dim=-1 to these calls, when we receive dim as an argument? Let's pass it through properly instead of overriding it. (Also, does the fused Triton kernel actually handle dim!=-1 correctly?)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently it does not handle dim != -1 . Looking into it (seeing how llama.cpp is doing this) if you have any pointers.

else:
return F.softmax(x, dim=-1)

@Profiler.profiling_decorator("argmax")
def argmax(self, x, dim):
if self.use_triton:
# TODO: change
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding a TODO to the code here, would you mind creating an issue to track it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

models/llama/llama/generation.py Outdated Show resolved Hide resolved
- Rename certain functions to conform with naming scheme
- Current triton softmax does not handle > 2 dimensions but will need to investigate (probably by looking at llama.cpp)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants