Skip to content

Commit

Permalink
Benchmark int4 gemm implementations (#2261)
Browse files Browse the repository at this point in the history
Summary:
Focusing on llama2-70b inference, this compares tinygemm (_weight_int4pack_mm) to a Triton implementation.

Note that these are not numerically equivalent now, as the Triton implementation does not apply scale and zero point.  TODO!

Pull Request resolved: #2261

Test Plan:
```
pytorch run_benchmark.py triton --op int4_gemm
```

Performance in TFLOPS (although latency or maybe b/w would be better for the decoding sizes)
```
int4-gemm-performance:
       B       m       n       k   tinygemm      triton
0    1.0     1.0  1280.0  8192.0   1.318632    0.146058
1    1.0     1.0  7168.0  8192.0   3.097060    0.763632
2    1.0     1.0  8192.0  1024.0   1.533006    0.279546
3    1.0     1.0  8192.0  3584.0   2.702516    0.842907
4    1.0  4096.0  1280.0  8192.0  29.048470  160.250400
5    1.0  4096.0  7168.0  8192.0  29.133376  260.258231
6    1.0  4096.0  8192.0  1024.0  27.168206  244.825143
7    1.0  4096.0  8192.0  3584.0  28.823074  251.058619
8    4.0     1.0  1280.0  8192.0   5.002748    0.582154
9    4.0     1.0  7168.0  8192.0  12.042711    3.029316
10   4.0     1.0  8192.0  1024.0   5.991863    1.152598
11   4.0     1.0  8192.0  3584.0  10.309034    3.343978
12   4.0  4096.0  1280.0  8192.0  29.119370  254.743029
13   4.0  4096.0  7168.0  8192.0  29.114149  261.838071
14   4.0  4096.0  8192.0  1024.0  27.197883  262.376217
15   4.0  4096.0  8192.0  3584.0  28.852035  255.000142
16  16.0     1.0  1280.0  8192.0  12.834468    2.328100
17  16.0     1.0  7168.0  8192.0  24.755590   12.119764
18  16.0     1.0  8192.0  1024.0  15.650388    4.454917
19  16.0     1.0  8192.0  3584.0  21.683995   13.388112
20  16.0  4096.0  1280.0  8192.0  29.148757  260.059109
21  16.0  4096.0  7168.0  8192.0  29.150796  269.369526
22  16.0  4096.0  8192.0  1024.0  27.245403  272.193003
23  16.0  4096.0  8192.0  3584.0  29.095535  263.722304
24  64.0     1.0  1280.0  8192.0  20.799920    9.287653
25  64.0     1.0  7168.0  8192.0  27.319690   38.018942
26  64.0     1.0  8192.0  1024.0  22.919694   17.567766
27  64.0     1.0  8192.0  3584.0  26.492333   42.643613
28  64.0  4096.0  1280.0  8192.0  29.424380  265.626026
29  64.0  4096.0  7168.0  8192.0  29.391893  260.692446
30  64.0  4096.0  8192.0  1024.0  27.452366  269.505851
31  64.0  4096.0  8192.0  3584.0  29.098774  252.047059
```

Reviewed By: xuzhao9

Differential Revision: D57217321

Pulled By: bertmaher

fbshipit-source-id: 3cc24d2cf57c8277189799eb1a61bb9a1d8ca5e3
  • Loading branch information
bertmaher authored and facebook-github-bot committed May 16, 2024
1 parent c9f6193 commit 43621bc
Show file tree
Hide file tree
Showing 3 changed files with 312 additions and 0 deletions.
1 change: 1 addition & 0 deletions torchbenchmark/operators/int4_gemm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .int4_gemm import Operator
145 changes: 145 additions & 0 deletions torchbenchmark/operators/int4_gemm/int4_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
Compute a bf16 (activation) x int4 (weight) gemm.
Inspired by [gpt-fast](https://github.com/pytorch-labs/gpt-fast)
ATen kernels from tinygemm
Triton implementation by @jlebar: https://gist.github.com/jlebar/3435b2c00deea53258887ce37231e5e2
"""

import argparse
import os
import statistics
import torch
import triton.ops
import triton.language as tl

from typing import Any

from torchbenchmark.util.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
register_benchmark,
register_metric,
)

from .kernel import pack_2xint4, matmul, matmul_kernel


class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "gbps", "latency"]

def __init__(self, mode, device, extra_args):
super().__init__(mode=mode, device=device, extra_args=extra_args)
# `Group size` and `inner K tiles` are defaults from gpt-fast.
self.group_size = 32
self.inner_k_tiles = 8

def get_input_iter(self):
def args(B, L, Dout, Din):
x = torch.randn(B, L, Din, device=self.device, dtype=torch.bfloat16)
w = torch.randint(-8, 7, (Din, Dout), device=self.device, dtype=torch.int32)
scales_and_zeros = torch.randn(
Din // self.group_size,
Dout,
2,
device=self.device,
dtype=torch.bfloat16,
)
return (x, w, scales_and_zeros)

# LLama-2 shapes w/ 8-way tensor parallelism.
name_to_shapes_70b = {
"attn.wqkv": (8192, 1280),
"attn.w0": (1024, 8192),
"ffn.w13": (8192, 7168),
"ffn.w2": (3584, 8192),
}
for seq_len in (1, 4096):
for bsz in (1, 4, 16, 64):
for name, (k, n) in name_to_shapes_70b.items():
yield args(bsz, seq_len, n, k)

def get_x_val(self, example_inputs) -> float:
x, w, scales_and_zeros = example_inputs
B, m, k = x.size()
_, n = w.size()
return (B, m, n, k)

@register_benchmark(baseline=True)
def tinygemm(self, x, w, scales_and_zeros):
x = x.reshape(-1, x.size(-1))
w_int4 = torch.ops.aten._convert_weight_to_int4pack(
w.T.contiguous(), self.inner_k_tiles
)
return lambda: torch.ops.aten._weight_int4pack_mm(
x, w_int4, self.group_size, scales_and_zeros
)

@register_benchmark()
def triton(self, x, w, scales_and_zeros):
x = x.reshape(-1, x.size(-1))
w_int4 = pack_2xint4(w).T.contiguous().T
return lambda: matmul(x, w_int4)

@register_metric()
def best_config(self, fn, inputs, metrics):
if "triton" in str(fn):
return str(matmul_kernel.best_config)
return ""

@register_metric()
def gbps(self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> float:
def nbytes(t):
return t.numel() * t.element_size()

x, w, scale_and_zero = example_inputs
c = fn()

gb = (sum(nbytes(t) for t in (x, scale_and_zero, c)) + nbytes(w) // 8) / 1e9
return list(map(lambda ms: gb / ms * 1e3, metrics.latency))

@register_metric()
def tflops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
a, b, _ = example_inputs
B, m, k = a.size()
m = B * m
_, n = b.size()
flops = 2 * m * n * k
return [flops / x / 1e12 * 1e3 for x in metrics.latency]

def plot(self):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=[
"B",
"m",
"n",
"k",
], # argument names to use as an x-axis for the plot
x_vals=self.output.x_vals, # different possible values for `x_name`
line_arg="provider", # argument name whose value corresponds to a different line in the plot
line_vals=[
"tinygemm",
"triton",
], # possible values for `line_arg``
line_names=[
"tinygemm",
"triton",
], # label name for the lines
styles=[("blue", "-"), ("green", "-")],
ylabel="tflops", # label name for the y-axis
plot_name="int4-gemm-performance", # name for the plot. Used also as a file name for saving the plot.
args={}, # values for function arguments not in `x_names` and `y_name`
)
)
def _plot(B, m, n, k, provider):
tflops = self.output.get_y_vals((B, m, n, k), provider, "tflops")
return tflops

save_path = "/tmp/int4_gemm"

if not os.path.exists(save_path):
os.mkdir(save_path)

_plot.run(show_plots=True, print_data=True, save_path=save_path)
166 changes: 166 additions & 0 deletions torchbenchmark/operators/int4_gemm/kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""
Triton implementation by @jlebar: https://gist.github.com/jlebar/3435b2c00deea53258887ce37231e5e2
"""

import torch
import triton
import triton.language as tl

AUTOTUNE_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
},
num_stages=4,
num_warps=8,
),
]


@triton.autotune(configs=AUTOTUNE_CONFIGS, key=["M", "N", "K"])
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
# Matrix dimensions.
M,
N,
K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
#
# We assume `b` is packed with 2 `int4` elements per K, i.e. it's a
# (K//2)xNx(2xint4) matrix, represented in Triton as (K//2)xNxi8. If K
# is the minor dimension, then stride_bk should logically be 0.5. But
# we don't want a fractional stride! So let the given stride be the
# stride per 2xint4.
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
tl.device_assert(K % BLOCK_SIZE_K == 0)

# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K // 2, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetic` section for details
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_ak = tl.arange(0, BLOCK_SIZE_K)
offs_bk = tl.arange(0, BLOCK_SIZE_K // 2)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_ak[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs)
tl.static_assert(b.dtype == tl.int8)

# Unpack `b` into an fp16 matrix, taking care to sign-extend b_lo. Use
# _4_i8 because the literal "4" is considered an i32, which causes the
# shift operands to be widened to i32.
_4_i8 = tl.full((1,), 4, dtype=tl.int8)
b_lo = (b << _4_i8) >> _4_i8
b_hi = b >> _4_i8
# Workaround: Convert before the join() so that Triton can load the data
# after the join using ldmatrix.
b_f16 = (
tl.join(b_lo.to(tl.bfloat16), b_hi.to(tl.bfloat16))
.permute(0, 2, 1)
.reshape(BLOCK_SIZE_K, BLOCK_SIZE_N)
)

accumulator += tl.dot(a, b_f16)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk // 2

c = accumulator.to(tl.bfloat16)

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)


def matmul(a, b):
assert a.shape[1] == b.shape[0] * 2, "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
_, N = b.shape

c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
matmul_kernel[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
)
return c


def pack_2xint4(t):
# Packs a KxNxfp16 matrix into a (K//2)xNx(2xint4) matrix.
t = t.to(torch.int8).reshape(t.shape[0] // 2, 2, t.shape[1]).permute(1, 0, 2)
return (t[0] & 0xF) | (t[1] << 4)

0 comments on commit 43621bc

Please sign in to comment.