diff --git a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py index 9fdb7cc68..11a6b5fc2 100644 --- a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py +++ b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py @@ -80,7 +80,7 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): spec_sig = f"L{a_ident}_R{b_ident}" template_file = "batch_matmul_transpose_b.mlir" target_function_name = f"sharktank_batch_matmul_transpose_b_{spec_sig}" - cst_zero = "0." if "f" in str(accum_type) else "0" + # Template params. c_asm_type = f"tensor<{'x'.join('?' if d is None else str(d) for d in result_desc.spec_dims)}x{accum_type}>" @@ -93,6 +93,5 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): b_asm_type=b_asm_type, c_asm_type=c_asm_type, dtype=str(accum_type), - cst_zero=cst_zero, ) kb.yield_results(*call_function(target_function, *kb.arg_bindings)) diff --git a/sharktank/sharktank/kernels/conv_2d_nchw_fchw.py b/sharktank/sharktank/kernels/conv_2d_nchw_fchw.py index 529511e02..9ada3b099 100644 --- a/sharktank/sharktank/kernels/conv_2d_nchw_fchw.py +++ b/sharktank/sharktank/kernels/conv_2d_nchw_fchw.py @@ -23,8 +23,6 @@ (torch.int16, torch.int16, "torch.int16"): torch.int16, (torch.int16, torch.int16, "torch.int32"): torch.int32, # Legal fp types. - (torch.float8_e4m3fnuz, torch.float8_e4m3fnuz, "torch.float16"): torch.float16, - (torch.float8_e4m3fnuz, torch.float8_e4m3fnuz, "torch.float32"): torch.float32, (torch.float16, torch.float16, "torch.float16"): torch.float16, (torch.float16, torch.float16, "torch.float32"): torch.float32, (torch.float32, torch.float32, "torch.float32"): torch.float32, @@ -35,7 +33,6 @@ torch.int8: "i8", torch.int16: "i16", torch.int32: "i32", - torch.float8_e4m3fnuz: "f8E4M3FNUZ", torch.float16: "f16", torch.float32: "f32", } diff --git a/sharktank/sharktank/kernels/templates/batch_matmul_transpose_b.mlir b/sharktank/sharktank/kernels/templates/batch_matmul_transpose_b.mlir index ccac9072f..908ca1c7f 100644 --- a/sharktank/sharktank/kernels/templates/batch_matmul_transpose_b.mlir +++ b/sharktank/sharktank/kernels/templates/batch_matmul_transpose_b.mlir @@ -4,7 +4,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// !cst_zero = {{cst_zero}} !dtype = {{dtype}} !a_tensor_type = {{a_asm_type}} !b_tensor_type = {{b_asm_type}} @@ -16,8 +15,7 @@ module { util.func private @sharktank_batch_matmul_transpose_b_{{spec_sig}}( %a: !a_tensor_type, %b: !b_tensor_type) -> !c_tensor_type { - // %zero = arith.constant !cst_zero: !dtype - %zero = arith.constant {{cst_zero}}: !dtype + %zero = arith.constant 0: !dtype %c0 = arith.constant 0: index %c1 = arith.constant 1: index %batch_dim = tensor.dim %a, %c0 : !a_tensor_type // b, m, k diff --git a/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py b/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py index 3cbf63dc8..825301797 100644 --- a/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py +++ b/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py @@ -208,7 +208,7 @@ def quantize_bias( bias_scale = 1.0 / (input_scale * weight_scale) bias_quantizer = StaticScaledQuantizer( scale=bias_scale, - dtype=torch.int32 if quantization_dtype == torch.int8 else torch.float16, + dtype=torch.int32, disable_saturate=True, ) bias_quant = bias_quantizer.quantize(bias, name=bias_name) diff --git a/sharktank/sharktank/ops/qconv_impls.py b/sharktank/sharktank/ops/qconv_impls.py index 6adf0c99a..af1199976 100644 --- a/sharktank/sharktank/ops/qconv_impls.py +++ b/sharktank/sharktank/ops/qconv_impls.py @@ -31,7 +31,7 @@ ) -def qconv2d_tensor_scaled( +def qconv2d_tensor_scaled_integer( input: QuantizedTensor, weight: QuantizedTensor, bias: Optional[AnyTensor] = None, @@ -59,16 +59,12 @@ def qconv2d_tensor_scaled( input_layout: TensorScaledLayout = input.unpack() weight_layout: TensorScaledLayout = weight.unpack() - # # Handle integer and fp8 quantizations. + # Only handle integer quantizations. if ( input_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point ): - if ( - input_layout.qs.dtype != torch.float8_e4m3fnuz - or weight_layout.qs.dtype != torch.float8_e4m3fnuz - ): - return NotImplemented + return NotImplemented # Bias is both optional and may either be quantized or fp. bias_qs = None @@ -89,10 +85,7 @@ def qconv2d_tensor_scaled( # Alias components (d=scale, qs=quantized samples, m=offset). if accum_dtype is None: - if weight_layout.qs.dtype.is_floating_point: - accum_dtype = torch.float32 - else: - accum_dtype = torch.int32 + accum_dtype = torch.int32 input_d = input_layout.d input_dtype = input_layout.dtype input_qs = input_layout.qs @@ -121,7 +114,7 @@ def qconv2d_tensor_scaled( dilation = _expand_int_to_2_tuple(dilation) extended_padding_list = [item for item in padding for _ in range(2)] padded_input = _pad_last_2d(input_qs, extended_padding_list) - y_qs = _invoke_conv2d_kernel( + y_qs = _invoke_int32_conv2d( padded_input, weight_qs, bias_qs.to(accum_dtype) if bias_qs is not None else None, @@ -152,7 +145,7 @@ def qconv2d_tensor_scaled( weight_offset_fix = torch.sum( padded_input, dim=1, keepdim=True, dtype=accum_dtype ) - weight_offset_fix = _invoke_pooling_sum_kernel( + weight_offset_fix = _invoke_int32_pooling_sum( weight_offset_fix, [weight_qs.shape[2], weight_qs.shape[3]], stride, @@ -195,11 +188,13 @@ def qconv2d_tensor_scaled( return y -conv2d.override(QuantizedTensor, QuantizedTensor)(qconv2d_tensor_scaled) -conv2d.override(QuantizedTensor, QuantizedTensor, AnyTensor)(qconv2d_tensor_scaled) +conv2d.override(QuantizedTensor, QuantizedTensor)(qconv2d_tensor_scaled_integer) +conv2d.override(QuantizedTensor, QuantizedTensor, AnyTensor)( + qconv2d_tensor_scaled_integer +) -def _invoke_conv2d_kernel(input, weight, bias, stride, dilation, *, accum_dtype): +def _invoke_int32_conv2d(input, weight, bias, stride, dilation, *, accum_dtype): """Does a low level invocation of a conv2d integer kernel on an explicitly padded input. This presumes that the stride/padding/dilation have already been normalized @@ -238,7 +233,7 @@ def _invoke_conv2d_kernel(input, weight, bias, stride, dilation, *, accum_dtype) return y_qs -def _invoke_pooling_sum_kernel(input, kernel_size, stride, dilation, *, accum_dtype): +def _invoke_int32_pooling_sum(input, kernel_size, stride, dilation, *, accum_dtype): """Invokes either a custom integer pooling sum or the built-in fp avg_pool2d kernel on an explicitly padded input. """ diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py index d63261dab..0a381d613 100644 --- a/sharktank/sharktank/ops/qlinear_impls.py +++ b/sharktank/sharktank/ops/qlinear_impls.py @@ -28,7 +28,7 @@ from sharktank import kernels -def qlinear_tensor_scaled( +def qlinear_tensor_scaled_integer( x: QuantizedTensor, weight: QuantizedTensor, bias: Optional[AnyTensor], @@ -48,11 +48,8 @@ def qlinear_tensor_scaled( x_layout: TensorScaledLayout = x.unpack() weight_layout: TensorScaledLayout = weight.unpack() - # Handle integer and fp8 quantizations. - if ( - x_layout.qs.dtype != torch.float8_e4m3fnuz - and weight_layout.qs.dtype != torch.float8_e4m3fnuz - ): + # Only handle integer quantizations. + if x_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point: return NotImplemented # Bias. @@ -67,10 +64,7 @@ def qlinear_tensor_scaled( # Alias components (d=scale, qs=quantized samples, m=offset) if accum_dtype is None: - if weight_layout.qs.dtype.is_floating_point: - accum_dtype = torch.float32 - else: - accum_dtype = torch.int32 + accum_dtype = torch.int32 x_d = x_layout.d x_dtype = x_layout.dtype x_qs = x_layout.qs @@ -92,7 +86,7 @@ def qlinear_tensor_scaled( # TODO: Handle permutation that we have a kernel for. # Fall back to automatic fusion based on integer, high precision matmul. - y_qs = _invoke_mmt_kernel(x_qs, weight_qs, accum_dtype=accum_dtype) + y_qs = _invoke_int32_mmt(x_qs, weight_qs, accum_dtype=accum_dtype) # Offset correction. By applying the offset correction in post, it is # set up to fuse with its consumer, which is already doing additional @@ -149,8 +143,10 @@ def qlinear_tensor_scaled( # Overrload for both bias and no bias. -linear.override(QuantizedTensor, QuantizedTensor)(qlinear_tensor_scaled) -linear.override(QuantizedTensor, QuantizedTensor, AnyTensor)(qlinear_tensor_scaled) +linear.override(QuantizedTensor, QuantizedTensor)(qlinear_tensor_scaled_integer) +linear.override(QuantizedTensor, QuantizedTensor, AnyTensor)( + qlinear_tensor_scaled_integer +) def linear_quantized_weight( @@ -170,30 +166,19 @@ def linear_quantized_weight( linear.override(Tensor, QuantizedTensor, AnyTensor)(linear_quantized_weight) -def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype): +def _invoke_int32_mmt(lhs, rhs, *, accum_dtype): if debugging.flags.use_custom_iree_kernels: # The custom kernel requires that the lhs and rhs be the same # rank. Broadcast the rhs to match. lhs_rank = len(lhs.shape) rhs_rank = len(rhs.shape) - # If input to the kernel is 2D, expand the tensor to add the batch - # dimension. - if lhs_rank == 2: - lhs_size = [1] + list(lhs.shape) - lhs = lhs.unsqueeze(0).expand(lhs_size) - lhs_rank = len(lhs.shape) if rhs_rank < lhs_rank: assert (rhs_rank + 1) == lhs_rank rhs_size = [lhs.shape[0]] + list(rhs.shape) rhs = rhs.unsqueeze(0).expand(rhs_size) - rhs_rank = len(rhs.shape) y_qs = kernels.batch_matmul_transpose_b( lhs.to(accum_dtype), rhs.to(accum_dtype) ) - # Squeeze the batch dimension to maintain shape parity with other - # layers. - if len(y_qs.shape) > 2: - y_qs = y_qs.squeeze(0) else: # FP emulation. y_qs = torch.matmul( diff --git a/sharktank/tests/ops/qconv_test.py b/sharktank/tests/ops/qconv_test.py index 97b0efd66..4440202eb 100644 --- a/sharktank/tests/ops/qconv_test.py +++ b/sharktank/tests/ops/qconv_test.py @@ -71,7 +71,7 @@ def testInputSymPerTensor_WeightAsymPerChannel_NoBias(self): ) self.assertIs( ops._registry._test_get_last_op_dispatch(), - ops.qconv_impls.qconv2d_tensor_scaled, + ops.qconv_impls.qconv2d_tensor_scaled_integer, ) y_ref = torch.nn.functional.conv2d( input_q.unpack().dequant(), @@ -105,7 +105,7 @@ def testInputSymPerTensor_WeightAsymPerChannel_FloatBias(self): y_actual = ops.conv2d(input_q, weight_q, bias, stride=1, padding=(1, 1)) self.assertIs( ops._registry._test_get_last_op_dispatch(), - ops.qconv_impls.qconv2d_tensor_scaled, + ops.qconv_impls.qconv2d_tensor_scaled_integer, ) y_ref = torch.nn.functional.conv2d( input_q.unpack().dequant(), @@ -147,7 +147,7 @@ def testInputSymPerTensor_WeightAsymPerChannel_QuantizedBias(self): ) self.assertIs( ops._registry._test_get_last_op_dispatch(), - ops.qconv_impls.qconv2d_tensor_scaled, + ops.qconv_impls.qconv2d_tensor_scaled_integer, ) y_ref = torch.nn.functional.conv2d( input_q.unpack().dequant(), @@ -184,7 +184,7 @@ def testInputSymPerTensor_WeightSymPerTensor_NoBias(self): ) self.assertIs( ops._registry._test_get_last_op_dispatch(), - ops.qconv_impls.qconv2d_tensor_scaled, + ops.qconv_impls.qconv2d_tensor_scaled_integer, ) y_ref = torch.nn.functional.conv2d( input_q.unpack().dequant(), @@ -224,7 +224,7 @@ def testInputAsymPerChannel_WeightAsymPerChannel_NoBias(self): ) self.assertIs( ops._registry._test_get_last_op_dispatch(), - ops.qconv_impls.qconv2d_tensor_scaled, + ops.qconv_impls.qconv2d_tensor_scaled_integer, ) y_ref = torch.nn.functional.conv2d( input_q.unpack().dequant(),