diff --git a/include/imex/Dialect/XeTile/IR/XeTileAttrs.td b/include/imex/Dialect/XeTile/IR/XeTileAttrs.td index e587be538..b6e6735c5 100644 --- a/include/imex/Dialect/XeTile/IR/XeTileAttrs.td +++ b/include/imex/Dialect/XeTile/IR/XeTileAttrs.td @@ -79,9 +79,10 @@ def XeTile_TileAttr : XeTile_Attr<"XeTile", "tile_attr"> { [{ mlir::Type intType = mlir::IntegerType::get($_ctxt, 32); mlir::BoolAttr scatteredAttr = mlir::BoolAttr::get($_ctxt, scattered); + mlir::DenseI64ArrayAttr blkAttr = inner_blocks.empty()? mlir::DenseI64ArrayAttr(): + mlir::DenseI64ArrayAttr::get($_ctxt, inner_blocks); return $_get($_ctxt, sg_map, wg_map, mlir::DenseI32ArrayAttr::get($_ctxt, order), - mlir::DenseI64ArrayAttr::get($_ctxt, inner_blocks), - mlir::IntegerAttr::get(intType, memory_space), scatteredAttr); + blkAttr, mlir::IntegerAttr::get(intType, memory_space), scatteredAttr); }]>, AttrBuilder<(ins CArg<"llvm::ArrayRef", "{1, 0}">:$order, CArg<"int", "0">:$memory_space, CArg<"bool", "false">:$scattered), @@ -90,7 +91,7 @@ def XeTile_TileAttr : XeTile_Attr<"XeTile", "tile_attr"> { mlir::BoolAttr scatteredAttr = mlir::BoolAttr::get($_ctxt, scattered); return $_get($_ctxt, xetile::SubGroupMapAttr(), xetile::WorkGroupMapAttr(), mlir::DenseI32ArrayAttr::get($_ctxt, order), - mlir::DenseI64ArrayAttr::get($_ctxt, {}), + mlir::DenseI64ArrayAttr(), mlir::IntegerAttr::get(intType, memory_space), scatteredAttr); }]>, AttrBuilder<(ins CArg<"xetile::SubGroupMapAttr", "{}">:$sg_map, @@ -101,7 +102,7 @@ def XeTile_TileAttr : XeTile_Attr<"XeTile", "tile_attr"> { mlir::Type intType = mlir::IntegerType::get($_ctxt, 32); mlir::BoolAttr scatteredAttr = mlir::BoolAttr::get($_ctxt, scattered); return $_get($_ctxt, sg_map, wg_map, mlir::DenseI32ArrayAttr::get($_ctxt, order), - mlir::DenseI64ArrayAttr::get($_ctxt, {}), + mlir::DenseI64ArrayAttr(), mlir::IntegerAttr::get(intType, memory_space), scatteredAttr); }]> ]; diff --git a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp index 4fa4df3da..940218b3c 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp @@ -1090,6 +1090,32 @@ struct TypecastOpPattern : public XeOneToNConversion { } }; +struct SgArithCmpIOpPattern : public XeOneToNConversion { + using XeOneToNConversion::XeOneToNConversion; + + mlir::LogicalResult + matchAndRewrite(mlir::arith::CmpIOp op, OpAdaptor adaptor, + XeOneToNPatternRewriter &rewriter) const override { + auto res = op.getResult(); + auto resType = mlir::dyn_cast(res.getType()); + if (!resType || resType.getRank() != 4) + return mlir::failure(); + + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); + + llvm::SmallVector newOps; + for (auto [l, r] : llvm::zip_equal(lhs, rhs)) { + auto newOp = rewriter.create( + op.getLoc(), op.getPredicate(), l, r); + newOps.push_back(newOp); + } + + rewriter.replaceOp(op, newOps); + return mlir::success(); + } +}; + struct SgBroadcastOpPattern : public XeOneToNConversion { using XeOneToNConversion::XeOneToNConversion; @@ -1256,8 +1282,8 @@ void populateXeTileOpConversionPatterns(imex::XeOneToNTypeConverter &converter, SgVectorSplatOpPattern, SgUpdateTileOffsetOpPattern, SgTransposeOpPattern, SgTransposeOpPattern, SgBroadcastOpPattern, - SgTileReductionOpPattern, SgVectorCreateMaskOpPattern>( - patterns.getContext(), converter, analysis); + SgTileReductionOpPattern, SgVectorCreateMaskOpPattern, + SgArithCmpIOpPattern>(patterns.getContext(), converter, analysis); // Element-wise math operations patterns.insert,