Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add special einsum cases that lower to batch matmul #262

Merged
merged 2 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions sharktank/sharktank/kernels/mmt_block_scaled_offset_q4.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,38 +37,43 @@ def select(self, ksel: KernelSelection):
m_desc = ksel.arg_tensor(3) # Shape [N, K // BLOCK_SIZE, 1]

# a arg
*batch_dims, a_m, a_k = a_desc.t.shape
*a_batch_dims, a_m, a_k = a_desc.t.shape
torch._check(
a_desc.t.dtype.is_floating_point,
lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'a': Expected floating point (got {a_desc.t.dtype})",
)
torch._check(
len(batch_dims) == 1,
len(a_batch_dims) == 1,
lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'a': Expected 3d tensor (got {a_desc.t.shape})",
)

# qs arg
qs_n, qs_group0, qs_bs_div_2, *rest = qs_desc.t.shape
*qs_batch_dims, qs_n, qs_group0, qs_bs_div_2 = qs_desc.t.shape
torch._check(
len(rest) == 0 and (qs_group0 * qs_bs_div_2 * 2) == a_k,
(
len(qs_batch_dims) == 0
or len(qs_batch_dims) == 1
and qs_batch_dims == a_batch_dims
)
and (qs_group0 * qs_bs_div_2 * 2) == a_k,
lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'qs': Incorrect shape (got {qs_desc.t.shape})",
)
block_size = qs_bs_div_2 * 2

# d arg
d_n, d_group0, d_one, *rest = d_desc.t.shape
*d_batch_dims, d_n, d_group0, d_one = d_desc.t.shape
torch._check(
len(rest) == 0
d_batch_dims == qs_batch_dims
and (d_group0 * block_size) == a_k
and d_one == 1
and d_n == qs_n,
lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'd': Incorrect shape (got {d_desc.t.shape})",
)

# m arg
m_n, m_group0, m_one, *rest = m_desc.t.shape
*m_batch_dims, m_n, m_group0, m_one = m_desc.t.shape
torch._check(
len(rest) == 0
m_batch_dims == qs_batch_dims
and (m_group0 * block_size) == a_k
and m_one == 1
and m_n == qs_n,
Expand All @@ -81,12 +86,17 @@ def select(self, ksel: KernelSelection):

# Specialize on K, N, BS
a_desc.specialize_dims(-1)
qs_desc.specialize_all_dims()
d_desc.specialize_all_dims()
m_desc.specialize_all_dims()
if len(qs_batch_dims) == 0:
qs_desc.specialize_all_dims()
d_desc.specialize_all_dims()
m_desc.specialize_all_dims()
else:
qs_desc.specialize_dims(1, 2, 3)
d_desc.specialize_dims(1, 2, 3)
m_desc.specialize_dims(1, 2, 3)

# Shape batch..., m, n
c_desc = ksel.return_new_tensor(batch_dims + [a_m, d_n], dtype=a_desc.t.dtype)
c_desc = ksel.return_new_tensor(a_batch_dims + [a_m, d_n], dtype=a_desc.t.dtype)
c_desc.specialize_dims(-1)

def generate(self, ksel: KernelSelection, kb: KernelBuilder):
Expand All @@ -99,13 +109,14 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):

rank = a_tensor_type.rank
k = a_tensor_type.get_dim_size(rank - 1)
n, group0, bs_i8 = qs_tensor_type.shape
*qs_batch_dims, n, group0, bs_i8 = qs_tensor_type.shape
batched_rhs = len(qs_batch_dims) == 1
bs = bs_i8 * 2 # 2 nibbles per byte.
a_type_str = str(a_tensor_type.element_type)
scale_type_str = str(d_tensor_type.element_type)

template_file = "mmt_block_scaled_offset_q4_unsigned.mlir"
target_function_name = f"sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{bs}_{a_type_str}"
target_function_name = f"sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{bs}_{a_type_str}_{batched_rhs}"

target_function = inline_template_function(
kb,
Expand All @@ -118,5 +129,6 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
group0=group0,
a_type=a_type_str,
scale_type=scale_type_str,
batched_rhs=batched_rhs,
)
kb.yield_results(*call_function(target_function, *kb.arg_bindings))
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,25 @@
!accum_type = {{accum_type}}
!a_tensor_type = tensor<?x?x{{k}}x!a_type>
!aexp_tensor_type = tensor<?x?x{{group0}}x{{bs}}x!a_type>
{% if batched_rhs %}
!qs_raw_tensor_type = tensor<?x{{n}}x{{group0}}x{{bs_i8}}xi8>
!qs_tensor_type = tensor<?x{{n}}x{{group0}}x{{bs}}x!lowp_type>
!d_tensor_type = tensor<?x{{n}}x{{group0}}x1x!scale_type>
!m_tensor_type = tensor<?x{{n}}x{{group0}}x1x!scale_type>
!b_grouped_tensor_type = tensor<?x{{n}}x{{group0}}x{{bs}}x!a_type>
{% else %}
!qs_raw_tensor_type = tensor<{{n}}x{{group0}}x{{bs_i8}}xi8>
!qs_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!lowp_type>
!d_tensor_type = tensor<{{n}}x{{group0}}x1x!scale_type>
!m_tensor_type = tensor<{{n}}x{{group0}}x1x!scale_type>
!b_grouped_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!a_type>
{% endif %}
!accum_tensor_type = tensor<?x?x{{n}}x!accum_type>
!c_tensor_type = tensor<?x?x{{n}}x!a_type>
!b_grouped_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!a_type>

module {

util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_{{bs}}_{{a_type}}(
util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_{{bs}}_{{a_type}}_{{batched_rhs}}(
%a: !a_tensor_type, %d: !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type)
-> !c_tensor_type {
%zero = arith.constant 0.0: !accum_type
Expand All @@ -32,17 +40,31 @@ util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_
%m_dim = tensor.dim %a, %c1 : !a_tensor_type

// Cast qs_raw from i8 to lowp type.
{% if batched_rhs %}
%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{ %batch0_dim } -> !qs_tensor_type{ %batch0_dim }
%b_grouped = tensor.empty(%batch0_dim) : !b_grouped_tensor_type
{% else %}
%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type -> !qs_tensor_type
%b_grouped = tensor.empty() : !b_grouped_tensor_type
{% endif %}

// Dequantize.
%b_grouped = tensor.empty() : !b_grouped_tensor_type
%b_grouped_dequant = linalg.generic {
{% if batched_rhs %}
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
{% else %}
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"] }
{% endif %}
ins(%d, %m, %qs : !d_tensor_type, !m_tensor_type, !qs_tensor_type)
outs(%b_grouped : !b_grouped_tensor_type) {
^bb0(%d_element: !scale_type, %m_element: !scale_type, %q_element: !lowp_type, %out: !a_type):
Expand Down Expand Up @@ -70,7 +92,7 @@ util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_
indexing_maps = [
// d0 = b, d1 = m, d2 = n, d3 = group0 (r), d4 = block (r)
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> ({% if batched_rhs %}d0,{% endif %} d2, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] }
ins(%aexp, %b_grouped_dequant : !aexp_tensor_type, !b_grouped_tensor_type)
Expand Down
71 changes: 67 additions & 4 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,56 @@ def conv2d_default(
conv2d.override(Tensor, Tensor, auto_dequant=True)(conv2d_default)

# Einsum
@einsum_2args.override(AllOfType(Tensor, PrimitiveTensor))
def einsum_2args(x, y, einsum_str):
return torch.einsum(einsum_str, unbox_tensor(x), unbox_tensor(y))
def mk_menk_men(inputs, weights):
# batch dims: m, lhs pdims: none, lhs rdims: k, rhs pdims: en, rhs rdims: k
inputs = inputs.unsqueeze(1)
weights_shape = weights.shape
weights = weights.view(
weights_shape[0], weights_shape[1] * weights_shape[2], weights_shape[3]
)
result = matmul(inputs, weights, transpose_rhs=True)
result = result.view(weights_shape[0], weights_shape[1], weights_shape[2])
return result


def mek_menk_men(inputs, weights):
# batch dims: me, lhs pdims: none, lhs rdims: k, rhs pdims: n, rhs rdims: k
inputs_shape = inputs.shape
inputs = inputs.view(inputs_shape[0] * inputs_shape[1], 1, inputs_shape[2])
weights_shape = weights.shape
weights = weights.view(
weights_shape[0] * weights_shape[1], weights_shape[2], weights_shape[3]
)
result = matmul(inputs, weights, transpose_rhs=True)
result = result.view(weights_shape[0], weights_shape[1], weights_shape[2])
return result


def me_men_men(inputs, weights):
# batch dims: me, lhs pdims: none, lhs rdims: none, rhs pdims: n, rhs rdims: none
inputs_shape = inputs.shape
inputs = inputs.view(inputs_shape[0] * inputs_shape[1], 1, 1)
weights_shape = weights.shape
weights = weights.view(weights_shape[0] * weights_shape[1], weights_shape[2], 1)
result = matmul(inputs, weights, transpose_rhs=True)
result = result.view(weights_shape[0], weights_shape[1], weights_shape[2])
return result


@einsum_2args.override(AllOfType(Tensor, PrimitiveTensor, QuantizedTensor))
def einsum_2args(input0, input1, einsum_str):
# Special optimized einsum kernels that lower to batch matmul
if einsum_str == "mk,menk->men":
return mk_menk_men(input0, input1)
elif einsum_str == "mek,menk->men":
return mek_menk_men(input0, input1)
elif einsum_str == "me,men->men":
return me_men_men(input0, input1)
# Default non-QuantizedTensor einsum
if not isinstance(input1, QuantizedTensor):
return torch.einsum(einsum_str, unbox_tensor(x), unbox_tensor(y))
# Fallback to other kernels
return NotImplemented


# Elementwise
Expand Down Expand Up @@ -307,7 +354,7 @@ def matmul_default(lhs, rhs, *, transpose_rhs: bool) -> Tensor:
lhs = unbox_tensor(lhs)
rhs = unbox_tensor(rhs)
if transpose_rhs:
rhs = rhs.T
rhs = rhs.mT
return torch.matmul(lhs, rhs.to(lhs.dtype))


Expand Down Expand Up @@ -433,3 +480,19 @@ def unsqueeze_default(tensor: Union[Tensor, PrimitiveTensor], dim: int) -> Tenso
@view.override(Tensor)
def view_default(tensor: Union[Tensor, PrimitiveTensor], shape: List[int]) -> Tensor:
return unbox_tensor(tensor).view(*shape)


@view.override(QuantizedTensor)
def view_QuantizedTensor(tensor: QuantizedTensor, shape):
unpacked = tensor.unpack()
if not isinstance(unpacked, BlockScaledI4Layout):
return NotImplemented
bs = 16
shape = list(shape)
new_d = unpacked._d.view(shape[:-1] + [shape[-1] // 32, 1])
qs_shape = shape[:-1] + [shape[-1] // 32, 16]
new_qs = unpacked._qs.view(qs_shape)
if unpacked.m is not None:
new_m = unpacked.m.view(shape[:-1] + [shape[-1] // 32, 1])
layout = BlockScaledI4Layout(shape=shape, d=new_d, qs=new_qs, m=new_m)
return PlanarQuantizedTensor(shape=shape, layout=layout)
2 changes: 1 addition & 1 deletion sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def unsqueeze(self, dim: int) -> "AnyTensor":
def view(self, *args: Union[List[List[int]], List[int]]) -> "AnyTensor":
from ..ops import view

if all(isinstance(a, int) for a in args):
if all(isinstance(a, int) or isinstance(a, torch.SymInt) for a in args):
shape = args
else:
assert len(args) == 1
Expand Down
Loading