diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td index 712826d02f91..6bcca9ec0d5b 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td @@ -76,4 +76,24 @@ def TTC_PtrToMemRefOp : TTC_Op<"ptr_to_memref", [NoMemoryEffect]> { let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; } +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; + +def TTC_PrintOp : TTC_Op<"print", [MemoryEffects<[MemWrite]>]> { + let summary = "Print at most a single scalar or vector (converted from tensor) on each line"; + + let description = [{ + For converting tensor types to vector types. + It only takes a single scalar or vector (tensor) element. + }]; + + let arguments = (ins StrAttr:$prefix, BoolAttr:$hex, + Variadic>:$val); + + let assemblyFormat = [{ + $prefix attr-dict (`:` $val^ `:` type($val))? + }]; + + let hasVerifier = 1; +} + #endif diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td index ea31f877dab3..4bd64213db4b 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td @@ -1,6 +1,7 @@ #ifndef TRITONCPU_TYPES #define TRITONCPU_TYPES +include "triton/Dialect/Triton/IR/TritonTypes.td" include "triton/Dialect/TritonCPU/IR/TritonCPUDialect.td" include "mlir/IR/AttrTypeBase.td" @@ -23,4 +24,6 @@ def TTC_TokenType : TTC_TypeDef<"Token", "token"> { let skipDefaultBuilders = 1; } +def TTC_Vector : VectorOf<[TT_Float, TT_Int]>; + #endif diff --git a/lib/Dialect/TritonCPU/IR/CMakeLists.txt b/lib/Dialect/TritonCPU/IR/CMakeLists.txt index 67bf1bb1b9d4..c0b6f0f7be24 100644 --- a/lib/Dialect/TritonCPU/IR/CMakeLists.txt +++ b/lib/Dialect/TritonCPU/IR/CMakeLists.txt @@ -1,5 +1,6 @@ add_triton_library(TritonCPUIR Dialect.cpp + Ops.cpp Types.cpp DEPENDS diff --git a/lib/Dialect/TritonCPU/IR/Dialect.cpp b/lib/Dialect/TritonCPU/IR/Dialect.cpp index acd31c07290f..41a4c62bda45 100644 --- a/lib/Dialect/TritonCPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonCPU/IR/Dialect.cpp @@ -67,9 +67,6 @@ void TritonCPUDialect::initialize() { >(); } -#define GET_OP_CLASSES -#include "triton/Dialect/TritonCPU/IR/Ops.cpp.inc" - // verify TritonCPU ops LogicalResult TritonCPUDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { diff --git a/lib/Dialect/TritonCPU/IR/Ops.cpp b/lib/Dialect/TritonCPU/IR/Ops.cpp new file mode 100644 index 000000000000..d626ce3902a9 --- /dev/null +++ b/lib/Dialect/TritonCPU/IR/Ops.cpp @@ -0,0 +1,18 @@ +#include "mlir/IR/Builders.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonCPU/IR/Ops.cpp.inc" + +// enum attribute definitions +#include "triton/Dialect/TritonCPU/IR/OpsEnums.cpp.inc" + +namespace mlir::triton::cpu { + +LogicalResult PrintOp::verify() { + if (getOperands().size() > 1) + return emitOpError("expects at most one operand"); + return success(); +} + +} // namespace mlir::triton::cpu diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index d6f1a6d82a83..cf73ce52b0b2 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -106,6 +106,7 @@ def make_ttcir(mod, metadata, opt): cpu.passes.ttcpuir.add_convert_scan_op(pm) cpu.passes.ttcpuir.add_convert_cf_ops(pm) cpu.passes.ttcpuir.add_convert_atomic_ops(pm) + cpu.passes.ttcpuir.add_convert_debug_ops(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) passes.common.add_canonicalizer(pm) diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h index 303b99ce3c43..699518361e01 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -28,6 +28,7 @@ std::unique_ptr> createConvertHistogramOp(); std::unique_ptr> createConvertReductionOp(); std::unique_ptr> createConvertScanOp(); std::unique_ptr> createConvertAtomicOps(); +std::unique_ptr> createConvertDebugOps(); #define GEN_PASS_REGISTRATION #include "cpu/include/TritonToTritonCPU/Passes.h.inc" diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td index dfac926a9f5b..612ce135cc65 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.td +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -142,4 +142,17 @@ def ConvertAtomicOps : Pass<"triton-cpu-convert-atomic-ops", "mlir::ModuleOp"> { "mlir::triton::cpu::TritonCPUDialect"]; } +def ConvertDebugOps : Pass<"triton-cpu-convert-debug-ops", "mlir::ModuleOp"> { + let summary = "Convert Triton debug operations."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertDebugOps()"; + + let dependentDialects = ["mlir::vector::VectorDialect", + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + #endif diff --git a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp index 2bad397c9b77..c60da23b765a 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp @@ -4,6 +4,7 @@ #include "cpu/include/TritonCPUToLLVM/Passes.h" #include "mlir/Dialect/GPU/IR/GPUOps.h.inc" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -31,54 +32,6 @@ class TritonLLVMConversionTarget : public ConversionTarget { } }; -// The code for the print is similar to the GPU's TargetInfo.cpp. -LLVM::LLVMFuncOp getPrintfDeclaration(ConversionPatternRewriter &rewriter) { - auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); - StringRef funcName("printf"); - Operation *funcOp = moduleOp.lookupSymbol(funcName); - if (funcOp) - return cast(*funcOp); - - auto *context = rewriter.getContext(); - - // int printf(char* format, ...) - SmallVector argsType{ptr_ty(context)}; - auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType, true); - - ConversionPatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(moduleOp.getBody()); - - auto op = rewriter.create(UnknownLoc::get(context), - funcName, funcType); - return op; -} - -void emitPrintf(ConversionPatternRewriter &rewriter, Value formatStrStart, - int /*formatStrByteCount*/, ValueRange args) { - auto loc = UnknownLoc::get(rewriter.getContext()); - SmallVector formatStrAndArgs{formatStrStart}; - for (auto arg : args) { - formatStrAndArgs.push_back(arg); - } - call(getPrintfDeclaration(rewriter), formatStrAndArgs); -} - -Value llPrintf(StringRef msg, ValueRange args, - ConversionPatternRewriter &rewriter, - int *formatStrByteCount = nullptr) { - assert(!msg.empty() && "printf with empty string not supported"); - llvm::SmallString<64> msgNewline(msg); - msgNewline.push_back('\n'); - msgNewline.push_back('\0'); - Value msgValue = - LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), rewriter, - "printfFormat_", msgNewline); - emitPrintf(rewriter, msgValue, msgNewline.size_in_bytes(), args); - if (formatStrByteCount) - *formatStrByteCount = msgNewline.size_in_bytes(); - return msgValue; -} - // TODO: This code is the same as the GPU-backend code. Consider refactoring. std::string getFormatSubstr(Value value, bool hex = false, std::optional width = std::nullopt) { @@ -123,44 +76,139 @@ std::string getFormatSubstr(Value value, bool hex = false, return ""; } -// TritonCPU's device_print prints all values in the same line unlike GPUs -// and interpreter where each value is printed in a separate line. -struct PrintOpConversion : public ConvertOpToLLVMPattern { - explicit PrintOpConversion(LLVMTypeConverter &typeConverter) - : mlir::ConvertOpToLLVMPattern(typeConverter) {} +LLVM::LLVMFuncOp getPrintFuncDecl(ConversionPatternRewriter &rewriter, + bool printf) { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + StringRef funcName = printf ? "printf" : "triton_vector_print"; + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + + auto *ctx = rewriter.getContext(); + SmallVector argsType; + if (printf) + argsType = {ptr_ty(ctx)}; + else + argsType = {i32_ty, i32_ty, i32_ty, ptr_ty(ctx), + ptr_ty(ctx), i32_ty, i32_ty, i64_ty}; + + auto funcType = + LLVM::LLVMFunctionType::get(i32_ty, argsType, /*isVarArg*/ printf); + + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + return rewriter.create(UnknownLoc::get(ctx), funcName, + funcType); +} + +void llPrintf(StringRef prefix, std::array pid, + std::optional arg, ConversionPatternRewriter &rewriter, + bool hex = false) { + assert(!prefix.empty() && "printf with empty string not supported"); + auto loc = UnknownLoc::get(rewriter.getContext()); + + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + os << "(" << getFormatSubstr(pid[0]) << ", " << getFormatSubstr(pid[1]) + << ", " << getFormatSubstr(pid[2]) << ")" << prefix; + if (arg.has_value()) + os << getFormatSubstr(arg.value(), hex); + + llvm::SmallString<64> formatStrNewline(formatStr); + formatStrNewline.push_back('\n'); + formatStrNewline.push_back('\0'); + Value formatStrValue = + LLVM::addStringToModule(loc, rewriter, "printfFormat_", formatStrNewline); + + SmallVector allArgs{formatStrValue}; + for (auto elem : pid) + allArgs.push_back(elem); + if (arg.has_value()) + allArgs.push_back(arg.value()); + call(getPrintFuncDecl(rewriter, true), allArgs); +} + +void llVectorPrint(std::array pid, StringRef prefix, Value ptr, + bool isInteger, uint32_t bitWidth, int64_t numElem, + ConversionPatternRewriter &rewriter) { + assert(!prefix.empty()); + auto loc = UnknownLoc::get(rewriter.getContext()); + + llvm::SmallString<64> prefixStr(prefix); + prefixStr.push_back('\0'); + Value prefixValue = + LLVM::addStringToModule(loc, rewriter, "vectorPrintPrefix_", prefixStr); + + SmallVector allArgs; + for (auto elem : pid) + allArgs.push_back(elem); + allArgs.push_back(prefixValue); + allArgs.push_back(ptr); + allArgs.push_back(i32_val(isInteger)); + allArgs.push_back(i32_val(bitWidth)); + allArgs.push_back(i64_val(numElem)); + call(getPrintFuncDecl(rewriter, false), allArgs); +} + +bool usePrintf(triton::cpu::PrintOp op) { + // Simply use printf if no operand or the operand is scalar. + if (op.getNumOperands() == 0) + return true; + + // tt.print is already decomposed to triton_cpu.print per value. + assert(op.getNumOperands() == 1); + Type oprType = op.getOperands()[0].getType(); + return (oprType.isIntOrIndexOrFloat() || isa(oprType)); +} + +struct PrintOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, + matchAndRewrite(triton::cpu::PrintOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto getPid = [&](int axis) { return getProgramId(op->getParentOfType(), axis); }; - SmallVector values = {getPid(0), getPid(1), getPid(2)}; - - std::string formatStr; - llvm::raw_string_ostream os(formatStr); - os << "(" << getFormatSubstr(values[0]) << ", " - << getFormatSubstr(values[1]) << ", " << getFormatSubstr(values[2]) - << ")" << op.getPrefix(); - - for (size_t i = 0; i < op.getNumOperands(); i++) { - auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter); - if (dyn_cast(op.getOperand(i).getType())) { - llvm_unreachable("Not implemented for tensor types"); + std::array pid = {getPid(0), getPid(1), getPid(2)}; + + if (usePrintf(op)) { + if (op.getNumOperands() == 0) { + llPrintf(op.getPrefix(), pid, std::nullopt, rewriter); + } else { + Value llOpr = adaptor.getOperands()[0]; + llPrintf(op.getPrefix(), pid, llOpr, rewriter, op.getHex()); } - - // Only support scalars for now. - assert(elems.size() == 1); - if (i != 0) { - os << ", "; + } else { + Value llOpr = adaptor.getOperands()[0]; + auto vecShapedType = cast(op.getOperands()[0].getType()); + // Currently, we only support 1D vector printing. + if (vecShapedType.getRank() == 1) { + + // To get the pointer of the vector, create an alloca and store it. + auto ptrType = ptr_ty(rewriter.getContext()); + auto ptr = rewriter.create(loc, ptrType, + llOpr.getType(), i32_val(1)); + rewriter.create(loc, llOpr, ptr); + + // TODO: Consider passing an encoded element type information instead of + // booleans and separate bit width. + llVectorPrint(pid, op.getPrefix(), ptr, + vecShapedType.getElementType().isInteger(), + vecShapedType.getElementTypeBitWidth(), + vecShapedType.getNumElements(), rewriter); + } else { + // TODO: support 2D+ vector printing. + std::string msg{op.getPrefix()}; + llvm::raw_string_ostream os(msg); + os << "<>"; + llPrintf(msg, pid, std::nullopt, rewriter); } - os << getFormatSubstr(elems[0], op.getHex()); - values.push_back(elems[0]); } - llPrintf(formatStr, values, rewriter); rewriter.eraseOp(op); return success(); } diff --git a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp index 144cb57b1115..821ea6f954b2 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp @@ -10,11 +10,8 @@ TritonCPUToLLVMTypeConverter::TritonCPUToLLVMTypeConverter( addConversion([&](triton::PointerType type) -> std::optional { return convertTritonPointerType(type); }); - addConversion([this](RankedTensorType tensorTy) -> std::optional { - if (isa(tensorTy.getElementType())) - return VectorType::get(tensorTy.getShape(), - IntegerType::get(tensorTy.getContext(), 64)); - return std::nullopt; + addConversion([this](RankedTensorType type) -> std::optional { + return convertTritonTensorType(type); }); } @@ -41,3 +38,11 @@ Type TritonCPUToLLVMTypeConverter::convertTritonPointerType( } return LLVM::LLVMPointerType::get(ctx); } + +Type TritonCPUToLLVMTypeConverter::convertTritonTensorType( + RankedTensorType type) { + if (isa(type.getElementType())) + return VectorType::get(type.getShape(), + IntegerType::get(type.getContext(), 64)); + llvm_unreachable("No tensor types are expected in TTCIR"); +} diff --git a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h index 35d74a9ec430..02123796ff37 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h +++ b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h @@ -17,6 +17,7 @@ class TritonCPUToLLVMTypeConverter : public LLVMTypeConverter { const DataLayoutAnalysis *analysis = nullptr); Type convertTritonPointerType(triton::PointerType type); + Type convertTritonTensorType(RankedTensorType type); }; #endif diff --git a/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp b/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp index 271a8b28559e..d113e6671531 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp @@ -85,9 +85,9 @@ struct CdivToDiv : public OpRewritePattern { arith::ConstantOp addCstDef; Value addOtherVal; - if (addCstDef = addOpDef.getLhs().getDefiningOp()) + if ((addCstDef = addOpDef.getLhs().getDefiningOp())) addOtherVal = addOpDef.getRhs(); - else if (addCstDef = addOpDef.getRhs().getDefiningOp()) + else if ((addCstDef = addOpDef.getRhs().getDefiningOp())) addOtherVal = addOpDef.getLhs(); else return failure(); diff --git a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt index b200a47da92d..18e675044881 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt +++ b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt @@ -1,6 +1,7 @@ add_triton_library(TritonToTritonCPU ConvertAtomicOps.cpp ConvertControlFlowOps.cpp + ConvertDebugOps.cpp ConvertDotOp.cpp ConvertElementwiseOps.cpp ConvertElemManipOps.cpp diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp new file mode 100644 index 000000000000..cf6e6704bc28 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp @@ -0,0 +1,100 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTDEBUGOPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class DebugOpsConversionTarget : public ConversionTarget { +public: + explicit DebugOpsConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + + addIllegalOp(); + } +}; + +struct PrintOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + // It lowers to triton_cpu.print after converting tensor types to vectors. + // (tt.print doesn't accept vector types, so we have this intermediate op.) + if (op.getNumOperands() == 0) { + rewriter.create(loc, op.getPrefix(), op.getHex(), + ValueRange{}); + } else { + // triton_cpu.print takes up to one vector or scalar operand. It prints + // each value as a separate print call like the GPU and interpreter. + for (size_t i = 0; i < op.getNumOperands(); i++) { + Value opr = op.getOperands()[i]; + // TODO: Consider using memrefs for general N-dimensional vectors. + rewriter.create(loc, op.getPrefix(), op.getHex(), + rewriter.getRemappedValue(opr)); + } + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct ConvertDebugOps + : public triton::impl::ConvertDebugOpsBase { + using ConvertDebugOpsBase::ConvertDebugOpsBase; + + ConvertDebugOps() : ConvertDebugOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + DebugOpsConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertDebugOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/runtime/cpu_runtime.cpp b/third_party/cpu/runtime/cpu_runtime.cpp index 0d69ca6c8ab7..3d232ddb2530 100644 --- a/third_party/cpu/runtime/cpu_runtime.cpp +++ b/third_party/cpu/runtime/cpu_runtime.cpp @@ -1,6 +1,196 @@ +#include +#include #include +#include +#include +#include +#include -void triton_assert(bool cond, char *c) { +#if defined(_MSC_VER) +#define EXPORT __declspec(dllexport) +#elif defined(__GNUC__) +#define EXPORT __attribute__((visibility("default"))) +#else +#define EXPORT +#endif + +namespace { + +// A poor man's Torch-like pretty print for tensors and vectors. +const int MAX_FLOAT_WIDTH = 8; +const int FLOAT_PREC = 4; +const int ELEMS_PER_LINE = 8; + +struct FormatInfo { + bool isInt; + int bitWidth; + int maxIntDigits; + bool hasNegative; + bool scientific; +}; + +template +std::pair +computeDigitInfoHelper(const void *array, size_t index) { + T elem = static_cast(array)[index]; + if (elem == 0) + return {1, false}; + return {static_cast(std::log10(std::abs(elem))) + 1, elem < 0}; +} + +std::pair computeDigitInfo(void *vec, int32_t isInt, + int32_t bitWidth, size_t index) { + + if (isInt == 0) { + if (bitWidth == 32) + return computeDigitInfoHelper(vec, index); + else if (bitWidth == 64) + return computeDigitInfoHelper(vec, index); + else + assert(false && "Unsupported bitWidth"); + } else { + // TODO: Handle signed types? + if (bitWidth == 64) + return computeDigitInfoHelper(vec, index); + else if (bitWidth == 32) + return computeDigitInfoHelper(vec, index); + else if (bitWidth == 16) + return computeDigitInfoHelper(vec, index); + else if (bitWidth == 8) + return computeDigitInfoHelper(vec, index); + else + assert(false && "Unsupported bitWidth"); + } +} + +FormatInfo getFormatInfo(void *vec, bool isInt, int32_t bitWidth, + int64_t numElem) { + // Compute the max/min widths for pretty printing. + int maxIntDigits = 0; + int minIntDigits = std::numeric_limits::max(); + bool hasNegative = false; + for (int64_t i = 0; i < numElem; ++i) { + auto [digits, negative] = computeDigitInfo(vec, isInt, bitWidth, i); + hasNegative |= negative; + maxIntDigits = std::max(maxIntDigits, digits); + minIntDigits = std::min(minIntDigits, digits); + } + // Fallback to the scientific format for certain cases. + bool scientific; + if (isInt) { + scientific = false; + } else { + scientific = maxIntDigits + 2 + (hasNegative ? 1 : 0) > MAX_FLOAT_WIDTH; + scientific |= maxIntDigits - minIntDigits > 3; + } + return {isInt, bitWidth, maxIntDigits, hasNegative, scientific}; +} + +template +void printElementHelper(std::stringstream &ss, const void *array, + size_t index) { + ss << static_cast(array)[index]; +} + +void printElement(std::stringstream &ss, const void *vec, size_t index, + bool isInt, int bitWidth) { + if (isInt == 0) { + switch (bitWidth) { + case 32: + printElementHelper(ss, vec, index); + break; + case 64: + printElementHelper(ss, vec, index); + break; + default: + assert(false && "Unsupported bitWidth"); + } + } else { + switch (bitWidth) { + case 64: + printElementHelper(ss, vec, index); + break; + case 32: + printElementHelper(ss, vec, index); + break; + case 16: + printElementHelper(ss, vec, index); + break; + case 8: + // TODO: Seems like not working well. Need to fix it. + printElementHelper(ss, vec, index); + break; + default: + assert(false && "Unsupported bitWidth"); + } + } +} + +void printFormattedElement(std::stringstream &ss, void *vec, size_t index, + const FormatInfo &formatInfo) { + int padding = 0; + auto [digits, negative] = + computeDigitInfo(vec, formatInfo.isInt, formatInfo.bitWidth, index); + if (!negative && formatInfo.hasNegative) + padding++; + if (formatInfo.scientific) { + ss << std::scientific << std::setw(MAX_FLOAT_WIDTH) + << std::setprecision(FLOAT_PREC) << std::string(padding, ' '); + printElement(ss, vec, index, formatInfo.isInt, formatInfo.bitWidth); + } else { + padding += formatInfo.maxIntDigits - digits; + ss << std::fixed << std::setprecision(FLOAT_PREC) + << std::string(padding, ' '); + printElement(ss, vec, index, formatInfo.isInt, formatInfo.bitWidth); + } +} +} // namespace + +extern "C" { + +EXPORT void triton_assert(bool cond, char *c) { if (!cond) fprintf(stderr, "%s\n", c); } + +// Print the pid prefix like the GPU ad interpreter. And vectors are printed +// similar to Torch's printing like the following: +// (1, 0, 0) x: [ -0.4963, -1.7682, 2.0885, 3.1320, -4.3074, 5.6341, +// -6.4901, 7.8964, -8.4556, -9.6323, -10.3489, -11.4017, +// -12.0223, 13.1689, 14.2939, -15.5185] +// +// TODO: Implement for higher dimension vectors. +EXPORT void triton_vector_print(int32_t pid0, int32_t pid1, int32_t pid2, + const char *prefix, void *vec, int32_t isInt, + int32_t bitWidth, int64_t numElem) { + + FormatInfo formatInfo = getFormatInfo(vec, isInt != 0, bitWidth, numElem); + + std::stringstream ss; + ss << "(" << pid0 << ", " << pid1 << ", " << pid2 << ")" << prefix << "["; + const size_t header = ss.str().size(); + + if (numElem <= ELEMS_PER_LINE) { + for (int i = 0; i < numElem; i++) { + printFormattedElement(ss, vec, i, formatInfo); + if (i != numElem - 1) + ss << ", "; + } + } else { + // TODO: Too many lines? Omit the middle lines. + for (int i = 0; i < numElem; i++) { + printFormattedElement(ss, vec, i, formatInfo); + if (i == numElem - 1) + break; + if (i % ELEMS_PER_LINE == ELEMS_PER_LINE - 1) { + ss << ",\n" << std::string(header, ' '); + } else { + ss << ", "; + } + } + } + ss << "]\n"; + std::cout << ss.str() << std::flush; +} + +} // extern "C" diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 99975bb2f5d2..90c886b6ec85 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -54,6 +54,9 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_convert_atomic_ops", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertAtomicOps()); }); + m.def("add_convert_debug_ops", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertDebugOps()); + }); m.def("add_optimize_masks", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createOptimizeMasks()); });