diff --git a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp index cbe6960be..4fa4df3da 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp @@ -1258,6 +1258,8 @@ void populateXeTileOpConversionPatterns(imex::XeOneToNTypeConverter &converter, SgTransposeOpPattern, SgBroadcastOpPattern, SgTileReductionOpPattern, SgVectorCreateMaskOpPattern>( patterns.getContext(), converter, analysis); + + // Element-wise math operations patterns.insert, ElementWiseOpPattern, ElementWiseOpPattern, @@ -1267,16 +1269,30 @@ void populateXeTileOpConversionPatterns(imex::XeOneToNTypeConverter &converter, ElementWiseOpPattern, ElementWiseOpPattern, ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, + ElementWiseOpPattern>( + patterns.getContext(), converter, analysis); + + // Element-wise arithmetic operations + patterns.insert, + ElementWiseOpPattern, ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, ElementWiseOpPattern, + ElementWiseOpPattern, ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, ElementWiseOpPattern, + ElementWiseOpPattern, + ElementWiseOpPattern, ElementWiseOpPattern, - ElementWiseOpPattern, ElementWiseOpPattern>( patterns.getContext(), converter, analysis); patterns.insert, diff --git a/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp b/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp index 0564fbb4d..c74d9cf3d 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp @@ -91,12 +91,20 @@ class XeTileConversionTarget : public mlir::ConversionTarget { // Arith ops addDynamicallyLegalOp( [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); + addDynamicallyLegalOp( + [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); addDynamicallyLegalOp( [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); addDynamicallyLegalOp( [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); + addDynamicallyLegalOp( + [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); + addDynamicallyLegalOp( + [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); addDynamicallyLegalOp( [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); + addDynamicallyLegalOp( + [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); addDynamicallyLegalOp( [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); addDynamicallyLegalOp( @@ -105,16 +113,28 @@ class XeTileConversionTarget : public mlir::ConversionTarget { [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); addDynamicallyLegalOp( [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); + addDynamicallyLegalOp( + [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); addDynamicallyLegalOp( [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); + addDynamicallyLegalOp( + [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); + addDynamicallyLegalOp( + [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); addDynamicallyLegalOp( [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( + addDynamicallyLegalOp( [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); - addDynamicallyLegalOp( + addDynamicallyLegalOp( + [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); + addDynamicallyLegalOp( [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); addDynamicallyLegalOp( [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); + addDynamicallyLegalOp( + [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); + addDynamicallyLegalOp( + [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); addDynamicallyLegalOp( [&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); }); addDynamicallyLegalOp( diff --git a/test/Conversion/XeTileToXeGPU/elementwise_ops.mlir b/test/Conversion/XeTileToXeGPU/elementwise_ops.mlir index 18898ab83..a88da2a34 100644 --- a/test/Conversion/XeTileToXeGPU/elementwise_ops.mlir +++ b/test/Conversion/XeTileToXeGPU/elementwise_ops.mlir @@ -22,6 +22,38 @@ gpu.return } + gpu.func @arith_binary_ops_int() { + %0 = arith.constant dense<1>: vector<4x4x16x16xi16> + %1 = arith.constant dense<2>: vector<64x4x1x16xi16> + %2 = xetile.tile_unpack %0 {inner_blocks = array}: vector<4x4x16x16xi16> -> vector<64x64xi16> + %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<64x64xi16> -> vector<64x4x1x16xi16> + // CHECK-COUNT-256: arith.addi {{.*}}, {{.*}} : vector<1x16xi16> + // CHECK-COUNT-256: arith.subi + // CHECK-COUNT-256: arith.muli + // CHECK-COUNT-256: arith.maxsi + // CHECK-COUNT-256: arith.maxui + // CHECK-COUNT-256: arith.minsi + // CHECK-COUNT-256: arith.minui + // CHECK-COUNT-256: arith.divsi + // CHECK-COUNT-256: arith.divui + // CHECK-COUNT-256: arith.remsi + // CHECK-COUNT-256: arith.remui + // CHECK-COUNT-256: arith.andi + %result = arith.addi %3, %1 : vector<64x4x1x16xi16> + %subi_result = arith.subi %3, %1 : vector<64x4x1x16xi16> + %muli_result = arith.muli %subi_result, %1 : vector<64x4x1x16xi16> + %maxsi_result = arith.maxsi %muli_result, %1 : vector<64x4x1x16xi16> + %maxui_result = arith.maxui %muli_result, %1 : vector<64x4x1x16xi16> + %minsi_result = arith.minsi %maxsi_result, %muli_result : vector<64x4x1x16xi16> + %minui_result = arith.minui %maxui_result, %muli_result : vector<64x4x1x16xi16> + %divsi_result = arith.divsi %minui_result, %1 : vector<64x4x1x16xi16> + %divui_result = arith.divui %minui_result, %1 : vector<64x4x1x16xi16> + %remsi_result = arith.remsi %minsi_result, %divsi_result : vector<64x4x1x16xi16> + %remui_result = arith.remui %minui_result, %divui_result : vector<64x4x1x16xi16> + %and_result = arith.andi %remsi_result, %remui_result : vector<64x4x1x16xi16> + gpu.return + } + gpu.func @arith_xori_ops() { %0 = arith.constant dense<1>: vector<4x4x16x16xi16> %1 = arith.constant dense<2>: vector<64x4x1x16xi16> diff --git a/test/Integration/Dialect/XeTile/eltwise_int_ops.mlir b/test/Integration/Dialect/XeTile/eltwise_int_ops.mlir new file mode 100644 index 000000000..93b557277 --- /dev/null +++ b/test/Integration/Dialect/XeTile/eltwise_int_ops.mlir @@ -0,0 +1,78 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +module @eltwise_int attributes {gpu.container_module} { + memref.global "private" constant @__constant_5_1024x1024xi32 : memref<1024x1024xi32> = dense<5> + memref.global "private" constant @__constant_2_1024x1024xi32 : memref<1024x1024xi32> = dense<2> + + func.func @eltwise_int_test(%arg0: memref<1024x1024xi32>, %arg1: memref<1024x1024xi32>) -> memref<1024x1024xi32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %arg0_gpu = gpu.alloc host_shared () : memref<1024x1024xi32> + memref.copy %arg0, %arg0_gpu : memref<1024x1024xi32> to memref<1024x1024xi32> + + %arg1_gpu = gpu.alloc host_shared () : memref<1024x1024xi32> + memref.copy %arg1, %arg1_gpu : memref<1024x1024xi32> to memref<1024x1024xi32> + + %result = gpu.alloc host_shared () : memref<1024x1024xi32> + + gpu.launch_func @eltwise_int::@eltwise_int blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%arg0_gpu : memref<1024x1024xi32>, %arg1_gpu : memref<1024x1024xi32>, %result : memref<1024x1024xi32>) + + gpu.dealloc %arg0_gpu : memref<1024x1024xi32> + gpu.dealloc %arg1_gpu : memref<1024x1024xi32> + return %result : memref<1024x1024xi32> + + } + + gpu.module @eltwise_int attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @eltwise_int(%arg0: memref<1024x1024xi32>, %arg1: memref<1024x1024xi32>, %arg2: memref<1024x1024xi32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + + %m = arith.muli %block_id_x, %c16 : index + %n = arith.muli %block_id_y, %c32 : index + + %1 = xetile.init_tile %arg0[%m, %n] : memref<1024x1024xi32> -> !xetile.tile<16x32xi32> + %2 = xetile.load_tile %1: !xetile.tile<16x32xi32> -> vector<16x32xi32> + %3 = xetile.init_tile %arg1[%m, %n] : memref<1024x1024xi32> -> !xetile.tile<16x32xi32> + %4 = xetile.load_tile %3: !xetile.tile<16x32xi32> -> vector<16x32xi32> + %result_add = arith.addi %2, %4: vector<16x32xi32> //=7 + %result_sub = arith.subi %2, %4: vector<16x32xi32> //=3 + %result_mul = arith.muli %result_add, %result_sub: vector<16x32xi32> //=21 + %result_sdiv = arith.divsi %result_mul, %result_add: vector<16x32xi32> //=3 + %result_udiv = arith.divui %result_mul, %result_add: vector<16x32xi32> //=3 + %result_srem = arith.remsi %result_sdiv, %result_mul: vector<16x32xi32> //=3 + %result_urem = arith.remui %result_udiv, %result_srem: vector<16x32xi32> //=0 + %result = arith.addi %result_srem, %result_urem: vector<16x32xi32> //=3 + %store_tile = xetile.init_tile %arg2[%m, %n] : memref<1024x1024xi32> -> !xetile.tile<16x32xi32> + xetile.store_tile %result, %store_tile: vector<16x32xi32>, !xetile.tile<16x32xi32> + gpu.return + } + } + + func.func @main() attributes {llvm.emit_c_interface} { + %A = memref.get_global @__constant_5_1024x1024xi32 : memref<1024x1024xi32> + %B = memref.get_global @__constant_2_1024x1024xi32 : memref<1024x1024xi32> + + %c0_i32 = arith.constant 0 : i32 + + %result = call @eltwise_int_test(%A, %B) : (memref<1024x1024xi32>, memref<1024x1024xi32>) -> memref<1024x1024xi32> + %result_cast = memref.cast %result : memref<1024x1024xi32> to memref<*xi32> + // CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}} + // CHECK-COUNT-1048576: 3 + call @printMemrefI32(%result_cast) : (memref<*xi32>) -> () + + return + } + func.func private @printMemrefI32(memref<*xi32>) attributes {llvm.emit_c_interface} +}