Skip to content

Commit

Permalink
Interpreter support for quantized type (#2388)
Browse files Browse the repository at this point in the history
fixes #2373

The PR is rebased on top on
#2383 and cherry-pick changes
from #2384.

### Direction to reviewer

Please review the commit
4d7dc1a
**excluding** the following files
  - docs/generated/stablehlo_passes.md
  - stablehlo/transforms/Passes.td
  - stablehlo/transforms/ShapeLegalizeToStablehlo.cpp
  • Loading branch information
sdasgup3 authored Jun 20, 2024
1 parent 91ed649 commit 6802016
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 2 deletions.
67 changes: 65 additions & 2 deletions stablehlo/reference/Api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,19 +170,82 @@ LogicalResult removeDynamism(ModuleOp module, func::FuncOp func,
return success();
}

bool isAnyQuantizedTypes(TypeRange types) {
return llvm::any_of(types, [](Type type) {
return isa<quant::QuantizedType>(getElementTypeOrSelf(type));
});
}

// Recursively checks if an operation or any of its nested operations use
// quantized types.
//
// Args:
// op: The operation to check for quantized type usage.
//
// Returns:
// True if the operation or any nested operation uses quantized types,
// false otherwise.
bool funcUsesQuantType(func::FuncOp func_op) {
bool usesQuantizedType = false;

func_op.walk([&](Operation *op) {
if (isAnyQuantizedTypes(op->getOperandTypes()) ||
isAnyQuantizedTypes(op->getResultTypes())) {
usesQuantizedType = true;
return WalkResult::interrupt();
}
return WalkResult::advance();
});

return usesQuantizedType;
}

// Lowers quantization-related operations and types within a function to
// primitive math operations.
//
// This function checks if a function uses quantized types in its inputs,
// outputs, or internal operations. If so, it creates and runs a StableHLO
// quantization lowering pipeline to transform those quantized constructs into
// primitive math operations. If the lowering process fails, an error is
// emitted.
//
// Args:
// module: The module containing the function `func`.
// func The function to lower quantized types/operations in.
//
// Returns:
// A `LogicalResult` indicating success or failure of the lowering process.
LogicalResult lowerQuantization(ModuleOp module, func::FuncOp func) {
if (!(isAnyQuantizedTypes(func.getFunctionType().getInputs()) ||
funcUsesQuantType(func) ||
isAnyQuantizedTypes(func.getFunctionType().getResults()))) {
return success();
}

PassManager pm(func.getContext());
stablehlo::createStablehloLowerQuantPipeline(pm);
if (failed(pm.run(module))) {
return func.emitError("Failed to lower quantized types/ops in function: ")
<< func.getName();
}
module.dump();
return success();
}

} // namespace

FailureOr<SmallVector<InterpreterValue>> evalModule(
ModuleOp module, ArrayRef<InterpreterValue> inputs,
const InterpreterConfiguration &config) {
// Additional error checking at main function boundary.
// This is most likely user error, where future errors during interpreting are
// more likely invalid IR or interpreter bugs.
// This is most likely user error, where future errors during interpreting
// are more likely invalid IR or interpreter bugs.
if (module.getOps<func::FuncOp>().empty())
return SmallVector<InterpreterValue>();

auto mainFunc = getMainFunction(module, config.mainFunction);
if (failed(mainFunc) || failed(removeDynamism(module, *mainFunc, inputs)) ||
failed(lowerQuantization(module, *mainFunc)) ||
failed(validateEntrySignature(*mainFunc, inputs))) {
return failure();
}
Expand Down
54 changes: 54 additions & 0 deletions stablehlo/tests/interpret/quantized_ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// RUN: stablehlo-translate --interpret -split-input-file %s

func.func @uniform_quantize() {
%operand = stablehlo.constant dense<[4.0, 15.0]> : tensor<2xf32>
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
%bitcast_result = "stablehlo.bitcast_convert"(%result) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xi8>
check.expect_eq_const %bitcast_result, dense<[10, 10]> : tensor<2xi8>
func.return
}

// -----

func.func @uniform_quantize() {
%operand = stablehlo.constant dense<[10, 10]> : tensor<2xi8>
%bitcast_operand = "stablehlo.bitcast_convert"(%operand) : (tensor<2xi8>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
%result = "stablehlo.uniform_quantize"(%bitcast_operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
%bitcast_result = "stablehlo.bitcast_convert"(%result) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>) -> tensor<2xi8>
check.expect_eq_const %bitcast_result, dense<[20, 45]> : tensor<2xi8>
func.return
}

// -----

func.func @uniform_dequantize() {
%operand = stablehlo.constant dense<[10, 10]> : tensor<2xi8>
%bitcast_operand = "stablehlo.bitcast_convert"(%operand) : (tensor<2xi8>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
%result = "stablehlo.uniform_dequantize"(%bitcast_operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
check.expect_almost_eq_const %result, dense<[4.0, 15.0]> : tensor<2xf32>
func.return
}


// -----

func.func @uniform_qdq() {
%operand = stablehlo.constant dense<[4.0, 15.0]> : tensor<2xf32>
%quantize = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
%result = "stablehlo.uniform_dequantize"(%quantize) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
check.expect_almost_eq_const %result, dense<[4.0, 15.0]> : tensor<2xf32>
func.return
}

// -----

func.func @quantized_add() {
%operand1 = stablehlo.constant dense<[1.0, 2.0]> : tensor<2xf32>
%operand2 = stablehlo.constant dense<[3.0, 4.0]> : tensor<2xf32>
%q_operand1 = "stablehlo.uniform_quantize"(%operand1) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 0.1:-30>>
%q_operand2 = "stablehlo.uniform_quantize"(%operand2) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 0.5:-20>>
%result = "stablehlo.add"(%q_operand1, %q_operand2) : (tensor<2x!quant.uniform<i8:f32, 0.1:-30>>, tensor<2x!quant.uniform<i8:f32, 0.5:-20>>) -> tensor<2x!quant.uniform<i8:f32, 0.5:-20>>
%bitcast_result = "stablehlo.bitcast_convert"(%result) : (tensor<2x!quant.uniform<i8:f32, 0.5:-20>>) -> tensor<2xi8>
check.expect_eq_const %bitcast_result, dense<[-12, -8]> : tensor<2xi8>
func.return
}
1 change: 1 addition & 0 deletions stablehlo/tools/StablehloTranslateMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ TranslateFromMLIRRegistration interpretRegistration(
},
[](DialectRegistry &registry) {
registry.insert<func::FuncDialect>();
registry.insert<quant::QuantizationDialect>();
registry.insert<stablehlo::check::CheckDialect>();
registry.insert<stablehlo::interpreter::InterpreterDialect>();
registry.insert<stablehlo::StablehloDialect>();
Expand Down
16 changes: 16 additions & 0 deletions stablehlo/transforms/PassPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ limitations under the License.
==============================================================================*/

#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "stablehlo/dialect/Version.h"
#include "stablehlo/transforms/Passes.h"

Expand All @@ -36,6 +37,21 @@ void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
stablehlo::createStablehloCanonicalizeDynamismPass());
}

void createStablehloLowerQuantPipeline(OpPassManager &pm) {
pm.addNestedPass<mlir::func::FuncOp>(
stablehlo::createStablehloLegalizeQuantizedOpToQDQPass());
pm.addNestedPass<mlir::func::FuncOp>(
stablehlo::createStablehloLegalizeQuantToIntPass());
pm.addNestedPass<mlir::func::FuncOp>(
stablehlo::createChloLegalizeToStablehloPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<mlir::func::FuncOp>(
stablehlo::createShapeLegalizeToStablehloPass());
pm.addNestedPass<mlir::func::FuncOp>(
stablehlo::createStablehloCanonicalizeDynamismPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
}

void registerPassPipelines() {
PassPipelineRegistration<>("stablehlo-deserialize",
"Run an example pipeline.",
Expand Down
5 changes: 5 additions & 0 deletions stablehlo/transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ void createStablehloDeserializePipeline(OpPassManager &pm);
void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
TypeRange refinedTypes);

// Decomposes quantized operations within a StableHLO module by
// applying a series of MLIR passes essentially breaking down the quantized
// operations into a primitive math operations.
void createStablehloLowerQuantPipeline(OpPassManager &pm);

// Adds `stablehlo-deserialize` pipeline as a registered pass pipeline
// for opt tools.
void registerPassPipelines();
Expand Down

0 comments on commit 6802016

Please sign in to comment.