From dc58f6a22cd69fdaf00fee74e8a8b484f63c580f Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Tue, 27 Aug 2024 16:10:07 +0000 Subject: [PATCH] Add best effort validation for known dot algorithms --- stablehlo/dialect/Base.cpp | 85 ++++++ stablehlo/dialect/Base.h | 37 +++ stablehlo/dialect/TypeInference.cpp | 35 ++- stablehlo/tests/interpret/dot_general.mlir | 2 +- .../tests/ops_dot_general_algorithms.mlir | 262 ++++++++++++++++++ stablehlo/tests/ops_stablehlo.mlir | 6 +- stablehlo/tests/print_stablehlo.mlir | 8 +- 7 files changed, 409 insertions(+), 26 deletions(-) create mode 100644 stablehlo/tests/ops_dot_general_algorithms.mlir diff --git a/stablehlo/dialect/Base.cpp b/stablehlo/dialect/Base.cpp index 253d09817d3..557a101198c 100644 --- a/stablehlo/dialect/Base.cpp +++ b/stablehlo/dialect/Base.cpp @@ -21,12 +21,15 @@ limitations under the License. #include #include #include +#include #include #include "llvm/ADT/APInt.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Quant/QuantTypes.h" #include "mlir/Dialect/Shape/IR/Shape.h" @@ -42,10 +45,14 @@ limitations under the License. #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/TypeID.h" // Include order matters #include "stablehlo/dialect/BaseAttrInterfaces.cpp.inc" +#define DEBUG_TYPE "stablehlo-base" + namespace mlir { namespace hlo { @@ -624,6 +631,84 @@ bool isSplatArray(ArrayRef arr, int64_t val) { [val](int64_t x) { return x == val; }); } +namespace detail { +template +bool match(Type lhsPrecisionType, Type rhsPrecisionType, Type accumulationType, + int64_t numPrimitiveOperations) { + return isa(lhsPrecisionType) && isa(rhsPrecisionType) && + isa(accumulationType) && numPrimitiveOperations == N; +} + +FailureOr getKnownDotAlgorithm( + Type lhsPrecisionType, Type rhsPrecisionType, Type accumulationType, + int64_t lhsComponentCount, int64_t rhsComponentCount, + int64_t numPrimitiveOperations, bool allowImpreciseAccumulation) { + // Only support single component for now. + if (lhsComponentCount != 1 || rhsComponentCount != 1) return failure(); + + auto isAnyF8 = [](Type t) { + return llvm::isa(t); + }; + if (isAnyF8(lhsPrecisionType) && isAnyF8(rhsPrecisionType) && + accumulationType.isF32() && numPrimitiveOperations == 1) { + if (allowImpreciseAccumulation) + return KnownDotAlgorithm::ANY_F8_ANY_F8_F32_FAST_ACCUM; + return KnownDotAlgorithm::ANY_F8_ANY_F8_F32; + } + if (allowImpreciseAccumulation) return failure(); + + // TypeID doesn't define a `<` operator so cannot use in map. + // Use its name instead. + auto key = std::make_tuple(lhsPrecisionType.getAbstractType().getName(), + rhsPrecisionType.getAbstractType().getName(), + accumulationType.getAbstractType().getName(), + numPrimitiveOperations); + + StringRef bf16 = BFloat16Type::name; + StringRef f16 = Float16Type::name; + StringRef f32 = Float32Type::name; + StringRef f64 = Float64Type::name; + StringRef tf32 = FloatTF32Type::name; + std::map, + KnownDotAlgorithm> + knownDotAlgorithms{ + {{f16, f16, f16, 1}, KnownDotAlgorithm::F16_F16_F16}, + {{f16, f16, f32, 1}, KnownDotAlgorithm::F16_F16_F32}, + {{bf16, bf16, bf16, 1}, KnownDotAlgorithm::BF16_BF16_BF16}, + {{bf16, bf16, f32, 1}, KnownDotAlgorithm::BF16_BF16_F32}, + {{bf16, bf16, f32, 3}, KnownDotAlgorithm::BF16_BF16_F32_X3}, + {{bf16, bf16, f32, 6}, KnownDotAlgorithm::BF16_BF16_F32_X6}, + {{tf32, tf32, f32, 1}, KnownDotAlgorithm::TF32_TF32_F32}, + {{tf32, tf32, f32, 3}, KnownDotAlgorithm::TF32_TF32_F32_X3}, + {{f32, f32, f32, 1}, KnownDotAlgorithm::F32_F32_F32}, + {{f64, f64, f64, 1}, KnownDotAlgorithm::F64_F64_F64}, + }; + + auto algorithm = knownDotAlgorithms.find(key); + if (algorithm != knownDotAlgorithms.end()) { + LLVM_DEBUG(llvm::dbgs() + << "Found known dot algorithm: " + << static_cast(algorithm->second) << " " + << std::get<0>(key) << ", " << std::get<1>(key) << ", " + << std::get<2>(key) << ", " << std::get<3>(key) << "\n"); + return algorithm->second; + } + return failure(); +} +} // namespace detail + +// Check if the combination of a dot algorithm struct is known. +bool isKnownDotAlgorithm(Type lhsPrecisionType, Type rhsPrecisionType, + Type accumulationType, int64_t lhsComponentCount, + int64_t rhsComponentCount, + int64_t numPrimitiveOperations, + bool allowImpreciseAccumulation) { + return succeeded(detail::getKnownDotAlgorithm( + lhsPrecisionType, rhsPrecisionType, accumulationType, lhsComponentCount, + rhsComponentCount, numPrimitiveOperations, allowImpreciseAccumulation)); +} + mlir::Speculation::Speculatability getShapedSpeculatability( Operation* op, int64_t shapeCount) { // If all inputs are static and the shape-related operands are constant diff --git a/stablehlo/dialect/Base.h b/stablehlo/dialect/Base.h index 1441cad941f..a65d54700f9 100644 --- a/stablehlo/dialect/Base.h +++ b/stablehlo/dialect/Base.h @@ -235,6 +235,43 @@ class HloDialectInterface : public DialectInterface::Base { virtual Attribute createTypeExtensions(ArrayRef bounds) const = 0; }; +namespace detail { + +// An enum which tracks known supported dot algorithm pairs. +// Note this implementation is a detail for now and the APIs are likely to +// change once HLO broadens support for LHS/RHS components and num primitive +// operations. +// +// It is best to not rely on these values until the API solidifies. +// Instead use `isKnownDotAlgorithm`. +enum class KnownDotAlgorithm { + ANY_F8_ANY_F8_F32 = 1, + ANY_F8_ANY_F8_F32_FAST_ACCUM = 2, + F16_F16_F16 = 3, + F16_F16_F32 = 4, + BF16_BF16_BF16 = 5, + BF16_BF16_F32 = 6, + BF16_BF16_F32_X3 = 7, + BF16_BF16_F32_X6 = 8, + TF32_TF32_F32 = 9, + TF32_TF32_F32_X3 = 10, + F32_F32_F32 = 11, + F64_F64_F64 = 12, +}; + +FailureOr getKnownDotAlgorithm( + Type lhsPrecisionType, Type rhsPrecisionType, Type accumulationType, + int64_t lhsComponentCount, int64_t rhsComponentCount, + int64_t numPrimitiveOperations, bool allowImpreciseAccumulation); +} // namespace detail + +// Check if the combination of a dot algorithm struct is known. +bool isKnownDotAlgorithm(Type lhsPrecisionType, Type rhsPrecisionType, + Type accumulationType, int64_t lhsComponentCount, + int64_t rhsComponentCount, + int64_t numPrimitiveOperations, + bool allowImpreciseAccumulation); + namespace bytecode { // Helper methods for bytecode // Enum reader and writer. Many attrs have a single enum type to serialize. diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 1ef21e2ae99..15671d73e2c 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -27,7 +27,9 @@ limitations under the License. #include #include #include +#include #include +#include #include #include #include @@ -4057,24 +4059,6 @@ LogicalResult verifyDotAlgorithmAttr( Type lhsPrecisionType, Type rhsPrecisionType, Type accumulationType, int64_t lhsComponentCount, int64_t rhsComponentCount, int64_t numPrimitiveOperations, bool allowImpreciseAccumulation) { - auto isValidType = [](Type t) { - // Only support float types for now - // This can be extended as needed, as the RFC was for general support, but - // only FP hardware support exists in the ecosystem today. - return llvm::isa(t); - }; - // dot_general_i8 - if (!isValidType(lhsPrecisionType)) - return emitError() << "lhs precision type must be float"; - // dot_general_i9 - if (!isValidType(rhsPrecisionType)) - return emitError() << "rhs precision type must be float"; - // dot_general_i10 - if (!isValidType(accumulationType)) - return emitError() << "accumulation type must be float"; // dot_general_c22 if (lhsComponentCount < 1) return emitError() << "lhs component count must be positive"; @@ -4084,6 +4068,21 @@ LogicalResult verifyDotAlgorithmAttr( // dot_general_c24 if (numPrimitiveOperations < 1) return emitError() << "num primitive operations must be positive"; + + // Best effort algorithm verification, support algorithm combinations + // known to be supported on some hardware, not necessarily the target hardware + // dot_general_i8, dot_general_i9, dot_general_i10 + if (!isKnownDotAlgorithm(lhsPrecisionType, rhsPrecisionType, accumulationType, + lhsComponentCount, rhsComponentCount, + numPrimitiveOperations, allowImpreciseAccumulation)) + return emitError() + << "dot algorithm not known to be supported on any hardware: " + << "{lhs:" << lhsPrecisionType << ", rhs:" << rhsPrecisionType + << ", accum:" << accumulationType + << ", lhs_components:" << lhsComponentCount + << ", rhs_components:" << rhsComponentCount + << ", primitive_ops:" << numPrimitiveOperations + << ", imprecise:" << allowImpreciseAccumulation << "}"; return success(); } diff --git a/stablehlo/tests/interpret/dot_general.mlir b/stablehlo/tests/interpret/dot_general.mlir index 3c12c6e7d76..74b02afceb1 100644 --- a/stablehlo/tests/interpret/dot_general.mlir +++ b/stablehlo/tests/interpret/dot_general.mlir @@ -29,7 +29,7 @@ func.func @dot_general_op_test_algorithm() { algorithm = < lhs_precision_type = tf32, rhs_precision_type = tf32, - accumulation_type = tf32, + accumulation_type = f32, lhs_component_count = 1, rhs_component_count = 1, num_primitive_operations = 1, diff --git a/stablehlo/tests/ops_dot_general_algorithms.mlir b/stablehlo/tests/ops_dot_general_algorithms.mlir new file mode 100644 index 00000000000..b76c751ef3f --- /dev/null +++ b/stablehlo/tests/ops_dot_general_algorithms.mlir @@ -0,0 +1,262 @@ +// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: func @dot_algorithm_f8_f8_f32 +func.func @dot_algorithm_f8_f8_f32(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { + %0 = "stablehlo.dot_general"(%arg0, %arg1) <{ + dot_dimension_numbers = #stablehlo.dot, + precision_config = [#stablehlo, #stablehlo], + algorithm = #stablehlo.dot_algorithm< + lhs_precision_type = f8E4M3FNUZ, + rhs_precision_type = f8E4M3FNUZ, + accumulation_type = f32, + lhs_component_count = 1, + rhs_component_count = 1, + num_primitive_operations = 1, + allow_imprecise_accumulation = false + > + }> : (tensor<2x2x2xf32>, tensor<2x2x2xf32>) -> tensor<2x2x2xf32> return %0 : tensor<2x2x2xf32> +} + +// CHECK-LABEL: func @dot_algorithm_f8_f8_f32_fast_accum +func.func @dot_algorithm_f8_f8_f32_fast_accum(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { + %0 = "stablehlo.dot_general"(%arg0, %arg1) <{ + dot_dimension_numbers = #stablehlo.dot, + precision_config = [#stablehlo, #stablehlo], + algorithm = #stablehlo.dot_algorithm< + lhs_precision_type = f8E4M3FNUZ, + rhs_precision_type = f8E4M3FNUZ, + accumulation_type = f32, + lhs_component_count = 1, + rhs_component_count = 1, + num_primitive_operations = 1, + allow_imprecise_accumulation = true + > + }> : (tensor<2x2x2xf32>, tensor<2x2x2xf32>) -> tensor<2x2x2xf32> return %0 : tensor<2x2x2xf32> +} + +// CHECK-LABEL: func @dot_algorithm_f16_f16_f16 +func.func @dot_algorithm_f16_f16_f16(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { + %0 = "stablehlo.dot_general"(%arg0, %arg1) <{ + dot_dimension_numbers = #stablehlo.dot, + precision_config = [#stablehlo, #stablehlo], + algorithm = #stablehlo.dot_algorithm< + lhs_precision_type = f16, + rhs_precision_type = f16, + accumulation_type = f16, + lhs_component_count = 1, + rhs_component_count = 1, + num_primitive_operations = 1, + allow_imprecise_accumulation = false + > + }> : (tensor<2x2x2xf32>, tensor<2x2x2xf32>) -> tensor<2x2x2xf32> return %0 : tensor<2x2x2xf32> +} + +// CHECK-LABEL: func @dot_algorithm_f16_f16_f32 +func.func @dot_algorithm_f16_f16_f32(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { + %0 = "stablehlo.dot_general"(%arg0, %arg1) <{ + dot_dimension_numbers = #stablehlo.dot, + precision_config = [#stablehlo, #stablehlo], + algorithm = #stablehlo.dot_algorithm< + lhs_precision_type = f16, + rhs_precision_type = f16, + accumulation_type = f32, + lhs_component_count = 1, + rhs_component_count = 1, + num_primitive_operations = 1, + allow_imprecise_accumulation = false + > + }> : (tensor<2x2x2xf32>, tensor<2x2x2xf32>) -> tensor<2x2x2xf32> return %0 : tensor<2x2x2xf32> +} + +// CHECK-LABEL: func @dot_algorithm_bf16_bf16_bf16 +func.func @dot_algorithm_bf16_bf16_bf16(%arg0: tensor<2x2x2xbf16>, %arg1: tensor<2x2x2xbf16>) -> tensor<2x2x2xbf16> { + %0 = "stablehlo.dot_general"(%arg0, %arg1) <{ + dot_dimension_numbers = #stablehlo.dot, + precision_config = [#stablehlo, #stablehlo], + algorithm = #stablehlo.dot_algorithm< + lhs_precision_type = bf16, + rhs_precision_type = bf16, + accumulation_type = bf16, + lhs_component_count = 1, + rhs_component_count = 1, + num_primitive_operations = 1, + allow_imprecise_accumulation = false + > + }> : (tensor<2x2x2xbf16>, tensor<2x2x2xbf16>) -> tensor<2x2x2xbf16> return %0 : tensor<2x2x2xbf16> +} + +// CHECK-LABEL: func @dot_algorithm_bf16_bf16_f32 +func.func @dot_algorithm_bf16_bf16_f32(%arg0: tensor<2x2x2xbf16>, %arg1: tensor<2x2x2xbf16>) -> tensor<2x2x2xbf16> { + %0 = "stablehlo.dot_general"(%arg0, %arg1) <{ + dot_dimension_numbers = #stablehlo.dot, + precision_config = [#stablehlo, #stablehlo], + algorithm = #stablehlo.dot_algorithm< + lhs_precision_type = bf16, + rhs_precision_type = bf16, + accumulation_type = f32, + lhs_component_count = 1, + rhs_component_count = 1, + num_primitive_operations = 1, + allow_imprecise_accumulation = false + > + }> : (tensor<2x2x2xbf16>, tensor<2x2x2xbf16>) -> tensor<2x2x2xbf16> return %0 : tensor<2x2x2xbf16> +} + +// CHECK-LABEL: func @dot_algorithm_bf16_bf16_f32_x3 +func.func @dot_algorithm_bf16_bf16_f32_x3(%arg0: tensor<2x2x2xbf16>, %arg1: tensor<2x2x2xbf16>) -> tensor<2x2x2xbf16> { + %0 = "stablehlo.dot_general"(%arg0, %arg1) <{ + dot_dimension_numbers = #stablehlo.dot, + precision_config = [#stablehlo, #stablehlo], + algorithm = #stablehlo.dot_algorithm< + lhs_precision_type = bf16, + rhs_precision_type = bf16, + accumulation_type = f32, + lhs_component_count = 1, + rhs_component_count = 1, + num_primitive_operations = 3, + allow_imprecise_accumulation = false + > + }> : (tensor<2x2x2xbf16>, tensor<2x2x2xbf16>) -> tensor<2x2x2xbf16> return %0 : tensor<2x2x2xbf16> +} + +// CHECK-LABEL: func @dot_algorithm_bf16_bf16_f32_x6 +func.func @dot_algorithm_bf16_bf16_f32_x6(%arg0: tensor<2x2x2xbf16>, %arg1: tensor<2x2x2xbf16>) -> tensor<2x2x2xbf16> { + %0 = "stablehlo.dot_general"(%arg0, %arg1) <{ + dot_dimension_numbers = #stablehlo.dot, + precision_config = [#stablehlo, #stablehlo], + algorithm = #stablehlo.dot_algorithm< + lhs_precision_type = bf16, + rhs_precision_type = bf16, + accumulation_type = f32, + lhs_component_count = 1, + rhs_component_count = 1, + num_primitive_operations = 6, + allow_imprecise_accumulation = false + > + }> : (tensor<2x2x2xbf16>, tensor<2x2x2xbf16>) -> tensor<2x2x2xbf16> return %0 : tensor<2x2x2xbf16> +} + +// CHECK-LABEL: func @dot_algorithm_tf32_tf32_f32 +func.func @dot_algorithm_tf32_tf32_f32(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { + %0 = "stablehlo.dot_general"(%arg0, %arg1) <{ + dot_dimension_numbers = #stablehlo.dot, + precision_config = [#stablehlo, #stablehlo], + algorithm = #stablehlo.dot_algorithm< + lhs_precision_type = tf32, + rhs_precision_type = tf32, + accumulation_type = f32, + lhs_component_count = 1, + rhs_component_count = 1, + num_primitive_operations = 1, + allow_imprecise_accumulation = false + > + }> : (tensor<2x2x2xf32>, tensor<2x2x2xf32>) -> tensor<2x2x2xf32> return %0 : tensor<2x2x2xf32> +} + +// CHECK-LABEL: func @dot_algorithm_tf32_tf32_f32_x3 +func.func @dot_algorithm_tf32_tf32_f32_x3(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { + %0 = "stablehlo.dot_general"(%arg0, %arg1) <{ + dot_dimension_numbers = #stablehlo.dot, + precision_config = [#stablehlo, #stablehlo], + algorithm = #stablehlo.dot_algorithm< + lhs_precision_type = tf32, + rhs_precision_type = tf32, + accumulation_type = f32, + lhs_component_count = 1, + rhs_component_count = 1, + num_primitive_operations = 3, + allow_imprecise_accumulation = false + > + }> : (tensor<2x2x2xf32>, tensor<2x2x2xf32>) -> tensor<2x2x2xf32> return %0 : tensor<2x2x2xf32> +} + +// CHECK-LABEL: func @dot_algorithm_f32_f32_f32 +func.func @dot_algorithm_f32_f32_f32(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { + %0 = "stablehlo.dot_general"(%arg0, %arg1) <{ + dot_dimension_numbers = #stablehlo.dot, + precision_config = [#stablehlo, #stablehlo], + algorithm = #stablehlo.dot_algorithm< + lhs_precision_type = f32, + rhs_precision_type = f32, + accumulation_type = f32, + lhs_component_count = 1, + rhs_component_count = 1, + num_primitive_operations = 1, + allow_imprecise_accumulation = false + > + }> : (tensor<2x2x2xf32>, tensor<2x2x2xf32>) -> tensor<2x2x2xf32> return %0 : tensor<2x2x2xf32> +} + +// CHECK-LABEL: func @dot_algorithm_f64_f64_f64 +func.func @dot_algorithm_f64_f64_f64(%arg0: tensor<2x2x2xf64>, %arg1: tensor<2x2x2xf64>) -> tensor<2x2x2xf64> { + %0 = "stablehlo.dot_general"(%arg0, %arg1) <{ + dot_dimension_numbers = #stablehlo.dot, + precision_config = [#stablehlo, #stablehlo], + algorithm = #stablehlo.dot_algorithm< + lhs_precision_type = f64, + rhs_precision_type = f64, + accumulation_type = f64, + lhs_component_count = 1, + rhs_component_count = 1, + num_primitive_operations = 1, + allow_imprecise_accumulation = false + > + }> : (tensor<2x2x2xf64>, tensor<2x2x2xf64>) -> tensor<2x2x2xf64> return %0 : tensor<2x2x2xf64> +} + +// ----- + +func.func @dot_algorithm_f32_f32_f32_l3(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { + // expected-error@+4 {{dot algorithm not known to be supported on any hardware: {lhs:'f32', rhs:'f32', accum:'f32', lhs_components:3, rhs_components:1, primitive_ops:1, imprecise:0}}} + %0 = "stablehlo.dot_general"(%arg0, %arg1) <{ + dot_dimension_numbers = #stablehlo.dot, + precision_config = [#stablehlo, #stablehlo], + algorithm = #stablehlo.dot_algorithm< + lhs_precision_type = f32, + rhs_precision_type = f32, + accumulation_type = f32, + lhs_component_count = 3, + rhs_component_count = 1, + num_primitive_operations = 1, + allow_imprecise_accumulation = false + > + }> : (tensor<2x2x2xf32>, tensor<2x2x2xf32>) -> tensor<2x2x2xf32> return %0 : tensor<2x2x2xf32> +} + +// ----- + +func.func @dot_algorithm_f32_f32_f32_r3(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { + // expected-error@+4 {{dot algorithm not known to be supported on any hardware: {lhs:'f32', rhs:'f32', accum:'f32', lhs_components:1, rhs_components:3, primitive_ops:1, imprecise:0}}} + %0 = "stablehlo.dot_general"(%arg0, %arg1) <{ + dot_dimension_numbers = #stablehlo.dot, + precision_config = [#stablehlo, #stablehlo], + algorithm = #stablehlo.dot_algorithm< + lhs_precision_type = f32, + rhs_precision_type = f32, + accumulation_type = f32, + lhs_component_count = 1, + rhs_component_count = 3, + num_primitive_operations = 1, + allow_imprecise_accumulation = false + > + }> : (tensor<2x2x2xf32>, tensor<2x2x2xf32>) -> tensor<2x2x2xf32> return %0 : tensor<2x2x2xf32> +} + +// ----- + +func.func @dot_algorithm_f32_f32_f32_imprecise(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { + // expected-error@+4 {{dot algorithm not known to be supported on any hardware: {lhs:'f32', rhs:'f32', accum:'f32', lhs_components:1, rhs_components:1, primitive_ops:1, imprecise:1}}} + %0 = "stablehlo.dot_general"(%arg0, %arg1) <{ + dot_dimension_numbers = #stablehlo.dot, + precision_config = [#stablehlo, #stablehlo], + algorithm = #stablehlo.dot_algorithm< + lhs_precision_type = f32, + rhs_precision_type = f32, + accumulation_type = f32, + lhs_component_count = 1, + rhs_component_count = 1, + num_primitive_operations = 1, + allow_imprecise_accumulation = true + > + }> : (tensor<2x2x2xf32>, tensor<2x2x2xf32>) -> tensor<2x2x2xf32> return %0 : tensor<2x2x2xf32> +} diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index 1cf97858c94..ab49e174622 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -3334,7 +3334,7 @@ func.func @dot_general_c24(%arg0: tensor<2x2x2xi64>, %arg1: tensor<2x2x2xi64>) - // ----- func.func @dot_general_i8(%arg0: tensor<2x2x2xi64>, %arg1: tensor<2x2x2xi64>) -> tensor<2x2x2xi64> { - // expected-error @+3 {{lhs precision type must be float}} + // expected-error @+3 {{dot algorithm not known to be supported on any hardware}} %0 = "stablehlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #stablehlo.dot, algorithm = #stablehlo.dot_algorithm @@ -3344,7 +3344,7 @@ func.func @dot_general_i8(%arg0: tensor<2x2x2xi64>, %arg1: tensor<2x2x2xi64>) -> // ----- func.func @dot_general_i9(%arg0: tensor<2x2x2xi64>, %arg1: tensor<2x2x2xi64>) -> tensor<2x2x2xi64> { - // expected-error @+3 {{rhs precision type must be float}} + // expected-error @+3 {{dot algorithm not known to be supported on any hardware}} %0 = "stablehlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #stablehlo.dot, algorithm = #stablehlo.dot_algorithm @@ -3354,7 +3354,7 @@ func.func @dot_general_i9(%arg0: tensor<2x2x2xi64>, %arg1: tensor<2x2x2xi64>) -> // ----- func.func @dot_general_i10(%arg0: tensor<2x2x2xi64>, %arg1: tensor<2x2x2xi64>) -> tensor<2x2x2xi64> { - // expected-error @+3 {{accumulation type must be float}} + // expected-error @+3 {{dot algorithm not known to be supported on any hardware}} %0 = "stablehlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #stablehlo.dot, algorithm = #stablehlo.dot_algorithm diff --git a/stablehlo/tests/print_stablehlo.mlir b/stablehlo/tests/print_stablehlo.mlir index c20bc1a1584..bb4b711c851 100644 --- a/stablehlo/tests/print_stablehlo.mlir +++ b/stablehlo/tests/print_stablehlo.mlir @@ -315,8 +315,8 @@ func.func @dot_general(%arg0: tensor<2x2x2xi8>, %arg1: tensor<2x2x3xi8>, %arg2: // CHECK-NEXT: {{%.*}} = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x2x2xi8>, tensor<2x2x3xi8>) -> tensor<2x2x3xi32> // CHECK-NEXT: {{%.*}} = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] : (tensor<2x2xi8>, tensor<2x3xi8>) -> tensor<2x3xi32> // CHECK-NEXT: {{%.*}} = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xi8>, tensor<2x3xi8>) -> tensor<2x3xi32> - // CHECK-NEXT: {{%.*}} = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT], algorithm = : (tensor<2x2xi8>, tensor<2x3xi8>) -> tensor<2x3xi32> - // CHECK-NEXT: {{%.*}} = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0], algorithm = : (tensor<2x2xi8>, tensor<2x3xi8>) -> tensor<2x3xi32> + // CHECK-NEXT: {{%.*}} = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT], algorithm = : (tensor<2x2xi8>, tensor<2x3xi8>) -> tensor<2x3xi32> + // CHECK-NEXT: {{%.*}} = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0], algorithm = : (tensor<2x2xi8>, tensor<2x3xi8>) -> tensor<2x3xi32> %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< lhs_batching_dimensions = [0], @@ -362,7 +362,7 @@ func.func @dot_general(%arg0: tensor<2x2x2xi8>, %arg1: tensor<2x2x3xi8>, %arg2: algorithm = #stablehlo.dot_algorithm< lhs_precision_type = tf32, rhs_precision_type = tf32, - accumulation_type = tf32, + accumulation_type = f32, lhs_component_count = 1, rhs_component_count = 1, num_primitive_operations = 1, @@ -379,7 +379,7 @@ func.func @dot_general(%arg0: tensor<2x2x2xi8>, %arg1: tensor<2x2x3xi8>, %arg2: algorithm = #stablehlo.dot_algorithm< lhs_precision_type = tf32, rhs_precision_type = tf32, - accumulation_type = tf32, + accumulation_type = f32, lhs_component_count = 1, rhs_component_count = 1, num_primitive_operations = 1,