From 623e3e7be00d3be5537d584a856eb4e9904f331e Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Tue, 8 Oct 2024 15:54:28 -0500 Subject: [PATCH] Add special einsum cases that lower to batch matmul --- .../kernels/mmt_block_scaled_offset_q4.py | 40 +++++++---- .../mmt_block_scaled_offset_q4_unsigned.mlir | 30 ++++++-- sharktank/sharktank/ops/default_impls.py | 71 +++++++++++++++++-- sharktank/sharktank/types/tensors.py | 2 +- 4 files changed, 120 insertions(+), 23 deletions(-) diff --git a/sharktank/sharktank/kernels/mmt_block_scaled_offset_q4.py b/sharktank/sharktank/kernels/mmt_block_scaled_offset_q4.py index 2ed171115..0c8a61f32 100644 --- a/sharktank/sharktank/kernels/mmt_block_scaled_offset_q4.py +++ b/sharktank/sharktank/kernels/mmt_block_scaled_offset_q4.py @@ -37,28 +37,33 @@ def select(self, ksel: KernelSelection): m_desc = ksel.arg_tensor(3) # Shape [N, K // BLOCK_SIZE, 1] # a arg - *batch_dims, a_m, a_k = a_desc.t.shape + *a_batch_dims, a_m, a_k = a_desc.t.shape torch._check( a_desc.t.dtype.is_floating_point, lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'a': Expected floating point (got {a_desc.t.dtype})", ) torch._check( - len(batch_dims) == 1, + len(a_batch_dims) == 1, lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'a': Expected 3d tensor (got {a_desc.t.shape})", ) # qs arg - qs_n, qs_group0, qs_bs_div_2, *rest = qs_desc.t.shape + *qs_batch_dims, qs_n, qs_group0, qs_bs_div_2 = qs_desc.t.shape torch._check( - len(rest) == 0 and (qs_group0 * qs_bs_div_2 * 2) == a_k, + ( + len(qs_batch_dims) == 0 + or len(qs_batch_dims) == 1 + and qs_batch_dims == a_batch_dims + ) + and (qs_group0 * qs_bs_div_2 * 2) == a_k, lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'qs': Incorrect shape (got {qs_desc.t.shape})", ) block_size = qs_bs_div_2 * 2 # d arg - d_n, d_group0, d_one, *rest = d_desc.t.shape + *d_batch_dims, d_n, d_group0, d_one = d_desc.t.shape torch._check( - len(rest) == 0 + d_batch_dims == qs_batch_dims and (d_group0 * block_size) == a_k and d_one == 1 and d_n == qs_n, @@ -66,9 +71,9 @@ def select(self, ksel: KernelSelection): ) # m arg - m_n, m_group0, m_one, *rest = m_desc.t.shape + *m_batch_dims, m_n, m_group0, m_one = m_desc.t.shape torch._check( - len(rest) == 0 + m_batch_dims == qs_batch_dims and (m_group0 * block_size) == a_k and m_one == 1 and m_n == qs_n, @@ -81,12 +86,17 @@ def select(self, ksel: KernelSelection): # Specialize on K, N, BS a_desc.specialize_dims(-1) - qs_desc.specialize_all_dims() - d_desc.specialize_all_dims() - m_desc.specialize_all_dims() + if len(qs_batch_dims) == 0: + qs_desc.specialize_all_dims() + d_desc.specialize_all_dims() + m_desc.specialize_all_dims() + else: + qs_desc.specialize_dims(1, 2, 3) + d_desc.specialize_dims(1, 2, 3) + m_desc.specialize_dims(1, 2, 3) # Shape batch..., m, n - c_desc = ksel.return_new_tensor(batch_dims + [a_m, d_n], dtype=a_desc.t.dtype) + c_desc = ksel.return_new_tensor(a_batch_dims + [a_m, d_n], dtype=a_desc.t.dtype) c_desc.specialize_dims(-1) def generate(self, ksel: KernelSelection, kb: KernelBuilder): @@ -99,13 +109,14 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): rank = a_tensor_type.rank k = a_tensor_type.get_dim_size(rank - 1) - n, group0, bs_i8 = qs_tensor_type.shape + *qs_batch_dims, n, group0, bs_i8 = qs_tensor_type.shape + batched_rhs = len(qs_batch_dims) == 1 bs = bs_i8 * 2 # 2 nibbles per byte. a_type_str = str(a_tensor_type.element_type) scale_type_str = str(d_tensor_type.element_type) template_file = "mmt_block_scaled_offset_q4_unsigned.mlir" - target_function_name = f"sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{bs}_{a_type_str}" + target_function_name = f"sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{bs}_{a_type_str}_{batched_rhs}" target_function = inline_template_function( kb, @@ -118,5 +129,6 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): group0=group0, a_type=a_type_str, scale_type=scale_type_str, + batched_rhs=batched_rhs, ) kb.yield_results(*call_function(target_function, *kb.arg_bindings)) diff --git a/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir b/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir index a7f3138cb..afe2928c0 100644 --- a/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir +++ b/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir @@ -12,17 +12,25 @@ !accum_type = {{accum_type}} !a_tensor_type = tensor !aexp_tensor_type = tensor +{% if batched_rhs %} +!qs_raw_tensor_type = tensor +!qs_tensor_type = tensor +!d_tensor_type = tensor +!m_tensor_type = tensor +!b_grouped_tensor_type = tensor +{% else %} !qs_raw_tensor_type = tensor<{{n}}x{{group0}}x{{bs_i8}}xi8> !qs_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!lowp_type> !d_tensor_type = tensor<{{n}}x{{group0}}x1x!scale_type> !m_tensor_type = tensor<{{n}}x{{group0}}x1x!scale_type> +!b_grouped_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!a_type> +{% endif %} !accum_tensor_type = tensor !c_tensor_type = tensor -!b_grouped_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!a_type> module { -util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_{{bs}}_{{a_type}}( +util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_{{bs}}_{{a_type}}_{{batched_rhs}}( %a: !a_tensor_type, %d: !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type) -> !c_tensor_type { %zero = arith.constant 0.0: !accum_type @@ -32,17 +40,31 @@ util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_ %m_dim = tensor.dim %a, %c1 : !a_tensor_type // Cast qs_raw from i8 to lowp type. +{% if batched_rhs %} + %qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{ %batch0_dim } -> !qs_tensor_type{ %batch0_dim } + %b_grouped = tensor.empty(%batch0_dim) : !b_grouped_tensor_type +{% else %} %qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type -> !qs_tensor_type + %b_grouped = tensor.empty() : !b_grouped_tensor_type +{% endif %} // Dequantize. - %b_grouped = tensor.empty() : !b_grouped_tensor_type %b_grouped_dequant = linalg.generic { +{% if batched_rhs %} + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] } +{% else %} indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"] } +{% endif %} ins(%d, %m, %qs : !d_tensor_type, !m_tensor_type, !qs_tensor_type) outs(%b_grouped : !b_grouped_tensor_type) { ^bb0(%d_element: !scale_type, %m_element: !scale_type, %q_element: !lowp_type, %out: !a_type): @@ -70,7 +92,7 @@ util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_ indexing_maps = [ // d0 = b, d1 = m, d2 = n, d3 = group0 (r), d4 = block (r) affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, - affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ({% if batched_rhs %}d0,{% endif %} d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] } ins(%aexp, %b_grouped_dequant : !aexp_tensor_type, !b_grouped_tensor_type) diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index fec30fca6..ef7144bca 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -69,9 +69,56 @@ def conv2d_default( conv2d.override(Tensor, Tensor, auto_dequant=True)(conv2d_default) # Einsum -@einsum_2args.override(AllOfType(Tensor, PrimitiveTensor)) -def einsum_2args(x, y, einsum_str): - return torch.einsum(einsum_str, unbox_tensor(x), unbox_tensor(y)) +def mk_menk_men(inputs, weights): + # batch dims: m, lhs pdims: none, lhs rdims: k, rhs pdims: en, rhs rdims: k + inputs = inputs.unsqueeze(1) + weights_shape = weights.shape + weights = weights.view( + weights_shape[0], weights_shape[1] * weights_shape[2], weights_shape[3] + ) + result = matmul(inputs, weights, transpose_rhs=True) + result = result.view(weights_shape[0], weights_shape[1], weights_shape[2]) + return result + + +def mek_menk_men(inputs, weights): + # batch dims: me, lhs pdims: none, lhs rdims: k, rhs pdims: n, rhs rdims: k + inputs_shape = inputs.shape + inputs = inputs.view(inputs_shape[0] * inputs_shape[1], 1, inputs_shape[2]) + weights_shape = weights.shape + weights = weights.view( + weights_shape[0] * weights_shape[1], weights_shape[2], weights_shape[3] + ) + result = matmul(inputs, weights, transpose_rhs=True) + result = result.view(weights_shape[0], weights_shape[1], weights_shape[2]) + return result + + +def me_men_men(inputs, weights): + # batch dims: me, lhs pdims: none, lhs rdims: none, rhs pdims: n, rhs rdims: none + inputs_shape = inputs.shape + inputs = inputs.view(inputs_shape[0] * inputs_shape[1], 1, 1) + weights_shape = weights.shape + weights = weights.view(weights_shape[0] * weights_shape[1], weights_shape[2], 1) + result = matmul(inputs, weights, transpose_rhs=True) + result = result.view(weights_shape[0], weights_shape[1], weights_shape[2]) + return result + + +@einsum_2args.override(AllOfType(Tensor, PrimitiveTensor, QuantizedTensor)) +def einsum_2args(input0, input1, einsum_str): + # Special optimized einsum kernels that lower to batch matmul + if einsum_str == "mk,menk->men": + return mk_menk_men(input0, input1) + elif einsum_str == "mek,menk->men": + return mek_menk_men(input0, input1) + elif einsum_str == "me,men->men": + return me_men_men(input0, input1) + # Default non-QuantizedTensor einsum + if not isinstance(input1, QuantizedTensor): + return torch.einsum(einsum_str, unbox_tensor(x), unbox_tensor(y)) + # Fallback to other kernels + return NotImplemented # Elementwise @@ -307,7 +354,7 @@ def matmul_default(lhs, rhs, *, transpose_rhs: bool) -> Tensor: lhs = unbox_tensor(lhs) rhs = unbox_tensor(rhs) if transpose_rhs: - rhs = rhs.T + rhs = rhs.mT return torch.matmul(lhs, rhs.to(lhs.dtype)) @@ -433,3 +480,19 @@ def unsqueeze_default(tensor: Union[Tensor, PrimitiveTensor], dim: int) -> Tenso @view.override(Tensor) def view_default(tensor: Union[Tensor, PrimitiveTensor], shape: List[int]) -> Tensor: return unbox_tensor(tensor).view(*shape) + + +@view.override(QuantizedTensor) +def view_QuantizedTensor(tensor: QuantizedTensor, shape): + unpacked = tensor.unpack() + if not isinstance(unpacked, BlockScaledI4Layout): + return NotImplemented + bs = 16 + shape = list(shape) + new_d = unpacked._d.view(shape[:-1] + [shape[-1] // 32, 1]) + qs_shape = shape[:-1] + [shape[-1] // 32, 16] + new_qs = unpacked._qs.view(qs_shape) + if unpacked.m is not None: + new_m = unpacked.m.view(shape[:-1] + [shape[-1] // 32, 1]) + layout = BlockScaledI4Layout(shape=shape, d=new_d, qs=new_qs, m=new_m) + return PlanarQuantizedTensor(shape=shape, layout=layout) diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 324cc4331..70b0fbd01 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -378,7 +378,7 @@ def unsqueeze(self, dim: int) -> "AnyTensor": def view(self, *args: Union[List[List[int]], List[int]]) -> "AnyTensor": from ..ops import view - if all(isinstance(a, int) for a in args): + if all(isinstance(a, int) or isinstance(a, torch.SymInt) for a in args): shape = args else: assert len(args) == 1