Skip to content

Commit

Permalink
Add special einsum cases that lower to batch matmul (#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleHerndon authored Oct 14, 2024
1 parent 355761b commit acd77e3
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 23 deletions.
40 changes: 26 additions & 14 deletions sharktank/sharktank/kernels/mmt_block_scaled_offset_q4.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,38 +37,43 @@ def select(self, ksel: KernelSelection):
m_desc = ksel.arg_tensor(3) # Shape [N, K // BLOCK_SIZE, 1]

# a arg
*batch_dims, a_m, a_k = a_desc.t.shape
*a_batch_dims, a_m, a_k = a_desc.t.shape
torch._check(
a_desc.t.dtype.is_floating_point,
lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'a': Expected floating point (got {a_desc.t.dtype})",
)
torch._check(
len(batch_dims) == 1,
len(a_batch_dims) == 1,
lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'a': Expected 3d tensor (got {a_desc.t.shape})",
)

# qs arg
qs_n, qs_group0, qs_bs_div_2, *rest = qs_desc.t.shape
*qs_batch_dims, qs_n, qs_group0, qs_bs_div_2 = qs_desc.t.shape
torch._check(
len(rest) == 0 and (qs_group0 * qs_bs_div_2 * 2) == a_k,
(
len(qs_batch_dims) == 0
or len(qs_batch_dims) == 1
and qs_batch_dims == a_batch_dims
)
and (qs_group0 * qs_bs_div_2 * 2) == a_k,
lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'qs': Incorrect shape (got {qs_desc.t.shape})",
)
block_size = qs_bs_div_2 * 2

# d arg
d_n, d_group0, d_one, *rest = d_desc.t.shape
*d_batch_dims, d_n, d_group0, d_one = d_desc.t.shape
torch._check(
len(rest) == 0
d_batch_dims == qs_batch_dims
and (d_group0 * block_size) == a_k
and d_one == 1
and d_n == qs_n,
lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'd': Incorrect shape (got {d_desc.t.shape})",
)

# m arg
m_n, m_group0, m_one, *rest = m_desc.t.shape
*m_batch_dims, m_n, m_group0, m_one = m_desc.t.shape
torch._check(
len(rest) == 0
m_batch_dims == qs_batch_dims
and (m_group0 * block_size) == a_k
and m_one == 1
and m_n == qs_n,
Expand All @@ -81,12 +86,17 @@ def select(self, ksel: KernelSelection):

# Specialize on K, N, BS
a_desc.specialize_dims(-1)
qs_desc.specialize_all_dims()
d_desc.specialize_all_dims()
m_desc.specialize_all_dims()
if len(qs_batch_dims) == 0:
qs_desc.specialize_all_dims()
d_desc.specialize_all_dims()
m_desc.specialize_all_dims()
else:
qs_desc.specialize_dims(1, 2, 3)
d_desc.specialize_dims(1, 2, 3)
m_desc.specialize_dims(1, 2, 3)

# Shape batch..., m, n
c_desc = ksel.return_new_tensor(batch_dims + [a_m, d_n], dtype=a_desc.t.dtype)
c_desc = ksel.return_new_tensor(a_batch_dims + [a_m, d_n], dtype=a_desc.t.dtype)
c_desc.specialize_dims(-1)

def generate(self, ksel: KernelSelection, kb: KernelBuilder):
Expand All @@ -99,13 +109,14 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):

rank = a_tensor_type.rank
k = a_tensor_type.get_dim_size(rank - 1)
n, group0, bs_i8 = qs_tensor_type.shape
*qs_batch_dims, n, group0, bs_i8 = qs_tensor_type.shape
batched_rhs = len(qs_batch_dims) == 1
bs = bs_i8 * 2 # 2 nibbles per byte.
a_type_str = str(a_tensor_type.element_type)
scale_type_str = str(d_tensor_type.element_type)

template_file = "mmt_block_scaled_offset_q4_unsigned.mlir"
target_function_name = f"sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{bs}_{a_type_str}"
target_function_name = f"sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{bs}_{a_type_str}_{batched_rhs}"

target_function = inline_template_function(
kb,
Expand All @@ -118,5 +129,6 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
group0=group0,
a_type=a_type_str,
scale_type=scale_type_str,
batched_rhs=batched_rhs,
)
kb.yield_results(*call_function(target_function, *kb.arg_bindings))
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,25 @@
!accum_type = {{accum_type}}
!a_tensor_type = tensor<?x?x{{k}}x!a_type>
!aexp_tensor_type = tensor<?x?x{{group0}}x{{bs}}x!a_type>
{% if batched_rhs %}
!qs_raw_tensor_type = tensor<?x{{n}}x{{group0}}x{{bs_i8}}xi8>
!qs_tensor_type = tensor<?x{{n}}x{{group0}}x{{bs}}x!lowp_type>
!d_tensor_type = tensor<?x{{n}}x{{group0}}x1x!scale_type>
!m_tensor_type = tensor<?x{{n}}x{{group0}}x1x!scale_type>
!b_grouped_tensor_type = tensor<?x{{n}}x{{group0}}x{{bs}}x!a_type>
{% else %}
!qs_raw_tensor_type = tensor<{{n}}x{{group0}}x{{bs_i8}}xi8>
!qs_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!lowp_type>
!d_tensor_type = tensor<{{n}}x{{group0}}x1x!scale_type>
!m_tensor_type = tensor<{{n}}x{{group0}}x1x!scale_type>
!b_grouped_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!a_type>
{% endif %}
!accum_tensor_type = tensor<?x?x{{n}}x!accum_type>
!c_tensor_type = tensor<?x?x{{n}}x!a_type>
!b_grouped_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!a_type>

module {

util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_{{bs}}_{{a_type}}(
util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_{{bs}}_{{a_type}}_{{batched_rhs}}(
%a: !a_tensor_type, %d: !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type)
-> !c_tensor_type {
%zero = arith.constant 0.0: !accum_type
Expand All @@ -32,17 +40,31 @@ util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_
%m_dim = tensor.dim %a, %c1 : !a_tensor_type

// Cast qs_raw from i8 to lowp type.
{% if batched_rhs %}
%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{ %batch0_dim } -> !qs_tensor_type{ %batch0_dim }
%b_grouped = tensor.empty(%batch0_dim) : !b_grouped_tensor_type
{% else %}
%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type -> !qs_tensor_type
%b_grouped = tensor.empty() : !b_grouped_tensor_type
{% endif %}

// Dequantize.
%b_grouped = tensor.empty() : !b_grouped_tensor_type
%b_grouped_dequant = linalg.generic {
{% if batched_rhs %}
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
{% else %}
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"] }
{% endif %}
ins(%d, %m, %qs : !d_tensor_type, !m_tensor_type, !qs_tensor_type)
outs(%b_grouped : !b_grouped_tensor_type) {
^bb0(%d_element: !scale_type, %m_element: !scale_type, %q_element: !lowp_type, %out: !a_type):
Expand Down Expand Up @@ -70,7 +92,7 @@ util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_
indexing_maps = [
// d0 = b, d1 = m, d2 = n, d3 = group0 (r), d4 = block (r)
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> ({% if batched_rhs %}d0,{% endif %} d2, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] }
ins(%aexp, %b_grouped_dequant : !aexp_tensor_type, !b_grouped_tensor_type)
Expand Down
71 changes: 67 additions & 4 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,56 @@ def conv2d_default(
conv2d.override(Tensor, Tensor, auto_dequant=True)(conv2d_default)

# Einsum
@einsum_2args.override(AllOfType(Tensor, PrimitiveTensor))
def einsum_2args(x, y, einsum_str):
return torch.einsum(einsum_str, unbox_tensor(x), unbox_tensor(y))
def mk_menk_men(inputs, weights):
# batch dims: m, lhs pdims: none, lhs rdims: k, rhs pdims: en, rhs rdims: k
inputs = inputs.unsqueeze(1)
weights_shape = weights.shape
weights = weights.view(
weights_shape[0], weights_shape[1] * weights_shape[2], weights_shape[3]
)
result = matmul(inputs, weights, transpose_rhs=True)
result = result.view(weights_shape[0], weights_shape[1], weights_shape[2])
return result


def mek_menk_men(inputs, weights):
# batch dims: me, lhs pdims: none, lhs rdims: k, rhs pdims: n, rhs rdims: k
inputs_shape = inputs.shape
inputs = inputs.view(inputs_shape[0] * inputs_shape[1], 1, inputs_shape[2])
weights_shape = weights.shape
weights = weights.view(
weights_shape[0] * weights_shape[1], weights_shape[2], weights_shape[3]
)
result = matmul(inputs, weights, transpose_rhs=True)
result = result.view(weights_shape[0], weights_shape[1], weights_shape[2])
return result


def me_men_men(inputs, weights):
# batch dims: me, lhs pdims: none, lhs rdims: none, rhs pdims: n, rhs rdims: none
inputs_shape = inputs.shape
inputs = inputs.view(inputs_shape[0] * inputs_shape[1], 1, 1)
weights_shape = weights.shape
weights = weights.view(weights_shape[0] * weights_shape[1], weights_shape[2], 1)
result = matmul(inputs, weights, transpose_rhs=True)
result = result.view(weights_shape[0], weights_shape[1], weights_shape[2])
return result


@einsum_2args.override(AllOfType(Tensor, PrimitiveTensor, QuantizedTensor))
def einsum_2args(input0, input1, einsum_str):
# Special optimized einsum kernels that lower to batch matmul
if einsum_str == "mk,menk->men":
return mk_menk_men(input0, input1)
elif einsum_str == "mek,menk->men":
return mek_menk_men(input0, input1)
elif einsum_str == "me,men->men":
return me_men_men(input0, input1)
# Default non-QuantizedTensor einsum
if not isinstance(input1, QuantizedTensor):
return torch.einsum(einsum_str, unbox_tensor(x), unbox_tensor(y))
# Fallback to other kernels
return NotImplemented


# Elementwise
Expand Down Expand Up @@ -307,7 +354,7 @@ def matmul_default(lhs, rhs, *, transpose_rhs: bool) -> Tensor:
lhs = unbox_tensor(lhs)
rhs = unbox_tensor(rhs)
if transpose_rhs:
rhs = rhs.T
rhs = rhs.mT
return torch.matmul(lhs, rhs.to(lhs.dtype))


Expand Down Expand Up @@ -433,3 +480,19 @@ def unsqueeze_default(tensor: Union[Tensor, PrimitiveTensor], dim: int) -> Tenso
@view.override(Tensor)
def view_default(tensor: Union[Tensor, PrimitiveTensor], shape: List[int]) -> Tensor:
return unbox_tensor(tensor).view(*shape)


@view.override(QuantizedTensor)
def view_QuantizedTensor(tensor: QuantizedTensor, shape):
unpacked = tensor.unpack()
if not isinstance(unpacked, BlockScaledI4Layout):
return NotImplemented
bs = 16
shape = list(shape)
new_d = unpacked._d.view(shape[:-1] + [shape[-1] // 32, 1])
qs_shape = shape[:-1] + [shape[-1] // 32, 16]
new_qs = unpacked._qs.view(qs_shape)
if unpacked.m is not None:
new_m = unpacked.m.view(shape[:-1] + [shape[-1] // 32, 1])
layout = BlockScaledI4Layout(shape=shape, d=new_d, qs=new_qs, m=new_m)
return PlanarQuantizedTensor(shape=shape, layout=layout)
2 changes: 1 addition & 1 deletion sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def unsqueeze(self, dim: int) -> "AnyTensor":
def view(self, *args: Union[List[List[int]], List[int]]) -> "AnyTensor":
from ..ops import view

if all(isinstance(a, int) for a in args):
if all(isinstance(a, int) or isinstance(a, torch.SymInt) for a in args):
shape = args
else:
assert len(args) == 1
Expand Down

0 comments on commit acd77e3

Please sign in to comment.