Skip to content

Commit

Permalink
Revert "Add fp8 quantization for conv and linear layers (#277)"
Browse files Browse the repository at this point in the history
This reverts commit ec5672e.
  • Loading branch information
nithinsubbiah authored Oct 17, 2024
1 parent ec5672e commit 0022804
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 56 deletions.
3 changes: 1 addition & 2 deletions sharktank/sharktank/kernels/batch_matmul_transpose_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
spec_sig = f"L{a_ident}_R{b_ident}"
template_file = "batch_matmul_transpose_b.mlir"
target_function_name = f"sharktank_batch_matmul_transpose_b_{spec_sig}"
cst_zero = "0." if "f" in str(accum_type) else "0"

# Template params.
c_asm_type = f"tensor<{'x'.join('?' if d is None else str(d) for d in result_desc.spec_dims)}x{accum_type}>"

Expand All @@ -93,6 +93,5 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
b_asm_type=b_asm_type,
c_asm_type=c_asm_type,
dtype=str(accum_type),
cst_zero=cst_zero,
)
kb.yield_results(*call_function(target_function, *kb.arg_bindings))
3 changes: 0 additions & 3 deletions sharktank/sharktank/kernels/conv_2d_nchw_fchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
(torch.int16, torch.int16, "torch.int16"): torch.int16,
(torch.int16, torch.int16, "torch.int32"): torch.int32,
# Legal fp types.
(torch.float8_e4m3fnuz, torch.float8_e4m3fnuz, "torch.float16"): torch.float16,
(torch.float8_e4m3fnuz, torch.float8_e4m3fnuz, "torch.float32"): torch.float32,
(torch.float16, torch.float16, "torch.float16"): torch.float16,
(torch.float16, torch.float16, "torch.float32"): torch.float32,
(torch.float32, torch.float32, "torch.float32"): torch.float32,
Expand All @@ -35,7 +33,6 @@
torch.int8: "i8",
torch.int16: "i16",
torch.int32: "i32",
torch.float8_e4m3fnuz: "f8E4M3FNUZ",
torch.float16: "f16",
torch.float32: "f32",
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

// !cst_zero = {{cst_zero}}
!dtype = {{dtype}}
!a_tensor_type = {{a_asm_type}}
!b_tensor_type = {{b_asm_type}}
Expand All @@ -16,8 +15,7 @@ module {
util.func private @sharktank_batch_matmul_transpose_b_{{spec_sig}}(
%a: !a_tensor_type, %b: !b_tensor_type)
-> !c_tensor_type {
// %zero = arith.constant !cst_zero: !dtype
%zero = arith.constant {{cst_zero}}: !dtype
%zero = arith.constant 0: !dtype
%c0 = arith.constant 0: index
%c1 = arith.constant 1: index
%batch_dim = tensor.dim %a, %c0 : !a_tensor_type // b, m, k
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def quantize_bias(
bias_scale = 1.0 / (input_scale * weight_scale)
bias_quantizer = StaticScaledQuantizer(
scale=bias_scale,
dtype=torch.int32 if quantization_dtype == torch.int8 else torch.float16,
dtype=torch.int32,
disable_saturate=True,
)
bias_quant = bias_quantizer.quantize(bias, name=bias_name)
Expand Down
29 changes: 12 additions & 17 deletions sharktank/sharktank/ops/qconv_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)


def qconv2d_tensor_scaled(
def qconv2d_tensor_scaled_integer(
input: QuantizedTensor,
weight: QuantizedTensor,
bias: Optional[AnyTensor] = None,
Expand Down Expand Up @@ -59,16 +59,12 @@ def qconv2d_tensor_scaled(
input_layout: TensorScaledLayout = input.unpack()
weight_layout: TensorScaledLayout = weight.unpack()

# # Handle integer and fp8 quantizations.
# Only handle integer quantizations.
if (
input_layout.qs.dtype.is_floating_point
or weight_layout.qs.dtype.is_floating_point
):
if (
input_layout.qs.dtype != torch.float8_e4m3fnuz
or weight_layout.qs.dtype != torch.float8_e4m3fnuz
):
return NotImplemented
return NotImplemented

# Bias is both optional and may either be quantized or fp.
bias_qs = None
Expand All @@ -89,10 +85,7 @@ def qconv2d_tensor_scaled(

# Alias components (d=scale, qs=quantized samples, m=offset).
if accum_dtype is None:
if weight_layout.qs.dtype.is_floating_point:
accum_dtype = torch.float32
else:
accum_dtype = torch.int32
accum_dtype = torch.int32
input_d = input_layout.d
input_dtype = input_layout.dtype
input_qs = input_layout.qs
Expand Down Expand Up @@ -121,7 +114,7 @@ def qconv2d_tensor_scaled(
dilation = _expand_int_to_2_tuple(dilation)
extended_padding_list = [item for item in padding for _ in range(2)]
padded_input = _pad_last_2d(input_qs, extended_padding_list)
y_qs = _invoke_conv2d_kernel(
y_qs = _invoke_int32_conv2d(
padded_input,
weight_qs,
bias_qs.to(accum_dtype) if bias_qs is not None else None,
Expand Down Expand Up @@ -152,7 +145,7 @@ def qconv2d_tensor_scaled(
weight_offset_fix = torch.sum(
padded_input, dim=1, keepdim=True, dtype=accum_dtype
)
weight_offset_fix = _invoke_pooling_sum_kernel(
weight_offset_fix = _invoke_int32_pooling_sum(
weight_offset_fix,
[weight_qs.shape[2], weight_qs.shape[3]],
stride,
Expand Down Expand Up @@ -195,11 +188,13 @@ def qconv2d_tensor_scaled(
return y


conv2d.override(QuantizedTensor, QuantizedTensor)(qconv2d_tensor_scaled)
conv2d.override(QuantizedTensor, QuantizedTensor, AnyTensor)(qconv2d_tensor_scaled)
conv2d.override(QuantizedTensor, QuantizedTensor)(qconv2d_tensor_scaled_integer)
conv2d.override(QuantizedTensor, QuantizedTensor, AnyTensor)(
qconv2d_tensor_scaled_integer
)


def _invoke_conv2d_kernel(input, weight, bias, stride, dilation, *, accum_dtype):
def _invoke_int32_conv2d(input, weight, bias, stride, dilation, *, accum_dtype):
"""Does a low level invocation of a conv2d integer kernel on an explicitly padded input.
This presumes that the stride/padding/dilation have already been normalized
Expand Down Expand Up @@ -238,7 +233,7 @@ def _invoke_conv2d_kernel(input, weight, bias, stride, dilation, *, accum_dtype)
return y_qs


def _invoke_pooling_sum_kernel(input, kernel_size, stride, dilation, *, accum_dtype):
def _invoke_int32_pooling_sum(input, kernel_size, stride, dilation, *, accum_dtype):
"""Invokes either a custom integer pooling sum or the built-in fp avg_pool2d
kernel on an explicitly padded input.
"""
Expand Down
35 changes: 10 additions & 25 deletions sharktank/sharktank/ops/qlinear_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from sharktank import kernels


def qlinear_tensor_scaled(
def qlinear_tensor_scaled_integer(
x: QuantizedTensor,
weight: QuantizedTensor,
bias: Optional[AnyTensor],
Expand All @@ -48,11 +48,8 @@ def qlinear_tensor_scaled(
x_layout: TensorScaledLayout = x.unpack()
weight_layout: TensorScaledLayout = weight.unpack()

# Handle integer and fp8 quantizations.
if (
x_layout.qs.dtype != torch.float8_e4m3fnuz
and weight_layout.qs.dtype != torch.float8_e4m3fnuz
):
# Only handle integer quantizations.
if x_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point:
return NotImplemented

# Bias.
Expand All @@ -67,10 +64,7 @@ def qlinear_tensor_scaled(

# Alias components (d=scale, qs=quantized samples, m=offset)
if accum_dtype is None:
if weight_layout.qs.dtype.is_floating_point:
accum_dtype = torch.float32
else:
accum_dtype = torch.int32
accum_dtype = torch.int32
x_d = x_layout.d
x_dtype = x_layout.dtype
x_qs = x_layout.qs
Expand All @@ -92,7 +86,7 @@ def qlinear_tensor_scaled(
# TODO: Handle permutation that we have a kernel for.

# Fall back to automatic fusion based on integer, high precision matmul.
y_qs = _invoke_mmt_kernel(x_qs, weight_qs, accum_dtype=accum_dtype)
y_qs = _invoke_int32_mmt(x_qs, weight_qs, accum_dtype=accum_dtype)

# Offset correction. By applying the offset correction in post, it is
# set up to fuse with its consumer, which is already doing additional
Expand Down Expand Up @@ -149,8 +143,10 @@ def qlinear_tensor_scaled(


# Overrload for both bias and no bias.
linear.override(QuantizedTensor, QuantizedTensor)(qlinear_tensor_scaled)
linear.override(QuantizedTensor, QuantizedTensor, AnyTensor)(qlinear_tensor_scaled)
linear.override(QuantizedTensor, QuantizedTensor)(qlinear_tensor_scaled_integer)
linear.override(QuantizedTensor, QuantizedTensor, AnyTensor)(
qlinear_tensor_scaled_integer
)


def linear_quantized_weight(
Expand All @@ -170,30 +166,19 @@ def linear_quantized_weight(
linear.override(Tensor, QuantizedTensor, AnyTensor)(linear_quantized_weight)


def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype):
def _invoke_int32_mmt(lhs, rhs, *, accum_dtype):
if debugging.flags.use_custom_iree_kernels:
# The custom kernel requires that the lhs and rhs be the same
# rank. Broadcast the rhs to match.
lhs_rank = len(lhs.shape)
rhs_rank = len(rhs.shape)
# If input to the kernel is 2D, expand the tensor to add the batch
# dimension.
if lhs_rank == 2:
lhs_size = [1] + list(lhs.shape)
lhs = lhs.unsqueeze(0).expand(lhs_size)
lhs_rank = len(lhs.shape)
if rhs_rank < lhs_rank:
assert (rhs_rank + 1) == lhs_rank
rhs_size = [lhs.shape[0]] + list(rhs.shape)
rhs = rhs.unsqueeze(0).expand(rhs_size)
rhs_rank = len(rhs.shape)
y_qs = kernels.batch_matmul_transpose_b(
lhs.to(accum_dtype), rhs.to(accum_dtype)
)
# Squeeze the batch dimension to maintain shape parity with other
# layers.
if len(y_qs.shape) > 2:
y_qs = y_qs.squeeze(0)
else:
# FP emulation.
y_qs = torch.matmul(
Expand Down
10 changes: 5 additions & 5 deletions sharktank/tests/ops/qconv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def testInputSymPerTensor_WeightAsymPerChannel_NoBias(self):
)
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
ops.qconv_impls.qconv2d_tensor_scaled,
ops.qconv_impls.qconv2d_tensor_scaled_integer,
)
y_ref = torch.nn.functional.conv2d(
input_q.unpack().dequant(),
Expand Down Expand Up @@ -105,7 +105,7 @@ def testInputSymPerTensor_WeightAsymPerChannel_FloatBias(self):
y_actual = ops.conv2d(input_q, weight_q, bias, stride=1, padding=(1, 1))
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
ops.qconv_impls.qconv2d_tensor_scaled,
ops.qconv_impls.qconv2d_tensor_scaled_integer,
)
y_ref = torch.nn.functional.conv2d(
input_q.unpack().dequant(),
Expand Down Expand Up @@ -147,7 +147,7 @@ def testInputSymPerTensor_WeightAsymPerChannel_QuantizedBias(self):
)
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
ops.qconv_impls.qconv2d_tensor_scaled,
ops.qconv_impls.qconv2d_tensor_scaled_integer,
)
y_ref = torch.nn.functional.conv2d(
input_q.unpack().dequant(),
Expand Down Expand Up @@ -184,7 +184,7 @@ def testInputSymPerTensor_WeightSymPerTensor_NoBias(self):
)
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
ops.qconv_impls.qconv2d_tensor_scaled,
ops.qconv_impls.qconv2d_tensor_scaled_integer,
)
y_ref = torch.nn.functional.conv2d(
input_q.unpack().dequant(),
Expand Down Expand Up @@ -224,7 +224,7 @@ def testInputAsymPerChannel_WeightAsymPerChannel_NoBias(self):
)
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
ops.qconv_impls.qconv2d_tensor_scaled,
ops.qconv_impls.qconv2d_tensor_scaled_integer,
)
y_ref = torch.nn.functional.conv2d(
input_q.unpack().dequant(),
Expand Down

0 comments on commit 0022804

Please sign in to comment.