Skip to content

Commit

Permalink
Add fp8 quantization for conv and linear layers
Browse files Browse the repository at this point in the history
  • Loading branch information
nithinsubbiah committed Oct 14, 2024
1 parent b55065a commit df87b06
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 9 deletions.
3 changes: 3 additions & 0 deletions sharktank/sharktank/kernels/conv_2d_nchw_fchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
(torch.int16, torch.int16, "torch.int16"): torch.int16,
(torch.int16, torch.int16, "torch.int32"): torch.int32,
# Legal fp types.
(torch.float8_e4m3fnuz, torch.float8_e4m3fnuz, "torch.float16"): torch.float16,
(torch.float8_e4m3fnuz, torch.float8_e4m3fnuz, "torch.float32"): torch.float32,
(torch.float16, torch.float16, "torch.float16"): torch.float16,
(torch.float16, torch.float16, "torch.float32"): torch.float32,
(torch.float32, torch.float32, "torch.float32"): torch.float32,
Expand All @@ -33,6 +35,7 @@
torch.int8: "i8",
torch.int16: "i16",
torch.int32: "i32",
torch.float8_e4m3fnuz: "f8E4M3FNUZ",
torch.float16: "f16",
torch.float32: "f32",
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ def quantize_bias(
bias_scale = 1.0 / (input_scale * weight_scale)
bias_quantizer = StaticScaledQuantizer(
scale=bias_scale,
dtype=torch.int32,
# dtype=torch.int32,
dtype=torch.float16, # TODO: Nithin
disable_saturate=True,
)
bias_quant = bias_quantizer.quantize(bias, name=bias_name)
Expand Down Expand Up @@ -287,7 +288,8 @@ def quantize_bias(
else:
# Unfused.
quantize_weight(weight.name, weight, weight_scale, weight_zp)
if bias is not None:
# if bias is not None and quantization_dtype == torch.int8:
if bias is not None: # TODO: Nithin
quantize_bias(bias.name, bias, input_scale, weight_scale)

# Input scaling.
Expand Down
11 changes: 7 additions & 4 deletions sharktank/sharktank/ops/qconv_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def qconv2d_tensor_scaled_integer(
input_layout: TensorScaledLayout = input.unpack()
weight_layout: TensorScaledLayout = weight.unpack()

# Only handle integer quantizations.
# # Handle integer and fp8 quantizations.
if (
input_layout.qs.dtype.is_floating_point
or weight_layout.qs.dtype.is_floating_point
input_layout.qs.dtype != torch.float8_e4m3fnuz
and weight_layout.qs.dtype != torch.float8_e4m3fnuz
):
return NotImplemented

Expand All @@ -85,7 +85,10 @@ def qconv2d_tensor_scaled_integer(

# Alias components (d=scale, qs=quantized samples, m=offset).
if accum_dtype is None:
accum_dtype = torch.int32
if weight_layout.qs.dtype.is_floating_point:
accum_dtype = torch.float32
else:
accum_dtype = torch.int32
input_d = input_layout.d
input_dtype = input_layout.dtype
input_qs = input_layout.qs
Expand Down
21 changes: 18 additions & 3 deletions sharktank/sharktank/ops/qlinear_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,15 @@ def qlinear_tensor_scaled_integer(
x_layout: TensorScaledLayout = x.unpack()
weight_layout: TensorScaledLayout = weight.unpack()

# Only handle integer quantizations.
if x_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point:
# # Only handle integer quantizations.
# if x_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point:
# return NotImplemented

# Handle integer and fp8 quantizations.
if (
x_layout.qs.dtype != torch.float8_e4m3fnuz
and weight_layout.qs.dtype != torch.float8_e4m3fnuz
):
return NotImplemented

# Bias.
Expand All @@ -64,7 +71,10 @@ def qlinear_tensor_scaled_integer(

# Alias components (d=scale, qs=quantized samples, m=offset)
if accum_dtype is None:
accum_dtype = torch.int32
if weight_layout.qs.dtype.is_floating_point:
accum_dtype = torch.float32
else:
accum_dtype = torch.int32
x_d = x_layout.d
x_dtype = x_layout.dtype
x_qs = x_layout.qs
Expand Down Expand Up @@ -172,10 +182,15 @@ def _invoke_int32_mmt(lhs, rhs, *, accum_dtype):
# rank. Broadcast the rhs to match.
lhs_rank = len(lhs.shape)
rhs_rank = len(rhs.shape)
if lhs_rank == 2:
lhs_size = [1] + list(lhs.shape)
lhs = lhs.unsqueeze(0).expand(lhs_size)
lhs_rank = len(lhs.shape)
if rhs_rank < lhs_rank:
assert (rhs_rank + 1) == lhs_rank
rhs_size = [lhs.shape[0]] + list(rhs.shape)
rhs = rhs.unsqueeze(0).expand(rhs_size)
rhs_rank = len(rhs.shape)
y_qs = kernels.batch_matmul_transpose_b(
lhs.to(accum_dtype), rhs.to(accum_dtype)
)
Expand Down

0 comments on commit df87b06

Please sign in to comment.