Skip to content

Commit

Permalink
Add lowering pattern and test for arith::AndIOp
Browse files Browse the repository at this point in the history
  • Loading branch information
chencha3 committed Oct 7, 2024
1 parent fe1535e commit daa697e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 0 deletions.
1 change: 1 addition & 0 deletions lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,7 @@ void populateXeTileOpConversionPatterns(imex::XeOneToNTypeConverter &converter,
ElementWiseOpPattern<mlir::math::RsqrtOp, 1>,
ElementWiseOpPattern<mlir::math::ErfOp, 1>,
ElementWiseOpPattern<mlir::arith::AddFOp, 2>,
ElementWiseOpPattern<mlir::arith::AndIOp, 2>,
ElementWiseOpPattern<mlir::arith::RemFOp, 2>,
ElementWiseOpPattern<mlir::arith::DivFOp, 2>,
ElementWiseOpPattern<mlir::arith::MulFOp, 2>,
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class XeTileConversionTarget : public mlir::ConversionTarget {
// Arith ops
addDynamicallyLegalOp<mlir::arith::AddFOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::AndIOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::DivFOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::MulFOp>(
Expand Down
36 changes: 36 additions & 0 deletions test/Conversion/XeTileToXeGPU/elementwise_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -221,4 +221,40 @@
xetile.store_tile %7, %6 : vector<4x2x8x16xi16>, !xetile.tile<32x32xi16, #xetile.tile_attr<inner_blocks = [8, 16]>>
gpu.return
}


gpu.func @sglevel_and_test(%arg0: memref<1x4096xi8>, %arg1: memref<1x4096xi8>, %arg2: memref<1x4096xi8>) {
%c0 = arith.constant 0 : index
%c4096 = arith.constant 4096 : index
%c32 = arith.constant 32 : index
%c1024_i32 = arith.constant 1024 : i32
%thread_id_x = gpu.thread_id x
%thread_id_y = gpu.thread_id y
%block_dim_y = gpu.block_dim y
%0 = arith.muli %thread_id_x, %block_dim_y : index
%1 = arith.addi %0, %thread_id_y : index
%block_id_x = gpu.block_id x
%2 = arith.index_cast %block_id_x : index to i32
%3 = arith.muli %2, %c1024_i32 : i32
%4 = arith.index_cast %3 : i32 to index
%5 = arith.remsi %1, %c32 : index
%6 = arith.muli %5, %c32 : index
%7 = arith.remsi %6, %c4096 : index
%8 = arith.addi %7, %4 : index
%9 = xetile.init_tile %arg0[%c0, %8] : memref<1x4096xi8> -> !xetile.tile<1x32xi8, #xetile.tile_attr<inner_blocks = [1, 32]>>
%10 = xetile.load_tile %9 {padding = 0 : i32} : !xetile.tile<1x32xi8, #xetile.tile_attr<inner_blocks = [1, 32]>> -> vector<1x1x1x32xi8>
%11 = xetile.tile_unpack %10 {inner_blocks = array<i64: 1, 32>} : vector<1x1x1x32xi8> -> vector<1x32xi8>
%12 = xetile.init_tile %arg1[%c0, %8] : memref<1x4096xi8> -> !xetile.tile<1x32xi8, #xetile.tile_attr<inner_blocks = [1, 32]>>
%13 = xetile.load_tile %12 {padding = 0 : i32} : !xetile.tile<1x32xi8, #xetile.tile_attr<inner_blocks = [1, 32]>> -> vector<1x1x1x32xi8>
%14 = xetile.tile_unpack %13 {inner_blocks = array<i64: 1, 32>} : vector<1x1x1x32xi8> -> vector<1x32xi8>
%15 = xetile.tile_pack %11 {inner_blocks = array<i64: 1, 32>} : vector<1x32xi8> -> vector<1x1x1x32xi8>
%16 = xetile.tile_pack %14 {inner_blocks = array<i64: 1, 32>} : vector<1x32xi8> -> vector<1x1x1x32xi8>
//CHECK: %{{.*}} = arith.andi %{{.*}}, %{{.*}} : vector<1x32xi8>
%17 = arith.andi %15, %16 : vector<1x1x1x32xi8>
%18 = xetile.tile_unpack %17 {inner_blocks = array<i64: 1, 32>} : vector<1x1x1x32xi8> -> vector<1x32xi8>
%19 = xetile.init_tile %arg2[%c0, %8] : memref<1x4096xi8> -> !xetile.tile<1x32xi8, #xetile.tile_attr<inner_blocks = [1, 32]>>
%20 = xetile.tile_pack %18 {inner_blocks = array<i64: 1, 32>} : vector<1x32xi8> -> vector<1x1x1x32xi8>
xetile.store_tile %20, %19 : vector<1x1x1x32xi8>, !xetile.tile<1x32xi8, #xetile.tile_attr<inner_blocks = [1, 32]>>
gpu.return
}
}

0 comments on commit daa697e

Please sign in to comment.