diff --git a/sharktank/sharktank/kernels/__init__.py b/sharktank/sharktank/kernels/__init__.py index 308e20ef4..beb7e90a2 100644 --- a/sharktank/sharktank/kernels/__init__.py +++ b/sharktank/sharktank/kernels/__init__.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .attention import * +from .einsum_2args_q4 import * from .mmtfp import * from .mmt_block_scaled_offset_q4 import * from .mmt_block_scaled_q8 import * diff --git a/sharktank/sharktank/kernels/einsum_2args_q4.py b/sharktank/sharktank/kernels/einsum_2args_q4.py new file mode 100644 index 000000000..76d8ad61c --- /dev/null +++ b/sharktank/sharktank/kernels/einsum_2args_q4.py @@ -0,0 +1,259 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .base import * + +import torch + +__all__ = [ + "einsum_2args_q4", +] + + +def einsum_util(einsum_str): + es_in, es_out = einsum_str.split("->") + es_in0, es_in1 = es_in.split(",") + es_set = set(es_out) + es_set = es_set.union(es_in0) + es_set = es_set.union(es_in1) + size = len(es_set) + imap = dict() + lmap = dict() + for i in range(len(es_out)): + imap[i] = es_out[i] + lmap[es_out[i]] = i + count = len(es_out) + for c in es_set: + if c not in lmap: + imap[count] = c + lmap[c] = count + count += 1 + + assert count == len(es_set) + + in0_idx = [lmap[i] for i in es_in0] + in1_idx = [lmap[i] for i in es_in1] + out_idx = [lmap[i] for i in es_out] + + input_idx_str = ", ".join(["d" + str(i) for i in range(size)]) + in0_idx_str = ", ".join(["d" + str(i) for i in in0_idx]) + in1_idx_str = ", ".join(["d" + str(i) for i in in1_idx]) + out_idx_str = ", ".join(["d" + str(i) for i in out_idx]) + + iterators = ", ".join( + ['"parallel"' if i in out_idx else '"reduction"' for i in range(size)] + ) + + affine_map_in0 = f"affine_map<({input_idx_str}) -> ({in0_idx_str})>" + affine_map_in1 = f"affine_map<({input_idx_str}) -> ({in1_idx_str})>" + affine_map_out = f"affine_map<({input_idx_str}) -> ({out_idx_str})>" + + indexing_maps = f"""{affine_map_in0}, + {affine_map_in1}, + {affine_map_out} +""" + + out_dyn_dim_size_str = "" + for c in es_out: + if c in es_in0: + out_dyn_dim_size_str += "%a" + str(es_in0.find(c)) + "," + elif c in es_in1: + if es_in1.find(c) == len(es_in1) - 1: + out_dyn_dim_size_str += "%b_unblocked_dim," + else: + out_dyn_dim_size_str += "%b" + str(es_in1.find(c)) + "," + else: + raise Exception("Invalid einsum string") + out_dyn_dim_size_str = out_dyn_dim_size_str[:-1] + return ( + (in0_idx, in1_idx, out_idx), + iterators, + indexing_maps, + out_dyn_dim_size_str, + ) + + +@CustomOp.register(library=LIBRARY) +class einsum_2args_q4(CustomOp): + """Einsum that takes two tensor inputs and returns one tensor. + + The first input is expected to be a normal tensor. + + The second input corresponds to the BlockScaledLayout and operates on planar `d` + and `qs` tensors as specified there: + + * `d`: `[..., K // BLOCK_SIZE, 1]` + * `qs`: `[..., K // BLOCK_SIZE, BLOCK_SIZE // 2]` (of uint8) + * `m`: `[..., K // BLOCK_SIZE, 1]` + """ + + signature = ( + "einsum_2args_q4(Tensor a, Tensor d, Tensor qs, Tensor m, str es) -> (Tensor)" + ) + + def select(self, ksel: KernelSelection): + a_desc = ksel.arg_tensor(0) # Shape [b, ] m, k + d_desc = ksel.arg_tensor(1) # Shape [N, K // BLOCK_SIZE, 1] + qs_desc = ksel.arg_tensor(2) # Shape [N, K // BLOCK_SIZE, BLOCK_SIZE // 2] + m_desc = ksel.arg_tensor(3) # Shape [N, K // BLOCK_SIZE, 1] + einsum_str = ksel.attr_str(4).v + + # a arg + a_dims = a_desc.t.shape + torch._check( + a_desc.t.dtype.is_floating_point, + lambda: f"einsum_2args_q4 arg 'a': Expected floating point (got {a_desc.t.dtype})", + ) + + # qs arg + *qs_dims, qs_group0, qs_bs_div_2 = qs_desc.t.shape + block_size = qs_bs_div_2 * 2 + + # d arg + *d_dims, d_group0, d_one = d_desc.t.shape + torch._check( + d_group0 == qs_group0 and d_one == 1 and len(d_dims) == len(qs_dims), + lambda: f"einsum_2args_q4 arg 'd': Incorrect shape (got {d_desc.t.shape})", + ) + + # m arg + *m_dims, m_group0, m_one = m_desc.t.shape + torch._check( + m_desc.t.dtype == d_desc.t.dtype and len(m_dims) == len(qs_dims), + lambda: f"einsum_2args_q4 arg 'm': Incorrect dtype (got {m_desc.t.dtype})", + ) + # einsum_str + torch._check( + einsum_str.count(",") == 1 and einsum_str.count("->") == 1, + lambda: f"einsum_2args_q4 arg 'einsum_str': Expected format '{{}},{{}}->{{}}' (got '{einsum_str}')", + ) + + es_in, es_out = einsum_str.split("->") + es_in0, es_in1 = es_in.split(",") + es_set = set(es_out) + + shp = qs_desc.t.shape + b_dims = list(shp[:-2]) + [shp[-2] * block_size] + torch._check( + len(es_in0) == len(a_desc.t.shape) + and len(es_in1) + == len(qs_desc.t.shape) + - 1, # The quantized shape is larger until the blocks are collapsed + lambda: f"einsum_2args_q4 arg 'einsum_str': Einsum str dimensions do not match input dimensions (got '{einsum_str}' with inputs: {a_desc.t.shape} and {b_dims})", + ) + torch._check( + len(es_in0) == len(set(es_in0)) + and len(es_in1) == len(set(es_in1)) + and len(es_in0) != 0 + and len(es_in1) != 0, + lambda: f"einsum_2args_q4 arg 'einsum_str': Unsupported einsum str (got '{einsum_str}')", + ) + + # Check corresponding dimensions match + for i in range(len(es_in0)): + a_dim = a_dims[i] + c = es_in0[i] + pos = es_in1.find(c) + if pos >= 0: + b_dim = b_dims[pos] + torch._check( + a_dim == b_dim, + lambda: f"einsum_2args_q4 arg 'einsum_str': Einsum str dimensions do not match input dim for idx {c} (got '{einsum_str}' with inputs: {a_desc.t.shape} and {b_dims})", + ) + + # Determine the output shape by referencing corresponding input shapes + out_dims = [] + for c in es_out: + pos0 = es_in0.find(c) + pos1 = es_in1.find(c) + a_dim = a_dims[pos0] + b_dim = b_dims[pos1] + if pos0 >= 0: + out_dims.append(a_dim) + elif pos1 >= 0: + out_dims.append(b_dim) + else: + torch._check( + False, + lambda: f"einsum_2args_q4 arg 'einsum_str': output indices must be in input indices (got '{einsum_str}')", + ) + + # Specialize on BS + qs_desc.specialize_dims(-1) + d_desc.specialize_dims(-1) + m_desc.specialize_dims(-1) + + # Shape batch..., m, n + c_desc = ksel.return_new_tensor(out_dims, dtype=a_desc.t.dtype) + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + a = kb.arg_value(0) + a_tensor_type = RankedTensorType(a.type) + d = kb.arg_value(1) + d_tensor_type = RankedTensorType(d.type) + qs = kb.arg_value(2) + qs_tensor_type = RankedTensorType(qs.type) + einsum_str = ksel.arg_descs[4].v + # einsum_str = "mek,menk->men" + + es_in, es_out = einsum_str.split("->") + es_in0, es_in1 = es_in.split(",") + + es_name = "_".join([es_in0, es_in1, es_out]) + + ( + (es_0, es_1, es_2), + einsum_iterators, + einsum_indexing_maps, + oddss, + ) = einsum_util(einsum_str) + + rank1 = len(es_1) + dequant_iterators = ", ".join( + ['"parallel"' for i in range(rank1 + 1)] + ) # rank + 1 because of the group dimensions + input_idx_str = ", ".join(["d" + str(i) for i in range(rank1 + 1)]) + broadcast_idx_str = ", ".join( + ["d" + str(i) if i != rank1 else "0" for i in range(rank1 + 1)] + ) + affine_map_parallel = f"affine_map<({input_idx_str}) -> ({input_idx_str})>" + affine_map_broadcast = f"affine_map<({input_idx_str}) -> ({broadcast_idx_str})>" + dequant_indexing_maps = f"""{affine_map_broadcast}, + {affine_map_broadcast}, + {affine_map_parallel}, + {affine_map_parallel}""" + + size_str = "x".join("?" for i in range(rank1 - 2)) + + rank = a_tensor_type.rank + *n_dims, group0, bs_i8 = qs_tensor_type.shape + bs = bs_i8 * 2 # 2 nibbles per byte. + group = group0 * bs + a_type_str = str(a_tensor_type.element_type) + scale_type_str = str(d_tensor_type.element_type) + + template_file = "einsum_2args_q4.mlir" + target_function_name = f"sharktank_einsum_2args_q4_{es_name}_{bs}_{a_type_str}" + + target_function = inline_template_function( + kb, + template_file, + target_function_name, + bs=bs, + bs_i8=bs_i8, + a_type=a_type_str, + scale_type=scale_type_str, + dequant_indexing_maps=dequant_indexing_maps, + dequant_iterator_types=dequant_iterators, + einsum_indexing_maps=einsum_indexing_maps, + einsum_iterator_types=einsum_iterators, + es_name=es_name, + a_size=len(es_in0), + b_size=len(es_in1), + c_size=len(es_out), + out_dyn_dim_size_str=oddss, + ) + kb.yield_results(*call_function(target_function, *kb.arg_bindings)) diff --git a/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir b/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir new file mode 100644 index 000000000..47ca6b331 --- /dev/null +++ b/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir @@ -0,0 +1,104 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +{% set accum_type = "f32" %} + +!lowp_type = i4 +!a_type = {{a_type}} +!scale_type = {{scale_type}} +!accum_type = {{accum_type}} +!a_tensor_type = tensor<{% for i in range(a_size) %}?x{% endfor %}!a_type> +!qs_raw_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs_i8}}xi8> +!qs_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs}}x!lowp_type> +!d_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}1x!scale_type> +!m_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}1x!scale_type> +!accum_tensor_type = tensor<{% for i in range(c_size) %}?x{% endfor %}!accum_type> +!c_tensor_type = tensor<{% for i in range(c_size) %}?x{% endfor %}!a_type> +!b_grouped_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs}}x!a_type> +!b_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}!a_type> + +module { + +util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}( + %a: !a_tensor_type, %d: !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type) + -> !c_tensor_type { + %debug = tensor.empty() : tensor<1xf32> + %zero = arith.constant 0.0: !accum_type + {% for i in range(a_size) %} + %k{{i}} = arith.constant {{i}} : index + {% endfor %} + {% for i in range(a_size, b_size) %} + %k{{i}} = arith.constant {{i}} : index + {% endfor %} + {% for i in range(a_size) %} + %a{{i}} = tensor.dim %a, %k{{i}}: !a_tensor_type + {% endfor %} + {% for i in range(b_size) %} + %b{{i}} = tensor.dim %qs_raw, %k{{i}}: !qs_raw_tensor_type + {% endfor %} + %bs = arith.constant {{bs}} : index + %b_unblocked_dim = arith.muli %b{{b_size-1}}, %bs : index + + //%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type -> !qs_tensor_type + %qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} -> !qs_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} + + // Dequantize. + %b_grouped = tensor.empty({% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}) : !b_grouped_tensor_type + %b_grouped_dequant = linalg.generic { + indexing_maps = [ + {{dequant_indexing_maps}}], + iterator_types = [{{dequant_iterator_types}}] } + ins(%d, %m, %qs : !d_tensor_type, !m_tensor_type, !qs_tensor_type) + outs(%b_grouped : !b_grouped_tensor_type) { + ^bb0(%d_element: !scale_type, %m_element: !scale_type, %q_element: !lowp_type, %out: !a_type): + %q_element_ext = arith.extui %q_element : !lowp_type to i32 + %q_element_fp = arith.uitofp %q_element_ext : i32 to !a_type + {% if scale_type == a_type %} + %q_element_scaled = arith.mulf %q_element_fp, %d_element : !a_type + %q_element_offset = arith.addf %q_element_scaled, %m_element : !a_type + {% else %} + %d_element_ext = arith.extf %d_element : !scale_type to !a_type + %m_element_ext = arith.extf %m_element : !scale_type to !a_type + %q_element_scaled = arith.mulf %q_element_fp, %d_element_ext : !a_type + %q_element_offset = arith.addf %q_element_scaled, %m_element_ext : !a_type + {% endif %} + linalg.yield %q_element_offset : !a_type + } -> !b_grouped_tensor_type + + // Collapse %b to the same unblocked structure. + %b_unblocked = tensor.collapse_shape %b_grouped_dequant [{% for i in range(b_size-1) %}[{{i}}], {% endfor %}[{{b_size-1}}, {{b_size}}]] : !b_grouped_tensor_type into !b_tensor_type + + // Einsum + %result_empty = tensor.empty({{out_dyn_dim_size_str}}) : !accum_tensor_type + %result_fill = linalg.fill ins(%zero: !accum_type) outs(%result_empty: !accum_tensor_type) -> !accum_tensor_type + %result = linalg.generic { + indexing_maps = [ + {{einsum_indexing_maps}}], + iterator_types = [{{einsum_iterator_types}}] } + ins(%a, %b_unblocked : !a_tensor_type, !b_tensor_type) + outs(%result_fill : !accum_tensor_type) { + ^bb0(%a_element: !a_type, %b_element: !a_type, %out: !accum_type): + %bmm_mul = arith.mulf %a_element, %b_element : !a_type + {% if accum_type == a_type %} + %bmm_accum = arith.addf %bmm_mul, %out : !a_type + {% else %} + %bmm_mul_ext = arith.extf %bmm_mul : !a_type to !accum_type + %bmm_accum = arith.addf %bmm_mul_ext, %out : !accum_type + {% endif %} + linalg.yield %bmm_accum : !accum_type + } -> !accum_tensor_type + + // Cast. + %result_cast_empty = tensor.empty({{out_dyn_dim_size_str}}) : !c_tensor_type + %result_cast = linalg.copy + ins(%result : !accum_tensor_type) + outs(%result_cast_empty : !c_tensor_type) -> !c_tensor_type + + //iree_input.tensor.trace "foobar" = [%a : !a_tensor_type, %d : !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type, %b_grouped_dequant: !b_grouped_tensor_type] + util.return %result_cast : !c_tensor_type +} + +} diff --git a/sharktank/sharktank/layers/ffn_moe_block.py b/sharktank/sharktank/layers/ffn_moe_block.py index 266537c89..0536302cf 100644 --- a/sharktank/sharktank/layers/ffn_moe_block.py +++ b/sharktank/sharktank/layers/ffn_moe_block.py @@ -12,6 +12,7 @@ from .base import ThetaLayer from .linear import LinearLayer from ..types import Theta, DefaultPrimitiveTensor +from ..ops import einsum_2args __all__ = [ "FFNMOE", @@ -36,24 +37,24 @@ def __init__( def pre_matmul_gather(self, inputs, weights, experts, einstring="mk,menk->men"): inputs = inputs[:, :] weights = weights[experts, :, :] - matmul = torch.einsum(einstring, inputs, weights.float()) + matmul = einsum_2args(inputs, weights, einstring) return matmul def bigger_mmg(self, inputs, weights, experts): inputs = inputs[:, :] weights = weights[experts, :, :] - matmul = torch.einsum("mek,menk->men", inputs, weights.float()) + matmul = einsum_2args(inputs, weights, "mek,menk->men") return matmul def one_hot_matmul(self, inputs, weights, experts): - matmul = torch.einsum("mk,bnk->bmn", inputs, weights) + matmul = einsum_2args(inputs, weights, "mk,bnk->bmn") # Post mix the experts oh = ( torch.nn.functional.one_hot(experts.reshape(-1), num_classes=8) .transpose(0, 1) .to(torch.float32) ) - output = torch.einsum("bm,bmn->mn", oh, matmul) + output = einsum_2args(oh, matmul, "bm,bmn->mn") return output def forward( @@ -63,19 +64,15 @@ def forward( expert_gate: torch.Tensor, ): if self.use_grok: - ffn_gate = F.gelu( - self.pre_matmul_gather(h, self.ffn_gate.as_torch(), experts) - ) + ffn_gate = F.gelu(self.pre_matmul_gather(h, self.ffn_gate, experts)) else: - ffn_gate = F.silu( - self.pre_matmul_gather(h, self.ffn_gate.as_torch(), experts) - ) + ffn_gate = F.silu(self.pre_matmul_gather(h, self.ffn_gate, experts)) ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts) ffn_down = self.pre_matmul_gather( ffn_gate * ffn_up, self.ffn_down, experts, einstring="mek,menk->men" ) - ffn_down = torch.einsum("me,men->men", expert_gate, ffn_down) + ffn_down = einsum_2args(expert_gate, ffn_down, "me,men->men") return torch.sum(ffn_down, dim=1) diff --git a/sharktank/sharktank/ops/custom_impls.py b/sharktank/sharktank/ops/custom_impls.py index fe6ae27b1..d1df15cbc 100644 --- a/sharktank/sharktank/ops/custom_impls.py +++ b/sharktank/sharktank/ops/custom_impls.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from ..kernels import ( + einsum_2args_q4, mmt_block_scaled_offset_q4_unsigned, mmt_block_scaled_q8, mmtfp, @@ -44,6 +45,18 @@ # return mmtfp(lhs, rhs) +# Einsum + + +@einsum_2args.override(Tensor, QuantizedTensor) +def einsum_2args_QuantizedTensor(input0, input1, einsum_str): + unpacked = input1.unpack() + layout = input1.layout_type + if not isinstance(unpacked, BlockScaledI4Layout): + return NotImplemented + return einsum_2args_q4(input0, unpacked.d, unpacked._qs, unpacked.m, einsum_str) + + # Quantized Matmul diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index 050b1384b..a2fcd2813 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -14,7 +14,13 @@ import torch.nn.functional as F from numbers import Number -from ..types import PrimitiveTensor, QuantizedTensor, InferenceTensor +from ..types import ( + PrimitiveTensor, + QuantizedTensor, + InferenceTensor, + PlanarQuantizedTensor, + BlockScaledI4Layout, +) from ..types.tensors import unbox_tensor, AnyTensor from ._registry import AllOfType, AllOfExprs, AllOfExprsVariadic, IsOfType from .signatures import * @@ -62,6 +68,12 @@ def conv2d_default( conv2d.override(Tensor, Tensor, Tensor, auto_dequant=True)(conv2d_default) conv2d.override(Tensor, Tensor, auto_dequant=True)(conv2d_default) +# Einsum +@einsum_2args.override(AllOfType(Tensor, PrimitiveTensor)) +def einsum_2args(x, y, einsum_str): + return torch.einsum(einsum_str, unbox_tensor(x), unbox_tensor(y)) + + # Elementwise @elementwise.override(Tensor) def elementwise_unary(operator, x, *args, **kwargs): @@ -145,6 +157,28 @@ def flatten_default( return torch.flatten(unbox_tensor(input), start_dim, end_dim) +@get_index.override(AllOfType(Tensor, PrimitiveTensor)) +def get_index_default(tensor, key): + return unbox_tensor(tensor).__get_item__(key) + + +@get_index.override(QuantizedTensor) +def get_index_QuantizedTensor(tensor: QuantizedTensor, key: slice): + unpacked = tensor.unpack() + if isinstance(unpacked, BlockScaledI4Layout): + mul = 2 + else: + return NotImplemented + new_d = unpacked._d[key] + new_qs = unpacked._qs[key] + if unpacked.m is not None: + new_m = unpacked.m[key] + dims = new_qs.shape + dims = dims[:-2] + (dims[-2] * dims[-1] * mul,) + layout = BlockScaledI4Layout(shape=dims, d=new_d, qs=new_qs, m=new_m) + return PlanarQuantizedTensor(shape=dims, layout=layout) + + @gemm.override(AllOfType(Tensor, InferenceTensor)) def gemm( a: AnyTensor, diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index cb7ea7bc4..59f7672c7 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -21,11 +21,13 @@ "all_reduce", "cat", "conv2d", + "einsum_2args", "elementwise", "embedding_lookup", "equal", "expand", "flatten", + "get_index", "gemm", "group_norm_affine", "layer_norm", @@ -165,6 +167,37 @@ def _conv2d_trampoline( d.fail(tensors) +@overridable +def einsum_2args( + input0: AnyTensor, + input1: AnyTensor, + einsum_str: str, + *, + accum_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """Executes a given Einstein summation notation string on the provided tensors. + + Equivalent to: + ``` + y = torch.einsum(einsum_str, input0, input1) + ``` + """ + raise NotImplementedError + + +@einsum_2args.trampoline +def _einsum_trampoline( + d: SignatureDispatcher, input0: AnyTensor, input1: AnyTensor, einsum_str: str +): + tensors = (input0, input1) + for override in d.find_overrides(tensors): + result = override(input0, input1, einsum_str) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + @overridable def elementwise(operator, *args, **kwargs) -> AnyTensor: """Applies an elementwise operator against arguments.""" @@ -270,6 +303,32 @@ def _expand_trampoline( d.fail(tensors) +@overridable +def get_index( + tensor: AnyTensor, + key: slice, +) -> torch.Tensor: + """Indexes the tensor using the key. + + Equivalent to: + ``` + out = tensor[key] + ``` + """ + raise NotImplementedError + + +@get_index.trampoline +def _get_index_trampoline(d: SignatureDispatcher, tensor: AnyTensor, key: slice): + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor, key) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + @overridable def flatten(input: AnyTensor, start_dim: int = 0, end_dim: int = -1) -> AnyTensor: """See torch.flatten""" diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index a41c9ee34..d2b7b38b1 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -413,6 +413,11 @@ def __floordiv__(self, rhs): return elementwise(torch.floor_divide, self, rhs) + def __getitem__(self, key): + from ..ops import get_index + + return get_index(self, key) + REGISTERED_INFERENCE_TENSOR_CLASSES: dict[str, Type[InferenceTensor]] = {} diff --git a/sharktank/tests/kernels/einsum_q4_test.py b/sharktank/tests/kernels/einsum_q4_test.py new file mode 100644 index 000000000..d94ec5851 --- /dev/null +++ b/sharktank/tests/kernels/einsum_q4_test.py @@ -0,0 +1,141 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging + +logging.basicConfig(level=logging.DEBUG) + +import unittest +from parameterized import parameterized + +import torch + +from shark_turbine import aot +from sharktank import kernels +from sharktank.types import layout_utils + + +class einsum_2args_q4_test(unittest.TestCase): + def setUp(self): + torch.manual_seed(42) + + @parameterized.expand( + [ + (torch.float32, torch.float32, torch.float32, 1e-2, 1e-3), + (torch.float32, torch.float16, torch.float32, 1e-2, 1e-3), + (torch.float16, torch.float16, torch.float32, 1e-2, 1e-3), + ] + ) + def test_basic_mk_menk_men(self, a_dtype, d_dtype, ref_dtype, atol, rtol): + a = torch.rand([2, 320], dtype=a_dtype) / 256.0 + d = torch.rand([2, 4, 8, 10, 1], dtype=d_dtype) / 256.0 + qs = (torch.rand([2, 4, 8, 10, 16], dtype=ref_dtype) * 255.0).to(torch.uint8) + m = torch.rand([2, 4, 8, 10, 1], dtype=d_dtype) + 16.0 + einsum_string = "mk,menk->men" + result = kernels.einsum_2args_q4(a, d, qs, m, einsum_string) + + # Dequantize and test with normal matmul. + # Tolerances are empirical and results are not expected to match exactly. + qs_i8 = layout_utils.promote_linear_i4_block_to_i8(qs) + b = (d.to(ref_dtype) * qs_i8.to(ref_dtype) + m.to(ref_dtype)).flatten(3) + ref = torch.einsum(einsum_string, a.to(ref_dtype), b.to(ref_dtype)) + torch.testing.assert_close(result.to(ref_dtype), ref, atol=atol, rtol=rtol) + + @parameterized.expand( + [ + (torch.float32, torch.float32, torch.float32, 1e-2, 1e-3), + (torch.float32, torch.float16, torch.float32, 1e-2, 1e-3), + (torch.float16, torch.float16, torch.float32, 1e-2, 1e-3), + ] + ) + def test_basic_mek_menk_men(self, a_dtype, d_dtype, ref_dtype, atol, rtol): + a = torch.rand([2, 4, 320], dtype=a_dtype) / 256.0 + d = torch.rand([2, 4, 8, 10, 1], dtype=d_dtype) / 256.0 + qs = (torch.rand([2, 4, 8, 10, 16], dtype=ref_dtype) * 255.0).to(torch.uint8) + m = torch.rand([2, 4, 8, 10, 1], dtype=d_dtype) + 16.0 + einsum_string = "mek,menk->men" + result = kernels.einsum_2args_q4(a, d, qs, m, einsum_string) + + # Dequantize and test with normal matmul. + # Tolerances are empirical and results are not expected to match exactly. + qs_i8 = layout_utils.promote_linear_i4_block_to_i8(qs) + b = (d.to(ref_dtype) * qs_i8.to(ref_dtype) + m.to(ref_dtype)).flatten(3) + ref = torch.einsum(einsum_string, a.to(ref_dtype), b.to(ref_dtype)) + torch.testing.assert_close(result.to(ref_dtype), ref, atol=atol, rtol=rtol) + + @parameterized.expand( + [ + (torch.float32, torch.float32, torch.float32, 1e-2, 1e-3), + (torch.float32, torch.float16, torch.float32, 1e-2, 1e-3), + (torch.float16, torch.float16, torch.float32, 1e-2, 1e-3), + ] + ) + def test_basic_me_men_men(self, a_dtype, d_dtype, ref_dtype, atol, rtol): + a = torch.rand([2, 4], dtype=a_dtype) / 256.0 + d = torch.rand([2, 4, 10, 1], dtype=d_dtype) / 256.0 + qs = (torch.rand([2, 4, 10, 16], dtype=ref_dtype) * 255.0).to(torch.uint8) + m = torch.rand([2, 4, 10, 1], dtype=d_dtype) + 16.0 + einsum_string = "me,men->men" + result = kernels.einsum_2args_q4(a, d, qs, m, einsum_string) + + # Dequantize and test with normal matmul. + # Tolerances are empirical and results are not expected to match exactly. + qs_i8 = layout_utils.promote_linear_i4_block_to_i8(qs) + b = (d.to(ref_dtype) * qs_i8.to(ref_dtype) + m.to(ref_dtype)).flatten(2) + ref = torch.einsum(einsum_string, a.to(ref_dtype), b.to(ref_dtype)) + torch.testing.assert_close(result.to(ref_dtype), ref, atol=atol, rtol=rtol) + + def testExportDynamicDims(self): + class MyModule(torch.nn.Module): + def forward(self, a, d, qs, m): + return kernels.einsum_2args_q4(a, d, qs, m, "ij,jk->ik") + + mod = MyModule() + ep = torch.export.export( + mod, + args=( + torch.rand([16, 320], dtype=torch.float32), + torch.rand([320, 2, 1], dtype=torch.float16), + (torch.rand([320, 2, 16], dtype=torch.float32) * 32).to(torch.uint8), + torch.rand([320, 2, 1], dtype=torch.float16), + ), + dynamic_shapes={ + "a": {}, + "d": {}, + "qs": {}, + "m": {}, + }, + ) + output = aot.export(ep) + output.verify() + asm = str(output.mlir_module) + self.assertIn("@sharktank_einsum_2args_q4_ij_jk_ik_32_f32", asm) + + def testExportStaticDims(self): + class MyModule(torch.nn.Module): + def forward(self, a, d, qs, m): + return kernels.einsum_2args_q4(a, d, qs, m, "mek,menk->men") + + mod = MyModule() + ep = torch.export.export( + mod, + args=( + torch.rand([4, 16, 320], dtype=torch.float32), + torch.rand([4, 16, 2, 10, 1], dtype=torch.float16), + (torch.rand([4, 16, 2, 10, 16], dtype=torch.float32) * 32).to( + torch.uint8 + ), + torch.rand([4, 16, 2, 10, 1], dtype=torch.float16), + ), + ) + output = aot.export(ep) + output.verify() + asm = str(output.mlir_module) + self.assertIn("@sharktank_einsum_2args_q4_mek_menk_men_32_f32", asm) + + +if __name__ == "__main__": + unittest.main()