Skip to content

Commit

Permalink
[XeTile] Add support for scattered ops (#932)
Browse files Browse the repository at this point in the history
Add support for scattered ops
- extend attr with scattered flag
- update init_tile to create scattered tile
- update update_tile_offset to work on scattered tile
- add LoadGather/StoreScatter to load/store on scattered tile
  • Loading branch information
chencha3 authored Oct 17, 2024
1 parent 8ae485b commit c4210b3
Show file tree
Hide file tree
Showing 64 changed files with 422 additions and 372 deletions.
33 changes: 16 additions & 17 deletions docs/rfcs/XeTile.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ XeTile provides a middle-level abstraction for matmul operation and sits between
|init_tile | operation ::= xetile.init_tile $base_memref $offset0, $offset1: type($base_memref), index, index, attr-dict-> type($tile, attr-dict) | %block = xetile.init_tile %base_memref, %tile_offset:2 memref<128x128xbf16> into tile<8x16xbf16> |
|load_tile | operation ::=xetile.load_tile $tile attr-dict:type($tile) ->type($res) | %vector_a = xetile.load_tile %tile_a {padding=0} : tile<64x32xbf16> into vector<32x64xbf16>|
|store_tile | operation ::=xetile.store_tile $value, $tile attr-dict: type($value), type($tile) | xetile.store_tile %tile_a, %vector_a: vector<64x64xbf16> into tile<64x64xbf16> |
|update_tile_offset | operation ::=xetile.update_tile_offset $tile, $delta0, $delta1: type($tile), index, index-> type($tile) | %tdesc_updated = xetile.update_nd_offset %tdesc, %offset_x, offset_y tensor_desc<32x64xbf16>, index, index -> tensor_desc<32x64xbf16> |
|update_tile_offset | operation ::=xetile.update_tile_offset $tile, $delta0, $delta1: type($tile) | %tdesc_updated = xetile.update_nd_offset %tdesc, %offset_x, offset_y tensor_desc<32x64xbf16> |
|prefetch_tile | operation ::=xetile.prefetch_tile $tile, attr-dict: type($tile) | xetile.prefetch_tile %coop_tile: tile<16x32xbf16> |
|tile_mma | operation ::=xetile.tile_mma $matA, $matB, $matC attr_dict: type($matC), type($matA), type($matB)-> type($res) | %vector_c = xetile.tile_mma %vector_a, %vector_b, %vector_c : vector<64x32xbf16>, vector<32x128xbf16>, vector<64x128xfloat> into vector<64x128xfloat> |
|atomic_rmw_tile| operation ::=xetile.atomic_rmw_tile \<$kind\>, $vec, $tile: type($vec), type($tile) -> type($res) | %vector_a = xetile.atomic_rmw_tile \<add\> %value, %tile: vector<8x16xbf16>, tile<8x16xbf16> to vector<8x16xbf16> |
Expand Down Expand Up @@ -129,8 +129,7 @@ A `tile_mma` variant without vector_c initialization.
```
`update_tile_offset` updates tile with offset_x and offset_y, to move the current tile to a new position. These offsets are relative offset to the current position and counted in the number of elements. Usually only one value is needed to update since the tile is only moving along the K dimension. Users should avoid initializing new tiles repeatedly. For best performance, the user should only initialize one tile as a base tile and update the tile offset to move to a new tile.
```mlir
%tile_updated = xetile.update_tile_offset %tile, %offset_x, offset_y :
tile<64x64xbf16>, index, index into tile <64x64xbf16>
%tile_updated = xetile.update_tile_offset %tile, %offset_x, offset_y : tile<64x64xbf16>
```


Expand Down Expand Up @@ -476,10 +475,10 @@ func.func @test_gemm(%a : memref<4096x4096xf16>,
     prefetch_tile %1 : tile<256x32xf16, #mp_a_pfh>               // sg_layout=[32,1]
          prefetch_tile %2  : tile<32x256xf16, #mp_a_pfh>              // sg_layout=[4,8]
          %6 = tile_mma %4, %5 {#mp_a #mp_b #mp_c} %4, %10 : (vector<256x32xf16>, vector<32x256xf16>) -> vector<256x256xf32> //sg_layout=[8,4]
          %1 = update_tile_offset   %1, %c0, %c32 :  tile<256x32xf16, #mp_a> -> tile<256x32xf16, #mp_a>
          %2 = update_tile_offset   %2, %c32, %c0 :  tile<32x256xf16, #mp_b> -> tile<256x32xf16, #mp_b>
          %1p = update_tile_offset   %1p, %c0, %c32 :  tile<256x32xf16, #mp_a_pft> -> tile<256x32xf16, #mp_a_pft>
          %2p = update_tile_offset   %2p, %c32, %c0 :  tile<32x256xf16, #mp_b_pft> -> tile<32x256xf16, #mp_b_pft>
          %1 = update_tile_offset   %1, %c0, %c32 :  tile<256x32xf16, #mp_a>
          %2 = update_tile_offset   %2, %c32, %c0 :  tile<32x256xf16, #mp_b>
          %1p = update_tile_offset   %1p, %c0, %c32 :  tile<256x32xf16, #mp_a_pft>
          %2p = update_tile_offset   %2p, %c32, %c0 :  tile<32x256xf16, #mp_b_pft>
        } 
  store_tile %3, %6: (tile<256x256xf32, #mp_c>, vector<256x256xf32>)           // sg_layout=[8, 4]
   } 
Expand Down Expand Up @@ -530,10 +529,10 @@ func.func @test_gemm(%a : memref<4096x4096xf16>,
     prefetch_tile %1 : tile<256x32xf16, #mp_a_pfh>               // sg_layout=[32,1]
          prefetch_tile %2  : tile<256x32xf16, #mp_a_pfh>              // sg_layout=[32,1]
          %6 = tile_mma %4, %5 {#mp_a #mp_b #mp_c} : (vector<256x32xf16>, vector<32x256xf16>) -> vector<256x256xf32> //sg_layout=[8,4]
          %1 = update_tile_offset   %1, %c0, %c32 :  tile<256x32xf16, #mp_a> -> tile<256x32xf16, #mp_a>
          %2 = update_tile_offset   %2, %c0, %c32 :  tile<256x32xf16, #mp_bt> -> tile<256x32xf16, #mp_bt>
          %1p = update_tile_offset   %1p, %c0, %c32 :  tile<256x32xf16, #mp_a_pft> -> tile<256x32xf16, #mp_a_pft>
          %2p = update_tile_offset   %2p, %c32, %c0 :  tile<256x32xf16, #mp_bt_pft> -> tile<256x32xf16, #mp_bt_pft>
          %1 = update_tile_offset   %1, %c0, %c32 :  tile<256x32xf16, #mp_a>
          %2 = update_tile_offset   %2, %c0, %c32 :  tile<256x32xf16, #mp_bt>
          %1p = update_tile_offset   %1p, %c0, %c32 :  tile<256x32xf16, #mp_a_pft>
          %2p = update_tile_offset   %2p, %c32, %c0 :  tile<256x32xf16, #mp_bt_pft>
        } 
        %12  = load_tile %7  : tile<1x256xf32, #mp_bcast> -> vector<1x256xf16>     // sg_layout=[8, 4], sg_data=[1,64]
Expand Down Expand Up @@ -631,12 +630,12 @@ func.func @test_gemm(%a : memref<4096x4096xf16>,
%slm_offset = add %slm_offset, %c32
%slm_offset = mod %slm_offset, %c128
%a1_load = update_tile_offset %a1_load, %c0, %slm_offset : tile<256x32xf16, #mp_a> -> tile<256x32xf16, #mp_a>
%b1_load = update_tile_offset %b1_load, %slm_offset, %c0 : tile<32x256xf16, #mp_b> -> tile<256x32xf16, #mp_b>
%a4_glb = update_tile_offset %a4_glb, %c0, %c32 : tile<256x32xf16, #mp_a_pft> -> tile<256x32xf16, #mp_a_pft>
%b4_glb = update_tile_offset %b4_glb, %c32, %c0 : tile<32x256xf16, #mp_b_pft> -> tile<32x256xf16, #mp_b_pft>
%a4_slm’ = update_tile_offset %a4_slm, %c0, %slm_offset: tile<256x32xf16, #mp_a_pft> -> tile<256x32xf16, #mp_a_pft>
%b4_slm’ = update_tile_offset %b4_slm, %slm_offset, %c0 : tile<32x256xf16, #mp_b_pft> -> tile<32x256xf16,#mp_b_pft>
%a1_load = update_tile_offset %a1_load, %c0, %slm_offset : tile<256x32xf16, #mp_a>
%b1_load = update_tile_offset %b1_load, %slm_offset, %c0 : tile<32x256xf16, #mp_b>
%a4_glb = update_tile_offset %a4_glb, %c0, %c32 : tile<256x32xf16, #mp_a_pft>
%b4_glb = update_tile_offset %b4_glb, %c32, %c0 : tile<32x256xf16, #mp_b_pft>
%a4_slm’ = update_tile_offset %a4_slm, %c0, %slm_offset: tile<256x32xf16, #mp_a_pft>
%b4_slm’ = update_tile_offset %b4_slm, %slm_offset, %c0 : tile<32x256xf16, #mp_b_pft>
%c_r = tile_mma %a1_rr, %b1_rr #mp_a #mp_b #mp_c:
(vector<256x32xf16>, vector<32x256xf16>) -> vector<256x256xf32> // sg_layout=[8,8], sg_data=[32,32]
Expand Down
23 changes: 14 additions & 9 deletions include/imex/Dialect/XeTile/IR/XeTileAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def XeTile_TileAttr : XeTile_Attr<"XeTile", "tile_attr"> {
OptionalParameter<"xetile::WorkGroupMapAttr">:$wg_map,
DefaultValuedParameter<"mlir::DenseI32ArrayAttr", "mlir::DenseI32ArrayAttr::get($_ctxt, {1, 0})">:$order,
OptionalParameter<"mlir::DenseI64ArrayAttr">:$inner_blocks,
OptionalParameter<"mlir::Attribute">:$memory_space
OptionalParameter<"mlir::Attribute">:$memory_space,
OptionalParameter<"mlir::BoolAttr">:$scattered
);
let assemblyFormat = "`<` struct(params) `>`";
let genVerifyDecl = true;
Expand All @@ -73,31 +74,35 @@ def XeTile_TileAttr : XeTile_Attr<"XeTile", "tile_attr"> {
CArg<"xetile::WorkGroupMapAttr", "{}">:$wg_map,
CArg<"llvm::ArrayRef<int32_t>", "{1, 0}">:$order,
CArg<"llvm::ArrayRef<int64_t>", "{}">:$inner_blocks,
CArg<"int", "0">:$memory_space),
CArg<"int", "0">:$memory_space,
CArg<"bool", "false">:$scattered),
[{
mlir::Type intType = mlir::IntegerType::get($_ctxt, 32);
mlir::BoolAttr scatteredAttr = mlir::BoolAttr::get($_ctxt, scattered);
return $_get($_ctxt, sg_map, wg_map, mlir::DenseI32ArrayAttr::get($_ctxt, order),
mlir::DenseI64ArrayAttr::get($_ctxt, inner_blocks),
mlir::IntegerAttr::get(intType, memory_space));
mlir::IntegerAttr::get(intType, memory_space), scatteredAttr);
}]>,
AttrBuilder<(ins CArg<"llvm::ArrayRef<int32_t>", "{1, 0}">:$order,
CArg<"int", "0">:$memory_space),
CArg<"int", "0">:$memory_space, CArg<"bool", "false">:$scattered),
[{
mlir::Type intType = mlir::IntegerType::get($_ctxt, 32);
mlir::BoolAttr scatteredAttr = mlir::BoolAttr::get($_ctxt, scattered);
return $_get($_ctxt, xetile::SubGroupMapAttr(), xetile::WorkGroupMapAttr(),
mlir::DenseI32ArrayAttr::get($_ctxt, order),
mlir::DenseI64ArrayAttr::get($_ctxt, {}),
mlir::IntegerAttr::get(intType, memory_space));
mlir::IntegerAttr::get(intType, memory_space), scatteredAttr);
}]>,
AttrBuilder<(ins CArg<"xetile::SubGroupMapAttr", "{}">:$sg_map,
CArg<"xetile::WorkGroupMapAttr", "{}">:$wg_map,
CArg<"llvm::ArrayRef<int32_t>", "{1, 0}">:$order,
CArg<"int", "0">:$memory_space),
CArg<"int", "0">:$memory_space, CArg<"bool", "false">:$scattered),
[{
mlir::Type intType = mlir::IntegerType::get($_ctxt, 32);
return $_get($_ctxt, sg_map, wg_map, mlir::DenseI32ArrayAttr::get($_ctxt, order),
mlir::Type intType = mlir::IntegerType::get($_ctxt, 32);
mlir::BoolAttr scatteredAttr = mlir::BoolAttr::get($_ctxt, scattered);
return $_get($_ctxt, sg_map, wg_map, mlir::DenseI32ArrayAttr::get($_ctxt, order),
mlir::DenseI64ArrayAttr::get($_ctxt, {}),
mlir::IntegerAttr::get(intType, memory_space));
mlir::IntegerAttr::get(intType, memory_space), scatteredAttr);
}]>
];
}
Expand Down
86 changes: 66 additions & 20 deletions include/imex/Dialect/XeTile/IR/XeTileOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class XeTile_Op<string mnemonic, list<Trait> traits = []> :
Op<XeTile_Dialect, mnemonic, traits>;

def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments,
ViewLikeOpInterface, OffsetSizeAndStrideOpInterface]> {
OffsetSizeAndStrideOpInterface]> {
let summary = "Describes an XeTile with reference to a base memref";
let description = [{
The "init_tile" operation is used to describe a 2D region (i.e. tile) in gloabl memory.
Expand Down Expand Up @@ -109,20 +109,19 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments,
Variadic<Index>: $offsets,
Variadic<Index>: $sizes,
Variadic<Index>: $strides,
DenseI64ArrayAttr: $const_offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_sizes,
OptionalAttr<DenseI64ArrayAttr>: $const_strides
);
OptionalAttr<DenseI64ArrayAttr>: $const_strides,
Optional<VectorOfRankAndType<[1,2], [Index]>>: $indices);

let results = (outs XeTile: $tile);


let assemblyFormat = [{
$source ``
custom<DynamicIndexList>($offsets, $const_offsets)
$source (`,` $indices^):(``custom<DynamicIndexList>($offsets, $const_offsets))?
(`,` custom<DynamicIndexList>($sizes, $const_sizes)^
`,` custom<DynamicIndexList>($strides, $const_strides))?
attr-dict `:` type($source) `->` qualified(type($tile))
attr-dict `:` type($source) (`,` type($indices)^)? `->` qualified(type($tile))
}];

let builders = [
Expand All @@ -135,7 +134,11 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments,
"mlir::Value":$source,
"llvm::ArrayRef<mlir::OpFoldResult>":$offsets,
"llvm::ArrayRef<mlir::OpFoldResult>":$sizes,
"llvm::ArrayRef<mlir::OpFoldResult>":$strides)>
"llvm::ArrayRef<mlir::OpFoldResult>":$strides)>,
// creating init_tile op for scattered operation
OpBuilder<(ins "xetile::TileType":$resultType,
"mlir::TypedValue<mlir::MemRefType>": $source,
"mlir::TypedValue<mlir::VectorType>": $indices)>
];

let extraClassDeclaration = [{
Expand Down Expand Up @@ -202,11 +205,15 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments,

/// Get static offsets.
llvm::ArrayRef<int64_t> getStaticOffsets() {
return getConstOffsets();
if (getConstOffsets().has_value())
return getConstOffsets().value();
return llvm::ArrayRef<int64_t>();
}

/// Get the static sizes.
llvm::ArrayRef<int64_t> getStaticSizes() {
if (getIndices())
return llvm::ArrayRef<int64_t>();
if (getConstSizes().has_value())
return getConstSizes().value();
// At this point, the source must be a memref with static shape.
Expand All @@ -216,6 +223,8 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments,

/// Get the static strides.
llvm::ArrayRef<int64_t> getStaticStrides() {
if (getIndices())
return llvm::ArrayRef<int64_t>();
if (getConstStrides().has_value())
return getConstStrides().value();
// At this point, the source must be a memref with static shape.
Expand Down Expand Up @@ -250,6 +259,11 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments,
/// Return the expected rank of each of the`static_offsets`,
/// `static_shape` and `static_strides` attributes.
std::array<unsigned, 3> getArrayAttrMaxRanks() {
// for scattered tile, the static_offsets, static_shape and
// static_strides are not used. Their ranks are expected to be 0.
if (getIndices())
return {0, 0, 0};

unsigned rank;
if (auto ty = llvm::dyn_cast<mlir::MemRefType>(getSourceType())) {
rank = ty.getRank();
Expand Down Expand Up @@ -457,7 +471,8 @@ def XeTile_TileMMAOp : XeTile_Op<"tile_mma", []> {
let hasVerifier = 1;
}

def XeTile_UpdateTileOffsetOp : XeTile_Op<"update_tile_offset", []> {
def XeTile_UpdateTileOffsetOp : XeTile_Op<"update_tile_offset", [AttrSizedOperandSegments,
AllTypesMatch<["tile", "result"]>]> {
let summary = "update the offsets of a tile";
let description = [{
"update_tile_offset" operation is used for iterating over the tiles. It takes in a
Expand All @@ -470,27 +485,22 @@ def XeTile_UpdateTileOffsetOp : XeTile_Op<"update_tile_offset", []> {

Example 1:
```mlir
xetile.update_tile_offset %tile, [%offset_x, %offset_y]
: tile<32x32xf32>, index, index
xetile.update_tile_offset %tile, [%offset_x, %offset_y] : tile<32x32xf32>
```
}];

let arguments = (ins
XeTile: $tile,
Index: $offset_x,
Index: $offset_y
);
Optional<Index>: $offset_x,
Optional<Index>: $offset_y,
Optional<FixedVectorOfRankAndType<[1], [Index]>>:$indices);

let results = (outs
XeTile: $result
);

let assemblyFormat = [{
$tile `,` ` ` `[` $offset_x `,` ` ` $offset_y `]` ` ` attr-dict `:`
qualified(type($tile)) `,`
qualified(type($offset_x)) `,`
qualified(type($offset_y)) ` `
`->` qualified(type($result))
$tile`,`` `(``$indices^):(`[` $offset_x `,` $offset_y `]` )? attr-dict `:` qualified(type($tile)) (`,` type($indices)^)?
}];
}

Expand Down Expand Up @@ -617,6 +627,42 @@ def XeTile_BroadcastOp: XeTile_Op<"broadcast", []> {
let hasVerifier = 1;
}

def XeTile_LoadGatherOp: XeTile_Op<"load", [AllElementTypesMatch<["tile", "value"]>,
AllShapesMatch<["tile", "value", "mask"]>]> {
let summary = "load a set of scattered data points from memory.";
let description = [{
The `load` operation is used to load data with scattered tile (each element in the tile
is interpreted as location of the data). the `mask` operand masks out memory access so
that it is safe to pass out-of-boundary addresses/offsets as long as they are masked.
In this case, the value specified in the padding attribute will be returned. The default
padding value is zero.}];

let arguments = (ins XeTile: $tile,
XeTile_MaskType: $mask,
OptionalAttr<XeTile_PaddingValueAttr>: $padding);
let results = (outs XeTile_1DOr2DVector: $value);
let assemblyFormat = [{
$tile `` `,` $mask attr-dict `:` qualified(type($tile)) `` `,` type($mask) `->` type($value)
}];
}

def XeTile_StoreScatterOp: XeTile_Op<"store", [AllElementTypesMatch<["value", "tile"]>,
AllShapesMatch<["value", "tile", "mask"]>]> {
let summary = "load a set of data to scattered memory locations.";
let description = [{
The `store` operation is used to store data into scattered tile (each element in the tile
is interpreted as location, one location per data element). the `mask` operand masks out
memory access so that it is safe to pass out-of-boundary addresses/offsets as long as they
are masked.
}];
let arguments = (ins XeTile_1DOr2DVector: $value,
XeTile: $tile,
XeTile_MaskType: $mask);
let assemblyFormat = [{
$value `,` $tile `,` $mask attr-dict `:` type($value) `,` qualified(type($tile)) `,` type($mask)
}];
}



#endif // _XETILE_OPS_TD_INCLUDED_
18 changes: 14 additions & 4 deletions include/imex/Dialect/XeTile/IR/XeTileTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ def XeTile : XeTile_Type<"Tile", "tile", [ShapedTypeInterface],
return 0;
}

mlir::BoolAttr getScatterAttr() {
auto encoding = llvm::dyn_cast_if_present<xetile::XeTileAttr>(getEncoding());
if (encoding)
return encoding.getScattered();
return mlir::BoolAttr();
}

}];

let assemblyFormat = "`<` custom<XeTileType>($shape, $elementType, $encoding) `>`";
Expand All @@ -151,19 +158,22 @@ def XeTile_FloatType : AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
// Define the scalar type for XeTile
def XeTile_ScalarType : AnyTypeOf<[XeTile_IntType, XeTile_FloatType]>;

// define a 1D or 2D memref of XeTile scalar type
def XeTile_DynamicMemref : Non0RankedMemRefOf<[XeTile_ScalarType]>;

// define the source type for XeTile init_tile
def XeTile_BaseAddrType : AnyTypeOf<[XeTile_DynamicMemref, UI64, UI32, I64, I32]>;
def XeTile_BaseAddrType : AnyTypeOf<[MemRefOf<[XeTile_ScalarType]>, UI64, UI32, I64, I32]>;

// input and output types needed for pack and unpack ops
// def XeTile_1DVector : VectorOfRankAndType<[1], [XeTile_ScalarType]>;
def XeTile_2DVector : VectorOfRankAndType<[2], [XeTile_ScalarType]>;
def XeTile_4DVector : VectorOfRankAndType<[4], [XeTile_ScalarType]>;

// define the value type for XeTile load_gather and store_scatter op
def XeTile_1DOr2DVector: VectorOfRankAndType<[1, 2], [XeTile_ScalarType]>;

// define the value type for XeTile load_tile and store_tile op
def XeTile_2DOr4DVector: VectorOfRankAndType<[2, 4], [XeTile_ScalarType]>;

def XeTile_MaskType: VectorOfRankAndType<[1, 2], [I1]>;

// define the attribute type allowed for padding values for load op
def XeTile_PaddingValueAttr : AnyAttrOf<[I32Attr, F32Attr]>;

Expand Down
Loading

0 comments on commit c4210b3

Please sign in to comment.