Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add embedding, rms_norm, rope #2517

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchbenchmark/operators/embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .operator import Operator
53 changes: 53 additions & 0 deletions torchbenchmark/operators/embedding/operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import argparse
from typing import Callable, Generator, List, Optional

import torch
from torch.nn import Embedding

from torchbenchmark.util.triton_op import BenchmarkOperator, register_benchmark

try:
from liger_kernel.transformers.experimental.embedding import LigerEmbedding
except ModuleNotFoundError:
LigerEmbedding = None

# Reference: https://github.com/linkedin/Liger-Kernel/
# blob/main/benchmark/scripts/benchmark_embedding.py


class Operator(BenchmarkOperator):
def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args, extra_args)
# they are generated later
self.baseline_op = None
self.liger_op = None
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for B, T, D in [(32, 512, 768), (8, 2048, 4096)]:
for V in [2**i for i in range(10, 18)]:
_input = torch.randint(0, V, (B, T), device=self.device)
yield V, D, _input

@register_benchmark(baseline=True)
def torch_embedding(self, V, D, input) -> Callable:
self.baseline_op = Embedding(V, D).to(self.device).to(self.dtype)
return lambda: self.baseline_op(input)

@register_benchmark()
def liger_embedding(self, V, D, input) -> Callable:
self.liger_op = LigerEmbedding(V, D).to(self.device).to(self.dtype)
return lambda: self.liger_op(input)

@register_benchmark()
def inductor_embedding(self, V, D, input) -> Callable:
self.baseline_op = Embedding(V, D).to(self.device).to(self.dtype)
compiled = torch.compile(self.baseline_op, dynamic=False)
return lambda: compiled(input)

def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
y = fwd_fn()
do = torch.randn_like(y)
return lambda: y.backward(do)
1 change: 1 addition & 0 deletions torchbenchmark/operators/rms_norm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .operator import Operator
70 changes: 70 additions & 0 deletions torchbenchmark/operators/rms_norm/operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import argparse
from typing import Callable, Generator, List, Optional

import torch

from torchbenchmark.util.triton_op import BenchmarkOperator, register_benchmark

try:
from liger_kernel.transformers.rms_norm import LigerRMSNorm
except ModuleNotFoundError:
LigerRMSNorm = None

# Reference: https://github.com/linkedin/Liger-Kernel/
# blob/main/benchmark/scripts/benchmark_rms_norm.py


class LlamaRMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)


class Operator(BenchmarkOperator):
def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args, extra_args)
self.M = 2048
self.eps = 1e-6
# they are generated later
self.llama_rms_op = None
self.liger_rms_op = None
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for H in [2**i for i in range(10, 16)]:
x_shape = (self.M, H)
_input = torch.randn(x_shape, dtype=self.dtype, device=self.device)
yield H, _input

@register_benchmark(baseline=True)
def llama_rms(self, H, input) -> Callable:
self.llama_rms_op = LlamaRMSNorm(hidden_size=H, eps=self.eps).to(self.device)
return lambda: self.llama_rms_op(input)

@register_benchmark()
def liger_rms(self, H, input) -> Callable:
self.liger_rms_op = LigerRMSNorm(hidden_size=H, eps=self.eps).to(self.device)
return lambda: self.liger_rms_op(input)

@register_benchmark()
def inductor_rms(self, H, input) -> Callable:
compiled = torch.compile(self.llama_rms_op, dynamic=False)
return lambda: compiled(input)

def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
y = fwd_fn()
do = torch.randn_like(y)
return lambda: y.backward(do, retain_graph=True)
1 change: 1 addition & 0 deletions torchbenchmark/operators/rope/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .operator import Operator
104 changes: 104 additions & 0 deletions torchbenchmark/operators/rope/operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import argparse
from typing import Callable, Generator, List, Optional

import torch
from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb,
LlamaRotaryEmbedding,
)

from torchbenchmark.util.triton_op import BenchmarkOperator, register_benchmark

try:
from liger_kernel.transformers.rope import liger_rotary_pos_emb
except ModuleNotFoundError:
liger_rotary_pos_emb = None

# Reference: https://github.com/linkedin/Liger-Kernel/
# blob/main/benchmark/scripts/benchmark_rope.py


class Operator(BenchmarkOperator):
def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args, extra_args)
# they are generated later
self.baseline_op = None
self.liger_op = None
self.use_cuda_graphs = False
self.num_q_heads = 32
self.num_kv_heads = 8

def get_input_iter(self) -> Generator:
hidden_size = 8192
for seq_length in [2**i for i in range(10, 15)]:
yield hidden_size, seq_length

seq_length = 2048
for hidden_size in [32 * (2**i) for i in range(4, 10, 2)]:
yield hidden_size, seq_length

def prepare_input(self, hidden_size, seq_length):
head_dim = hidden_size // self.num_q_heads
rotary_emb = LlamaRotaryEmbedding(head_dim, device=self.device)
q = torch.randn(
(1, seq_length, self.num_q_heads, head_dim),
device=self.device,
requires_grad=True,
dtype=self.dtype,
).transpose(1, 2)
k = torch.randn(
(1, seq_length, self.num_kv_heads, head_dim),
device=self.device,
requires_grad=True,
dtype=self.dtype,
).transpose(1, 2)
dq, dk = torch.randn_like(
q, device=self.device, dtype=self.dtype
), torch.randn_like(k, device=self.device)
pos_ids = torch.arange(
seq_length, device=self.device, dtype=torch.long
).unsqueeze(0)
cos, sin = rotary_emb(k, pos_ids)
# save q,k to self for grad_to_none
self.q = q
self.k = k
# save dq,dk to self for backward
self.dq = dq
self.dk = dk
return q, k, cos, sin, pos_ids

@register_benchmark(baseline=True)
def apply_rotary_pos_emb(self, hidden_size, seq_length) -> Callable:
q, k, cos, sin, pos_ids = self.prepare_input(hidden_size, seq_length)
return lambda: apply_rotary_pos_emb(q, k, cos, sin, pos_ids)

@register_benchmark()
def liger_rotary_pos_emb(self, hidden_size, seq_length) -> Callable:
q, k, cos, sin, pos_ids = self.prepare_input(hidden_size, seq_length)
return lambda: liger_rotary_pos_emb(q, k, cos, sin, pos_ids)

@register_benchmark()
def inductor_rotary_pos_emb_full_op(self, hidden_size, seq_length) -> Callable:
q, k, cos, sin, pos_ids = self.prepare_input(hidden_size, seq_length)
head_dim = hidden_size // self.num_q_heads
compiled = torch.compile(
LlamaRotaryEmbedding(head_dim, device=self.device), dynamic=False
)
cos, sin = compiled(k, pos_ids)
compiled_func = torch.compile(apply_rotary_pos_emb, dynamic=False)
return lambda: compiled_func(q, k, cos, sin, pos_ids)

def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
q_out, k_out = fwd_fn()
return lambda: torch.autograd.grad(
(q_out, k_out),
(self.q, self.k),
(self.dq, self.dk),
allow_unused=True,
retain_graph=True,
)

def get_grad_to_none(self, args) -> List[torch.Tensor]:
return [self.q, self.k]
2 changes: 1 addition & 1 deletion torchbenchmark/operators_collection/liger/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
liger_operators = ["FusedLinearCrossEntropy"]
liger_operators = ["FusedLinearCrossEntropy", "rope", "rms_norm", "embedding"]


def get_operators():
Expand Down
Loading