Skip to content

Commit

Permalink
[cpu] Rework device_print with triton_cpu.print and 1D vector printing (
Browse files Browse the repository at this point in the history
#99)

* [cpu] Rework device_print with 1D vector printing

* Update minor comments

* Apply suggestions from code review

Co-authored-by: Jez Ng <[email protected]>

* A few fixes upon the previous suggestions

* Refactoring + update comments

---------

Co-authored-by: Jez Ng <[email protected]>
  • Loading branch information
minjang and int3 authored Aug 14, 2024
1 parent 6d27cdc commit 94b6403
Show file tree
Hide file tree
Showing 16 changed files with 487 additions and 85 deletions.
20 changes: 20 additions & 0 deletions include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,24 @@ def TTC_PtrToMemRefOp : TTC_Op<"ptr_to_memref", [NoMemoryEffect]> {
let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
}

def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">;

def TTC_PrintOp : TTC_Op<"print", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
let summary = "Print at most a single scalar or vector (converted from tensor) on each line";

let description = [{
For converting tensor types to vector types.
It only takes a single scalar or vector (tensor) element.
}];

let arguments = (ins StrAttr:$prefix, BoolAttr:$hex,
Variadic<AnyTypeOf<[TT_Float, TT_Int, TT_Ptr, TTC_Vector]>>:$val);

let assemblyFormat = [{
$prefix attr-dict (`:` $val^ `:` type($val))?
}];

let hasVerifier = 1;
}

#endif
3 changes: 3 additions & 0 deletions include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef TRITONCPU_TYPES
#define TRITONCPU_TYPES

include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/TritonCPU/IR/TritonCPUDialect.td"
include "mlir/IR/AttrTypeBase.td"

Expand All @@ -23,4 +24,6 @@ def TTC_TokenType : TTC_TypeDef<"Token", "token"> {
let skipDefaultBuilders = 1;
}

def TTC_Vector : VectorOf<[TT_Float, TT_Int]>;

#endif
1 change: 1 addition & 0 deletions lib/Dialect/TritonCPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_triton_library(TritonCPUIR
Dialect.cpp
Ops.cpp
Types.cpp

DEPENDS
Expand Down
3 changes: 0 additions & 3 deletions lib/Dialect/TritonCPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ void TritonCPUDialect::initialize() {
>();
}

#define GET_OP_CLASSES
#include "triton/Dialect/TritonCPU/IR/Ops.cpp.inc"

// verify TritonCPU ops
LogicalResult TritonCPUDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
Expand Down
18 changes: 18 additions & 0 deletions lib/Dialect/TritonCPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include "mlir/IR/Builders.h"
#include "triton/Dialect/TritonCPU/IR/Dialect.h"

#define GET_OP_CLASSES
#include "triton/Dialect/TritonCPU/IR/Ops.cpp.inc"

// enum attribute definitions
#include "triton/Dialect/TritonCPU/IR/OpsEnums.cpp.inc"

namespace mlir::triton::cpu {

LogicalResult PrintOp::verify() {
if (getOperands().size() > 1)
return emitOpError("expects at most one operand");
return success();
}

} // namespace mlir::triton::cpu
1 change: 1 addition & 0 deletions third_party/cpu/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def make_ttcir(mod, metadata, opt):
cpu.passes.ttcpuir.add_convert_scan_op(pm)
cpu.passes.ttcpuir.add_convert_cf_ops(pm)
cpu.passes.ttcpuir.add_convert_atomic_ops(pm)
cpu.passes.ttcpuir.add_convert_debug_ops(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
passes.common.add_canonicalizer(pm)
Expand Down
1 change: 1 addition & 0 deletions third_party/cpu/include/TritonToTritonCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertHistogramOp();
std::unique_ptr<OperationPass<ModuleOp>> createConvertReductionOp();
std::unique_ptr<OperationPass<ModuleOp>> createConvertScanOp();
std::unique_ptr<OperationPass<ModuleOp>> createConvertAtomicOps();
std::unique_ptr<OperationPass<ModuleOp>> createConvertDebugOps();

#define GEN_PASS_REGISTRATION
#include "cpu/include/TritonToTritonCPU/Passes.h.inc"
Expand Down
13 changes: 13 additions & 0 deletions third_party/cpu/include/TritonToTritonCPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,17 @@ def ConvertAtomicOps : Pass<"triton-cpu-convert-atomic-ops", "mlir::ModuleOp"> {
"mlir::triton::cpu::TritonCPUDialect"];
}

def ConvertDebugOps : Pass<"triton-cpu-convert-debug-ops", "mlir::ModuleOp"> {
let summary = "Convert Triton debug operations.";
let description = [{

}];
let constructor = "mlir::triton::cpu::createConvertDebugOps()";

let dependentDialects = ["mlir::vector::VectorDialect",
"mlir::scf::SCFDialect",
"mlir::triton::TritonDialect",
"mlir::triton::cpu::TritonCPUDialect"];
}

#endif
196 changes: 122 additions & 74 deletions third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "cpu/include/TritonCPUToLLVM/Passes.h"

#include "mlir/Dialect/GPU/IR/GPUOps.h.inc"
#include "mlir/Dialect/Vector/IR/VectorOps.h"

#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
Expand Down Expand Up @@ -31,54 +32,6 @@ class TritonLLVMConversionTarget : public ConversionTarget {
}
};

// The code for the print is similar to the GPU's TargetInfo.cpp.
LLVM::LLVMFuncOp getPrintfDeclaration(ConversionPatternRewriter &rewriter) {
auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
StringRef funcName("printf");
Operation *funcOp = moduleOp.lookupSymbol(funcName);
if (funcOp)
return cast<LLVM::LLVMFuncOp>(*funcOp);

auto *context = rewriter.getContext();

// int printf(char* format, ...)
SmallVector<Type> argsType{ptr_ty(context)};
auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType, true);

ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());

auto op = rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(context),
funcName, funcType);
return op;
}

void emitPrintf(ConversionPatternRewriter &rewriter, Value formatStrStart,
int /*formatStrByteCount*/, ValueRange args) {
auto loc = UnknownLoc::get(rewriter.getContext());
SmallVector<Value> formatStrAndArgs{formatStrStart};
for (auto arg : args) {
formatStrAndArgs.push_back(arg);
}
call(getPrintfDeclaration(rewriter), formatStrAndArgs);
}

Value llPrintf(StringRef msg, ValueRange args,
ConversionPatternRewriter &rewriter,
int *formatStrByteCount = nullptr) {
assert(!msg.empty() && "printf with empty string not supported");
llvm::SmallString<64> msgNewline(msg);
msgNewline.push_back('\n');
msgNewline.push_back('\0');
Value msgValue =
LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), rewriter,
"printfFormat_", msgNewline);
emitPrintf(rewriter, msgValue, msgNewline.size_in_bytes(), args);
if (formatStrByteCount)
*formatStrByteCount = msgNewline.size_in_bytes();
return msgValue;
}

// TODO: This code is the same as the GPU-backend code. Consider refactoring.
std::string getFormatSubstr(Value value, bool hex = false,
std::optional<int> width = std::nullopt) {
Expand Down Expand Up @@ -123,44 +76,139 @@ std::string getFormatSubstr(Value value, bool hex = false,
return "";
}

// TritonCPU's device_print prints all values in the same line unlike GPUs
// and interpreter where each value is printed in a separate line.
struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::PrintOp> {
explicit PrintOpConversion(LLVMTypeConverter &typeConverter)
: mlir::ConvertOpToLLVMPattern<triton::PrintOp>(typeConverter) {}
LLVM::LLVMFuncOp getPrintFuncDecl(ConversionPatternRewriter &rewriter,
bool printf) {
auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
StringRef funcName = printf ? "printf" : "triton_vector_print";
Operation *funcOp = moduleOp.lookupSymbol(funcName);
if (funcOp)
return cast<LLVM::LLVMFuncOp>(*funcOp);

auto *ctx = rewriter.getContext();
SmallVector<Type> argsType;
if (printf)
argsType = {ptr_ty(ctx)};
else
argsType = {i32_ty, i32_ty, i32_ty, ptr_ty(ctx),
ptr_ty(ctx), i32_ty, i32_ty, i64_ty};

auto funcType =
LLVM::LLVMFunctionType::get(i32_ty, argsType, /*isVarArg*/ printf);

ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());

return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(ctx), funcName,
funcType);
}

void llPrintf(StringRef prefix, std::array<Value, 3> pid,
std::optional<Value> arg, ConversionPatternRewriter &rewriter,
bool hex = false) {
assert(!prefix.empty() && "printf with empty string not supported");
auto loc = UnknownLoc::get(rewriter.getContext());

std::string formatStr;
llvm::raw_string_ostream os(formatStr);
os << "(" << getFormatSubstr(pid[0]) << ", " << getFormatSubstr(pid[1])
<< ", " << getFormatSubstr(pid[2]) << ")" << prefix;
if (arg.has_value())
os << getFormatSubstr(arg.value(), hex);

llvm::SmallString<64> formatStrNewline(formatStr);
formatStrNewline.push_back('\n');
formatStrNewline.push_back('\0');
Value formatStrValue =
LLVM::addStringToModule(loc, rewriter, "printfFormat_", formatStrNewline);

SmallVector<Value> allArgs{formatStrValue};
for (auto elem : pid)
allArgs.push_back(elem);
if (arg.has_value())
allArgs.push_back(arg.value());
call(getPrintFuncDecl(rewriter, true), allArgs);
}

void llVectorPrint(std::array<Value, 3> pid, StringRef prefix, Value ptr,
bool isInteger, uint32_t bitWidth, int64_t numElem,
ConversionPatternRewriter &rewriter) {
assert(!prefix.empty());
auto loc = UnknownLoc::get(rewriter.getContext());

llvm::SmallString<64> prefixStr(prefix);
prefixStr.push_back('\0');
Value prefixValue =
LLVM::addStringToModule(loc, rewriter, "vectorPrintPrefix_", prefixStr);

SmallVector<Value> allArgs;
for (auto elem : pid)
allArgs.push_back(elem);
allArgs.push_back(prefixValue);
allArgs.push_back(ptr);
allArgs.push_back(i32_val(isInteger));
allArgs.push_back(i32_val(bitWidth));
allArgs.push_back(i64_val(numElem));
call(getPrintFuncDecl(rewriter, false), allArgs);
}

bool usePrintf(triton::cpu::PrintOp op) {
// Simply use printf if no operand or the operand is scalar.
if (op.getNumOperands() == 0)
return true;

// tt.print is already decomposed to triton_cpu.print per value.
assert(op.getNumOperands() == 1);
Type oprType = op.getOperands()[0].getType();
return (oprType.isIntOrIndexOrFloat() || isa<triton::PointerType>(oprType));
}

struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::cpu::PrintOp> {
using ConvertOpToLLVMPattern<triton::cpu::PrintOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor,
matchAndRewrite(triton::cpu::PrintOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();

auto getPid = [&](int axis) {
return getProgramId(op->getParentOfType<LLVM::LLVMFuncOp>(), axis);
};
SmallVector<Value> values = {getPid(0), getPid(1), getPid(2)};

std::string formatStr;
llvm::raw_string_ostream os(formatStr);
os << "(" << getFormatSubstr(values[0]) << ", "
<< getFormatSubstr(values[1]) << ", " << getFormatSubstr(values[2])
<< ")" << op.getPrefix();

for (size_t i = 0; i < op.getNumOperands(); i++) {
auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter);
if (dyn_cast<RankedTensorType>(op.getOperand(i).getType())) {
llvm_unreachable("Not implemented for tensor types");
std::array<Value, 3> pid = {getPid(0), getPid(1), getPid(2)};

if (usePrintf(op)) {
if (op.getNumOperands() == 0) {
llPrintf(op.getPrefix(), pid, std::nullopt, rewriter);
} else {
Value llOpr = adaptor.getOperands()[0];
llPrintf(op.getPrefix(), pid, llOpr, rewriter, op.getHex());
}

// Only support scalars for now.
assert(elems.size() == 1);
if (i != 0) {
os << ", ";
} else {
Value llOpr = adaptor.getOperands()[0];
auto vecShapedType = cast<ShapedType>(op.getOperands()[0].getType());
// Currently, we only support 1D vector printing.
if (vecShapedType.getRank() == 1) {

// To get the pointer of the vector, create an alloca and store it.
auto ptrType = ptr_ty(rewriter.getContext());
auto ptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType,
llOpr.getType(), i32_val(1));
rewriter.create<LLVM::StoreOp>(loc, llOpr, ptr);

// TODO: Consider passing an encoded element type information instead of
// booleans and separate bit width.
llVectorPrint(pid, op.getPrefix(), ptr,
vecShapedType.getElementType().isInteger(),
vecShapedType.getElementTypeBitWidth(),
vecShapedType.getNumElements(), rewriter);
} else {
// TODO: support 2D+ vector printing.
std::string msg{op.getPrefix()};
llvm::raw_string_ostream os(msg);
os << "<<not implemented for '" << llOpr.getType() << "'>>";
llPrintf(msg, pid, std::nullopt, rewriter);
}
os << getFormatSubstr(elems[0], op.getHex());
values.push_back(elems[0]);
}

llPrintf(formatStr, values, rewriter);
rewriter.eraseOp(op);
return success();
}
Expand Down
15 changes: 10 additions & 5 deletions third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@ TritonCPUToLLVMTypeConverter::TritonCPUToLLVMTypeConverter(
addConversion([&](triton::PointerType type) -> std::optional<Type> {
return convertTritonPointerType(type);
});
addConversion([this](RankedTensorType tensorTy) -> std::optional<Type> {
if (isa<PointerType>(tensorTy.getElementType()))
return VectorType::get(tensorTy.getShape(),
IntegerType::get(tensorTy.getContext(), 64));
return std::nullopt;
addConversion([this](RankedTensorType type) -> std::optional<Type> {
return convertTritonTensorType(type);
});
}

Expand All @@ -41,3 +38,11 @@ Type TritonCPUToLLVMTypeConverter::convertTritonPointerType(
}
return LLVM::LLVMPointerType::get(ctx);
}

Type TritonCPUToLLVMTypeConverter::convertTritonTensorType(
RankedTensorType type) {
if (isa<PointerType>(type.getElementType()))
return VectorType::get(type.getShape(),
IntegerType::get(type.getContext(), 64));
llvm_unreachable("No tensor types are expected in TTCIR");
}
1 change: 1 addition & 0 deletions third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class TritonCPUToLLVMTypeConverter : public LLVMTypeConverter {
const DataLayoutAnalysis *analysis = nullptr);

Type convertTritonPointerType(triton::PointerType type);
Type convertTritonTensorType(RankedTensorType type);
};

#endif
4 changes: 2 additions & 2 deletions third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ struct CdivToDiv : public OpRewritePattern<arith::DivSIOp> {

arith::ConstantOp addCstDef;
Value addOtherVal;
if (addCstDef = addOpDef.getLhs().getDefiningOp<arith::ConstantOp>())
if ((addCstDef = addOpDef.getLhs().getDefiningOp<arith::ConstantOp>()))
addOtherVal = addOpDef.getRhs();
else if (addCstDef = addOpDef.getRhs().getDefiningOp<arith::ConstantOp>())
else if ((addCstDef = addOpDef.getRhs().getDefiningOp<arith::ConstantOp>()))
addOtherVal = addOpDef.getLhs();
else
return failure();
Expand Down
1 change: 1 addition & 0 deletions third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_triton_library(TritonToTritonCPU
ConvertAtomicOps.cpp
ConvertControlFlowOps.cpp
ConvertDebugOps.cpp
ConvertDotOp.cpp
ConvertElementwiseOps.cpp
ConvertElemManipOps.cpp
Expand Down
Loading

0 comments on commit 94b6403

Please sign in to comment.