Skip to content

Commit

Permalink
Update test signature
Browse files Browse the repository at this point in the history
  • Loading branch information
nithinsubbiah committed Oct 17, 2024
1 parent 488ad7d commit 6b49a8e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
3 changes: 2 additions & 1 deletion sharktank/sharktank/kernels/batch_matmul_transpose_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
spec_sig = f"L{a_ident}_R{b_ident}"
template_file = "batch_matmul_transpose_b.mlir"
target_function_name = f"sharktank_batch_matmul_transpose_b_{spec_sig}"

cst_zero = "0." if "f" in str(accum_type) else "0"
# Template params.
c_asm_type = f"tensor<{'x'.join('?' if d is None else str(d) for d in result_desc.spec_dims)}x{accum_type}>"

Expand All @@ -93,5 +93,6 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
b_asm_type=b_asm_type,
c_asm_type=c_asm_type,
dtype=str(accum_type),
cst_zero=cst_zero,
)
kb.yield_results(*call_function(target_function, *kb.arg_bindings))
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

// !cst_zero = {{cst_zero}}
!dtype = {{dtype}}
!a_tensor_type = {{a_asm_type}}
!b_tensor_type = {{b_asm_type}}
Expand All @@ -15,7 +16,8 @@ module {
util.func private @sharktank_batch_matmul_transpose_b_{{spec_sig}}(
%a: !a_tensor_type, %b: !b_tensor_type)
-> !c_tensor_type {
%zero = arith.constant 0.: !dtype
// %zero = arith.constant !cst_zero: !dtype
%zero = arith.constant {{cst_zero}}: !dtype
%c0 = arith.constant 0: index
%c1 = arith.constant 1: index
%batch_dim = tensor.dim %a, %c0 : !a_tensor_type // b, m, k
Expand Down
10 changes: 5 additions & 5 deletions sharktank/tests/ops/qconv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def testInputSymPerTensor_WeightAsymPerChannel_NoBias(self):
)
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
ops.qconv_impls.qconv2d_tensor_scaled_integer,
ops.qconv_impls.qconv2d_tensor_scaled,
)
y_ref = torch.nn.functional.conv2d(
input_q.unpack().dequant(),
Expand Down Expand Up @@ -105,7 +105,7 @@ def testInputSymPerTensor_WeightAsymPerChannel_FloatBias(self):
y_actual = ops.conv2d(input_q, weight_q, bias, stride=1, padding=(1, 1))
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
ops.qconv_impls.qconv2d_tensor_scaled_integer,
ops.qconv_impls.qconv2d_tensor_scaled,
)
y_ref = torch.nn.functional.conv2d(
input_q.unpack().dequant(),
Expand Down Expand Up @@ -147,7 +147,7 @@ def testInputSymPerTensor_WeightAsymPerChannel_QuantizedBias(self):
)
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
ops.qconv_impls.qconv2d_tensor_scaled_integer,
ops.qconv_impls.qconv2d_tensor_scaled,
)
y_ref = torch.nn.functional.conv2d(
input_q.unpack().dequant(),
Expand Down Expand Up @@ -184,7 +184,7 @@ def testInputSymPerTensor_WeightSymPerTensor_NoBias(self):
)
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
ops.qconv_impls.qconv2d_tensor_scaled_integer,
ops.qconv_impls.qconv2d_tensor_scaled,
)
y_ref = torch.nn.functional.conv2d(
input_q.unpack().dequant(),
Expand Down Expand Up @@ -224,7 +224,7 @@ def testInputAsymPerChannel_WeightAsymPerChannel_NoBias(self):
)
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
ops.qconv_impls.qconv2d_tensor_scaled_integer,
ops.qconv_impls.qconv2d_tensor_scaled,
)
y_ref = torch.nn.functional.conv2d(
input_q.unpack().dequant(),
Expand Down

0 comments on commit 6b49a8e

Please sign in to comment.