Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add best effort validation for known dot algorithms #2511

Merged
merged 1 commit into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ limitations under the License.
#include <cstdint>
#include <functional>
#include <optional>
#include <tuple>
#include <utility>

#include "llvm/ADT/APInt.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
Expand All @@ -42,10 +45,14 @@ limitations under the License.
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/TypeID.h"

// Include order matters
#include "stablehlo/dialect/BaseAttrInterfaces.cpp.inc"

#define DEBUG_TYPE "stablehlo-base"

namespace mlir {
namespace hlo {

Expand Down Expand Up @@ -624,6 +631,84 @@ bool isSplatArray(ArrayRef<int64_t> arr, int64_t val) {
[val](int64_t x) { return x == val; });
}

namespace detail {
template <typename LHS, typename RHS, typename Accum, int64_t N>
bool match(Type lhsPrecisionType, Type rhsPrecisionType, Type accumulationType,
int64_t numPrimitiveOperations) {
return isa<LHS>(lhsPrecisionType) && isa<RHS>(rhsPrecisionType) &&
isa<Accum>(accumulationType) && numPrimitiveOperations == N;
}

FailureOr<KnownDotAlgorithm> getKnownDotAlgorithm(
Type lhsPrecisionType, Type rhsPrecisionType, Type accumulationType,
int64_t lhsComponentCount, int64_t rhsComponentCount,
int64_t numPrimitiveOperations, bool allowImpreciseAccumulation) {
// Only support single component for now.
if (lhsComponentCount != 1 || rhsComponentCount != 1) return failure();

auto isAnyF8 = [](Type t) {
return llvm::isa<Float8E4M3FNType, Float8E5M2Type, Float8E4M3FNUZType,
Float8E4M3B11FNUZType, Float8E5M2FNUZType>(t);
};
if (isAnyF8(lhsPrecisionType) && isAnyF8(rhsPrecisionType) &&
accumulationType.isF32() && numPrimitiveOperations == 1) {
if (allowImpreciseAccumulation)
return KnownDotAlgorithm::ANY_F8_ANY_F8_F32_FAST_ACCUM;
return KnownDotAlgorithm::ANY_F8_ANY_F8_F32;
}
if (allowImpreciseAccumulation) return failure();

// TypeID doesn't define a `<` operator so cannot use in map.
// Use its name instead.
auto key = std::make_tuple(lhsPrecisionType.getAbstractType().getName(),
rhsPrecisionType.getAbstractType().getName(),
accumulationType.getAbstractType().getName(),
numPrimitiveOperations);

StringRef bf16 = BFloat16Type::name;
StringRef f16 = Float16Type::name;
StringRef f32 = Float32Type::name;
StringRef f64 = Float64Type::name;
StringRef tf32 = FloatTF32Type::name;
std::map<std::tuple<StringRef, StringRef, StringRef, int64_t>,
KnownDotAlgorithm>
knownDotAlgorithms{
{{f16, f16, f16, 1}, KnownDotAlgorithm::F16_F16_F16},
{{f16, f16, f32, 1}, KnownDotAlgorithm::F16_F16_F32},
{{bf16, bf16, bf16, 1}, KnownDotAlgorithm::BF16_BF16_BF16},
{{bf16, bf16, f32, 1}, KnownDotAlgorithm::BF16_BF16_F32},
{{bf16, bf16, f32, 3}, KnownDotAlgorithm::BF16_BF16_F32_X3},
{{bf16, bf16, f32, 6}, KnownDotAlgorithm::BF16_BF16_F32_X6},
{{tf32, tf32, f32, 1}, KnownDotAlgorithm::TF32_TF32_F32},
{{tf32, tf32, f32, 3}, KnownDotAlgorithm::TF32_TF32_F32_X3},
{{f32, f32, f32, 1}, KnownDotAlgorithm::F32_F32_F32},
{{f64, f64, f64, 1}, KnownDotAlgorithm::F64_F64_F64},
};

auto algorithm = knownDotAlgorithms.find(key);
if (algorithm != knownDotAlgorithms.end()) {
LLVM_DEBUG(llvm::dbgs()
<< "Found known dot algorithm: "
<< static_cast<int64_t>(algorithm->second) << " "
<< std::get<0>(key) << ", " << std::get<1>(key) << ", "
<< std::get<2>(key) << ", " << std::get<3>(key) << "\n");
return algorithm->second;
}
return failure();
}
} // namespace detail

// Check if the combination of a dot algorithm struct is known.
bool isKnownDotAlgorithm(Type lhsPrecisionType, Type rhsPrecisionType,
Type accumulationType, int64_t lhsComponentCount,
int64_t rhsComponentCount,
int64_t numPrimitiveOperations,
bool allowImpreciseAccumulation) {
return succeeded(detail::getKnownDotAlgorithm(
lhsPrecisionType, rhsPrecisionType, accumulationType, lhsComponentCount,
rhsComponentCount, numPrimitiveOperations, allowImpreciseAccumulation));
}

mlir::Speculation::Speculatability getShapedSpeculatability(
Operation* op, int64_t shapeCount) {
// If all inputs are static and the shape-related operands are constant
Expand Down
37 changes: 37 additions & 0 deletions stablehlo/dialect/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,43 @@ class HloDialectInterface : public DialectInterface::Base<HloDialectInterface> {
virtual Attribute createTypeExtensions(ArrayRef<int64_t> bounds) const = 0;
};

namespace detail {

// An enum which tracks known supported dot algorithm pairs.
// Note this implementation is a detail for now and the APIs are likely to
// change once HLO broadens support for LHS/RHS components and num primitive
// operations.
//
// It is best to not rely on these values until the API solidifies.
// Instead use `isKnownDotAlgorithm`.
enum class KnownDotAlgorithm {
ANY_F8_ANY_F8_F32 = 1,
ANY_F8_ANY_F8_F32_FAST_ACCUM = 2,
F16_F16_F16 = 3,
F16_F16_F32 = 4,
BF16_BF16_BF16 = 5,
BF16_BF16_F32 = 6,
BF16_BF16_F32_X3 = 7,
BF16_BF16_F32_X6 = 8,
TF32_TF32_F32 = 9,
TF32_TF32_F32_X3 = 10,
F32_F32_F32 = 11,
F64_F64_F64 = 12,
};

FailureOr<KnownDotAlgorithm> getKnownDotAlgorithm(
Type lhsPrecisionType, Type rhsPrecisionType, Type accumulationType,
int64_t lhsComponentCount, int64_t rhsComponentCount,
int64_t numPrimitiveOperations, bool allowImpreciseAccumulation);
} // namespace detail

// Check if the combination of a dot algorithm struct is known.
bool isKnownDotAlgorithm(Type lhsPrecisionType, Type rhsPrecisionType,
Type accumulationType, int64_t lhsComponentCount,
int64_t rhsComponentCount,
int64_t numPrimitiveOperations,
bool allowImpreciseAccumulation);

namespace bytecode {
// Helper methods for bytecode
// Enum reader and writer. Many attrs have a single enum type to serialize.
Expand Down
35 changes: 17 additions & 18 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ limitations under the License.
#include <iterator>
#include <numeric>
#include <optional>
#include <set>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -4057,24 +4059,6 @@ LogicalResult verifyDotAlgorithmAttr(
Type lhsPrecisionType, Type rhsPrecisionType, Type accumulationType,
int64_t lhsComponentCount, int64_t rhsComponentCount,
int64_t numPrimitiveOperations, bool allowImpreciseAccumulation) {
auto isValidType = [](Type t) {
// Only support float types for now
// This can be extended as needed, as the RFC was for general support, but
// only FP hardware support exists in the ecosystem today.
return llvm::isa<FloatTF32Type, Float8E4M3FNType, Float8E5M2Type,
Float8E4M3FNUZType, Float8E4M3B11FNUZType,
Float8E5M2FNUZType, BFloat16Type, Float16Type, Float32Type,
Float64Type>(t);
};
// dot_general_i8
if (!isValidType(lhsPrecisionType))
return emitError() << "lhs precision type must be float";
// dot_general_i9
if (!isValidType(rhsPrecisionType))
return emitError() << "rhs precision type must be float";
// dot_general_i10
if (!isValidType(accumulationType))
return emitError() << "accumulation type must be float";
// dot_general_c22
if (lhsComponentCount < 1)
return emitError() << "lhs component count must be positive";
Expand All @@ -4084,6 +4068,21 @@ LogicalResult verifyDotAlgorithmAttr(
// dot_general_c24
if (numPrimitiveOperations < 1)
return emitError() << "num primitive operations must be positive";

// Best effort algorithm verification, support algorithm combinations
// known to be supported on some hardware, not necessarily the target hardware
// dot_general_i8, dot_general_i9, dot_general_i10
if (!isKnownDotAlgorithm(lhsPrecisionType, rhsPrecisionType, accumulationType,
lhsComponentCount, rhsComponentCount,
numPrimitiveOperations, allowImpreciseAccumulation))
return emitError()
<< "dot algorithm not known to be supported on any hardware: "
<< "{lhs:" << lhsPrecisionType << ", rhs:" << rhsPrecisionType
<< ", accum:" << accumulationType
<< ", lhs_components:" << lhsComponentCount
<< ", rhs_components:" << rhsComponentCount
<< ", primitive_ops:" << numPrimitiveOperations
<< ", imprecise:" << allowImpreciseAccumulation << "}";
return success();
}

Expand Down
2 changes: 1 addition & 1 deletion stablehlo/tests/interpret/dot_general.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func.func @dot_general_op_test_algorithm() {
algorithm = <
lhs_precision_type = tf32,
GleasonK marked this conversation as resolved.
Show resolved Hide resolved
rhs_precision_type = tf32,
accumulation_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
Expand Down
Loading
Loading