From 478e415ac575fe58f7204210815acfc16054078c Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Tue, 24 Sep 2024 19:21:12 -0500 Subject: [PATCH] Hook the einsum kernel and the __getitem__ operation into their implementations --- .../sharktank/kernels/einsum_2args_q4.py | 6 +- .../kernels/templates/einsum_2args_q4.mlir | 3 +- sharktank/sharktank/ops/custom_impls.py | 13 ++++ sharktank/sharktank/ops/default_impls.py | 39 +++++++++++- sharktank/sharktank/ops/signatures.py | 59 +++++++++++++++++++ sharktank/sharktank/types/tensors.py | 5 ++ 6 files changed, 119 insertions(+), 6 deletions(-) diff --git a/sharktank/sharktank/kernels/einsum_2args_q4.py b/sharktank/sharktank/kernels/einsum_2args_q4.py index 75d76b8cd..9ee90b4c3 100644 --- a/sharktank/sharktank/kernels/einsum_2args_q4.py +++ b/sharktank/sharktank/kernels/einsum_2args_q4.py @@ -72,6 +72,7 @@ def einsum_util(einsum_str): return ( (in0_idx, in1_idx, out_idx), iterators, + (affine_map_in0, affine_map_in1, affine_map_out), indexing_maps, out_dyn_dim_size_str, ) @@ -127,7 +128,6 @@ def select(self, ksel: KernelSelection): m_desc.t.dtype == d_desc.t.dtype and len(m_dims) == len(qs_dims), lambda: f"einsum_2args_q4 arg 'm': Incorrect dtype (got {m_desc.t.dtype})", ) - # einsum_str torch._check( einsum_str.count(",") == 1 and einsum_str.count("->") == 1, @@ -139,9 +139,7 @@ def select(self, ksel: KernelSelection): es_set = set(es_out) shp = qs_desc.t.shape - print(shp) b_dims = list(shp[:-2]) + [shp[-2] * block_size] - print(b_dims) torch._check( len(es_in0) == len(a_desc.t.shape) and len(es_in1) @@ -213,6 +211,7 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): ( (es_0, es_1, es_2), einsum_iterators, + _, einsum_indexing_maps, oddss, ) = einsum_util(einsum_str) @@ -262,5 +261,4 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): c_size=len(es_out), out_dyn_dim_size_str=oddss, ) - print(target_function) kb.yield_results(*call_function(target_function, *kb.arg_bindings)) diff --git a/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir b/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir index 9c5a49fa8..cc4fac190 100644 --- a/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir +++ b/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir @@ -25,6 +25,7 @@ module { util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}( %a: !a_tensor_type, %d: !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type) -> !c_tensor_type { + %debug = tensor.empty() : tensor<1xf32> %zero = arith.constant 0.0: !accum_type // todo: loop {% for i in range(a_size) %} @@ -43,7 +44,7 @@ util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}( %b_unblocked_dim = arith.muli %b{{b_size-1}}, %bs : index //%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type -> !qs_tensor_type - %qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} -> !qs_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b_unblocked_dim{{"}"}} + %qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} -> !qs_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} // Dequantize. // todo: loop diff --git a/sharktank/sharktank/ops/custom_impls.py b/sharktank/sharktank/ops/custom_impls.py index fe6ae27b1..d1df15cbc 100644 --- a/sharktank/sharktank/ops/custom_impls.py +++ b/sharktank/sharktank/ops/custom_impls.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from ..kernels import ( + einsum_2args_q4, mmt_block_scaled_offset_q4_unsigned, mmt_block_scaled_q8, mmtfp, @@ -44,6 +45,18 @@ # return mmtfp(lhs, rhs) +# Einsum + + +@einsum_2args.override(Tensor, QuantizedTensor) +def einsum_2args_QuantizedTensor(input0, input1, einsum_str): + unpacked = input1.unpack() + layout = input1.layout_type + if not isinstance(unpacked, BlockScaledI4Layout): + return NotImplemented + return einsum_2args_q4(input0, unpacked.d, unpacked._qs, unpacked.m, einsum_str) + + # Quantized Matmul diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index 0ab6053c2..ab1cc9283 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -14,7 +14,13 @@ import torch.nn.functional as F from numbers import Number -from ..types import PrimitiveTensor, QuantizedTensor, InferenceTensor +from ..types import ( + PrimitiveTensor, + QuantizedTensor, + InferenceTensor, + PlanarQuantizedTensor, + BlockScaledI4Layout, +) from ..types.tensors import unbox_tensor, AnyTensor from ._registry import AllOfType, AllOfExprs, AllOfExprsVariadic, IsOfType from .signatures import * @@ -62,6 +68,12 @@ def conv2d_default( conv2d.override(Tensor, Tensor, Tensor, auto_dequant=True)(conv2d_default) conv2d.override(Tensor, Tensor, auto_dequant=True)(conv2d_default) +# Einsum +@einsum_2args.override(AllOfType(Tensor, PrimitiveTensor)) +def einsum_2args(x, y, einsum_str): + return torch.einsum(einsum_str, unbox_tensor(x), unbox_tensor(y)) + + # Elementwise @elementwise.override(Tensor) def elementwise_unary(operator, x): @@ -133,6 +145,31 @@ def equal_default(a, b) -> bool: return torch.equal(unbox_tensor(a), unbox_tensor(b)) +@get_index.override(AllOfType(Tensor, PrimitiveTensor)) +def get_index_default(tensor, key: slice): + return unbox_tensor(tensor).__get_item__(key) + + +@get_index.override(QuantizedTensor) +def get_index_QuantizedTensor(tensor: QuantizedTensor, key: slice): + unpacked = tensor.unpack() + if isinstance(unpacked, BlockScaledI4Layout): + mul = 2 + else: + return NotImplemented + new_d = unpacked._d[key] + new_qs = unpacked._qs[key] + if unpacked.m is not None: + new_m = unpacked.m[key] + dims = new_qs.shape + dims = dims[:-2] + (dims[-2] * dims[-1] * mul,) + layout = BlockScaledI4Layout(shape=dims, d=new_d, qs=new_qs, m=new_m) + return PlanarQuantizedTensor(shape=dims, layout=layout) + + +# get_index.override(PlanarQuantizedTensor, slice)(get_index_QuantizedTensor) + + @gemm.override(AllOfType(Tensor, InferenceTensor)) def gemm( a: AnyTensor, diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index d39aba71c..d129b4830 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -21,9 +21,11 @@ "all_reduce", "cat", "conv2d", + "einsum_2args", "elementwise", "embedding_lookup", "equal", + "get_index", "gemm", "group_norm_affine", "layer_norm", @@ -151,6 +153,37 @@ def _conv2d_trampoline( d.fail(tensors) +@overridable +def einsum_2args( + input0: AnyTensor, + input1: AnyTensor, + einsum_str: str, + *, + accum_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """Executes a given Einstein summation notation string on the provided tensors. + + Equivalent to: + ``` + y = torch.einsum(einsum_str, input0, input1) + ``` + """ + raise NotImplementedError + + +@einsum_2args.trampoline +def _einsum_trampoline( + d: SignatureDispatcher, input0: AnyTensor, input1: AnyTensor, einsum_str: str +): + tensors = (input0, input1) + for override in d.find_overrides(tensors): + result = override(input0, input1, einsum_str) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + @overridable def elementwise(operator, *args: AnyTensor) -> AnyTensor: """Applies an elementwise operator against arguments.""" @@ -232,6 +265,32 @@ def _equal_trampoline(d: SignatureDispatcher, a: AnyTensor, b: AnyTensor): d.fail(tensors) +@overridable +def get_index( + tensor: AnyTensor, + key: slice, +) -> torch.Tensor: + """Indexes the tensor using the key. + + Equivalent to: + ``` + out = tensor[key] + ``` + """ + raise NotImplementedError + + +@get_index.trampoline +def _get_index_trampoline(d: SignatureDispatcher, tensor: AnyTensor, key: slice): + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor, key) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + @overridable def gemm( a: AnyTensor, diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 3545608a9..7abc39c3e 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -298,6 +298,11 @@ def __rmul__(self, lhs): # numbers on the lhs. return self.__mul__(lhs) + def __getitem__(self, key): + from ..ops import get_index + + return get_index(self, key) + REGISTERED_INFERENCE_TENSOR_CLASSES: dict[str, Type[InferenceTensor]] = {}