Skip to content

Commit

Permalink
cast-index pass: Use index.casts instead of index.castu.
Browse files Browse the repository at this point in the history
  • Loading branch information
silee2 committed Oct 10, 2024
1 parent 9c85f18 commit 78b5506
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
4 changes: 2 additions & 2 deletions lib/Transforms/CastIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct CastIndexPass : public imex::impl::CastIndexBase<CastIndexPass> {
// Replace index type operands with cast op from
// index to i32 type.
if (oper.getType().isIndex()) {
auto newOp = builder.create<index::CastUOp>(
auto newOp = builder.create<index::CastSOp>(
o->getLoc(), builder.getI32Type(), oper);
o->setOperand(idx, newOp);
}
Expand All @@ -81,7 +81,7 @@ struct CastIndexPass : public imex::impl::CastIndexBase<CastIndexPass> {
res.setType(builder.getI32Type());
builder.setInsertionPointAfter(o);
// Cast i32 type back to index type
auto newRes = builder.create<index::CastUOp>(
auto newRes = builder.create<index::CastSOp>(
o->getLoc(), builder.getIndexType(), res);
// Replace all uase of result with new cast op
res.replaceAllUsesExcept(newRes, newRes);
Expand Down
36 changes: 18 additions & 18 deletions test/Transforms/cast-index.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,35 @@ module @castindex attributes {gpu.container_module} {
%4 = gpu.thread_id y
// CHECK: %[[TID_Z:.*]] = gpu.thread_id z
%5 = gpu.thread_id z
// CHECK: %[[VAR_0_1:.*]] = index.castu %[[BID_X]] : index to i32
// CHECK: %[[VAR_3_1:.*]] = index.castu %[[TID_X]] : index to i32
// CHECK: %[[VAR_0_1:.*]] = index.casts %[[BID_X]] : index to i32
// CHECK: %[[VAR_3_1:.*]] = index.casts %[[TID_X]] : index to i32
// CHECK: %[[VAR_6:.*]] = arith.divui %[[VAR_0_1]], %[[VAR_3_1]] : i32
// CHECK: %[[RE_6:.*]] = index.castu %[[VAR_6]] : i32 to index
// CHECK: %[[RE_6:.*]] = index.casts %[[VAR_6]] : i32 to index
%6 = arith.divui %0, %3 : index
// CHECK: %[[VAR_1_1:.*]] = index.castu %[[BID_Y]] : index to i32
// CHECK: %[[VAR_3_2:.*]] = index.castu %[[TID_X]] : index to i32
// CHECK: %[[VAR_1_1:.*]] = index.casts %[[BID_Y]] : index to i32
// CHECK: %[[VAR_3_2:.*]] = index.casts %[[TID_X]] : index to i32
// CHECK: %[[VAR_7:.*]] = arith.remui %[[VAR_1_1]], %[[VAR_3_2]] : i32
// CHECK: %[[RE_7:.*]] = index.castu %[[VAR_7]] : i32 to index
// CHECK: %[[RE_7:.*]] = index.casts %[[VAR_7]] : i32 to index
%7 = arith.remui %1, %3 : index
// CHECK: %[[VAR_2_1:.*]] = index.castu %[[BID_Z]] : index to i32
// CHECK: %[[VAR_5_1:.*]] = index.castu %[[TID_Z]] : index to i32
// CHECK: %[[VAR_2_1:.*]] = index.casts %[[BID_Z]] : index to i32
// CHECK: %[[VAR_5_1:.*]] = index.casts %[[TID_Z]] : index to i32
// CHECK: %[[VAR_8:.*]] = arith.muli %[[VAR_2_1]], %[[VAR_5_1]] : i32
// CHECK: %[[RE_8:.*]] = index.castu %[[VAR_8]] : i32 to index
// CHECK: %[[RE_8:.*]] = index.casts %[[VAR_8]] : i32 to index
%8 = arith.muli %2, %5 : index
// CHECK: %[[VAR_2_2:.*]] = index.castu %[[BID_Z]] : index to i32
// CHECK: %[[VAR_4_1:.*]] = index.castu %[[TID_Y]] : index to i32
// CHECK: %[[VAR_2_2:.*]] = index.casts %[[BID_Z]] : index to i32
// CHECK: %[[VAR_4_1:.*]] = index.casts %[[TID_Y]] : index to i32
// CHECK: %[[VAR_9:.*]], %[[VAR_19:.*]] = arith.mulsi_extended %[[VAR_2_2]], %[[VAR_4_1]] : i32
// CHECK: %[[RE_9:.*]] = index.castu %[[VAR_9]] : i32 to index
// CHECK: %[[RE_9:.*]] = index.casts %[[VAR_9]] : i32 to index
%9, %19 = arith.mulsi_extended %2, %4 : index
// CHECK: %[[VAR_0_2:.*]] = index.castu %[[BID_X]] : index to i32
// CHECK: %[[VAR_4_2:.*]] = index.castu %[[TID_Y]] : index to i32
// CHECK: %[[VAR_0_2:.*]] = index.casts %[[BID_X]] : index to i32
// CHECK: %[[VAR_4_2:.*]] = index.casts %[[TID_Y]] : index to i32
// CHECK: %[[VAR_10:.*]] = arith.divsi %[[VAR_0_2]], %[[VAR_4_2]] : i32
// CHECK: %[[RE_10:.*]] = index.castu %[[VAR_10]] : i32 to index
// CHECK: %[[RE_10:.*]] = index.casts %[[VAR_10]] : i32 to index
%10 = arith.divsi %0, %4 : index
// CHECK: %[[VAR_1_2:.*]] = index.castu %[[BID_Y]] : index to i32
// CHECK: %[[VAR_4_3:.*]] = index.castu %[[TID_Y]] : index to i32
// CHECK: %[[VAR_1_2:.*]] = index.casts %[[BID_Y]] : index to i32
// CHECK: %[[VAR_4_3:.*]] = index.casts %[[TID_Y]] : index to i32
// CHECK: %[[VAR_11:.*]] = arith.remsi %[[VAR_1_2]], %[[VAR_4_3]] : i32
// CHECK: %[[RE_11:.*]] = index.castu %[[VAR_11]] : i32 to index
// CHECK: %[[RE_11:.*]] = index.casts %[[VAR_11]] : i32 to index
%11 = arith.remsi %1, %4 : index
// CHECK: %[[VAR_12:.*]] = memref.load %[[ARG_0:.*]][%[[RE_6]], %[[RE_7]]] : memref<4x5xf16>
%12 = memref.load %arg0[%6, %7] : memref<4x5xf16>
Expand Down

0 comments on commit 78b5506

Please sign in to comment.