diff --git a/docs/rfcs/XeTile.md b/docs/rfcs/XeTile.md index ad72a440d..f05b468e1 100644 --- a/docs/rfcs/XeTile.md +++ b/docs/rfcs/XeTile.md @@ -19,18 +19,19 @@ XeTile provides a middle-level abstraction for matmul operation and sits between | Ops | Syntax | Example | | :--- | :---- | :--- | -|init_tile | operation ::= XeTile.init_tile $base_memref $offset0, $offset1: type($base_memref), index, index, attr-dict-> type($tile, attr-dict) | %block = XeTile.init_tile %base_memref, %tile_offset:2 memref<128x128xbf16> into tile<8x16xbf16> | -|load_tile | operation ::=XeTile.load_tile $tile attr-dict:type($tile) ->type($res) | %vector_a = XeTile.load_tile %tile_a {padding=0} : tile<64x32xbf16> into vector<32x64xbf16>| -|store_tile | operation ::=XeTile.store_tile $value, $tile attr-dict: type($value), type($tile) | XeTile.store_tile %tile_a, %vector_a: vector<64x64xbf16> into tile<64x64xbf16> | -|update_tile_offset | operation ::=XeTile.update_tile_offset $tile, $delta0, $delta1: type($tile), index, index-> type($tile) | %tdesc_updated = XeTile.update_nd_offset %tdesc, %offset_x, offset_y tensor_desc<32x64xbf16>, index, index -> tensor_desc<32x64xbf16> | -|prefetch_tile | operation ::=XeTile.prefetch_tile $tile, attr-dict: type($tile) | XeTile.prefetch_tile %coop_tile: tile<16x32xbf16> | -|tile_mma | operation ::=XeTile.tile_mma $matA, $matB, $matC attr_dict: type($matC), type($matA), type($matB)-> type($res) | %vector_c = XeTile.tile_mma %vector_a, %vector_b, %vector_c : vector<64x32xbf16>, vector<32x128xbf16>, vector<64x128xfloat> into vector<64x128xfloat> | -|atomic_rmw_tile| operation ::=XeTile.atomic_rmw_tile \<$kind\>, $vec, $tile: type($vec), type($tile) -> type($res) | %vector_a = atomic_rmw_tile \ %value, %tile: vector<8x16xbf16>, tile<8x16xbf16> to vector<8x16xbf16> | -|tile_transpose | operation ::=XeTile.tile_transpose $vec $permuation_dims attr_dict: type($vec) -> type($res) | %vector_a = XeTile.tile_transpose %vector_b [1, 0]: vector<64x32xfloat> into vector<32x64xfloat> | -|tile_reduce | operation ::=XeTile.tile_reduce \<$kind\> $src $reduction_dims attr_dict: type($value) -> type($res) | %vector_a = XeTile.tile_reduce \ %vector_b [1]: vector<64x32xfloat> into vector<64x1xfloat> | -|tile_broadcast | operation ::=XeTile.tile_broadcast $src $broadcast_dims attr_dict: type($value) -> type($res) | %vector_a = XeTile.tile_broadcast %vector_b[0]: vector<1x32xfloat> into vector<64x32xfloat> | -|tile_pack* | operation ::=XeTile.tile_pack $matA attr_dict: type($value) -> type($res) | %vector_a = XeTile.tile_pack %vector_b inner_blocks=array : vector<64x32xfloat> into vector<4x2x16x16xfloat> | -|tile_unpack* | operation ::=XeTile.tile_upack $matA attr_dict: type($value) -> type($res) | %vector_a = XeTile.tile_unpack %vector_b inner_blocks=array : vector<1x2x64x16xfloat> into vector<64x32xbf16> | + +|init_tile | operation ::= xetile.init_tile $base_memref $offset0, $offset1: type($base_memref), index, index, attr-dict-> type($tile, attr-dict) | %block = xetile.init_tile %base_memref, %tile_offset:2 memref<128x128xbf16> into tile<8x16xbf16> | +|load_tile | operation ::=xetile.load_tile $tile attr-dict:type($tile) ->type($res) | %vector_a = xetile.load_tile %tile_a {padding=0} : tile<64x32xbf16> into vector<32x64xbf16>| +|store_tile | operation ::=xetile.store_tile $value, $tile attr-dict: type($value), type($tile) | xetile.store_tile %tile_a, %vector_a: vector<64x64xbf16> into tile<64x64xbf16> | +|update_tile_offset | operation ::=xetile.update_tile_offset $tile, $delta0, $delta1: type($tile), index, index-> type($tile) | %tdesc_updated = xetile.update_nd_offset %tdesc, %offset_x, offset_y tensor_desc<32x64xbf16>, index, index -> tensor_desc<32x64xbf16> | +|prefetch_tile | operation ::=xetile.prefetch_tile $tile, attr-dict: type($tile) | xetile.prefetch_tile %coop_tile: tile<16x32xbf16> | +|tile_mma | operation ::=xetile.tile_mma $matA, $matB, $matC attr_dict: type($matC), type($matA), type($matB)-> type($res) | %vector_c = xetile.tile_mma %vector_a, %vector_b, %vector_c : vector<64x32xbf16>, vector<32x128xbf16>, vector<64x128xfloat> into vector<64x128xfloat> | +|atomic_rmw_tile| operation ::=xetile.atomic_rmw_tile \<$kind\>, $vec, $tile: type($vec), type($tile) -> type($res) | %vector_a = xetile.atomic_rmw_tile \ %value, %tile: vector<8x16xbf16>, tile<8x16xbf16> to vector<8x16xbf16> | +|tile_transpose | operation ::=xetile.tile_transpose $vec $permuation_dims attr_dict: type($vec) -> type($res) | %vector_a = xetile.tile_transpose %vector_b [1, 0]: vector<64x32xfloat> into vector<32x64xfloat> | +|tile_reduce | operation ::=xetile.tile_reduce \<$kind\> $src $reduction_dims attr_dict: type($value) -> type($res) | %vector_a = xetile.tile_reduce \ %vector_b [1]: vector<64x32xfloat> into vector<64x1xfloat> | +|tile_broadcast | operation ::=xetile.tile_broadcast $src $broadcast_dims attr_dict: type($value) -> type($res) | %vector_a = xetile.tile_broadcast %vector_b[0]: vector<1x32xfloat> into vector<64x32xfloat> | +|tile_pack* | operation ::=xetile.tile_pack $matA attr_dict: type($value) -> type($res) | %vector_a = xetile.tile_pack %vector_b {inner_blocks=array} : vector<64x32xfloat> into vector<4x2x16x16xfloat> | +|tile_unpack* | operation ::=xetile.tile_upack $matA attr_dict: type($value) -> type($res) | %vector_a = xetile.tile_unpack %vector_b {inner_blocks=array} : vector<1x2x64x16xfloat> into vector<64x32xbf16> | *Operations only used to support internal lowering. @@ -42,17 +43,17 @@ To create a 2D Tile memory descriptor, the user needs to set up a tile (init_til `init_tile` with memref of static shape. Tile uses memref’s shape and strides as base_shape and base_strides. ```mlir - %tile0 = XeTile.init_tile %base_memref, [%tile_offset:2] : + %tile0 = xetile.init_tile %base_memref, [%tile_offset:2] : memref<128x128xbf16> into tile<8x16xbf16> ``` `init_tile` with memref of dynamic shape. The memref has a dynamic shape, so that its shape and strides have to be passed as runtime parameters to init_tile. ```mlir - %tile0 = XeTile.init_tile %base_memref, [%tile_offset:2], [%base_shape:2], [%base_strides:2]: + %tile0 = xetile.init_tile %base_memref, [%tile_offset:2], [%base_shape:2], [%base_strides:2]: memref into tile<8x16xbf16> ``` `init_tile` with an address for the base matrix. This form is to support the use case which doesn’t use a memref to describe the base matrix. ```mlir - %tile0 = XeTile.init_tile %base_addr, [%tile_offset:2], [%base_shape:2], [%base_strides:2]: + %tile0 = xetile.init_tile %base_addr, [%tile_offset:2], [%base_shape:2], [%base_strides:2]: i64 into tile<8x16xbf16> ``` @@ -60,22 +61,21 @@ To create a 2D Tile memory descriptor, the user needs to set up a tile (init_til ```mlir #tile_attr = #xetile.tile_attr - %tile0 = XeTile.init_tile %base_memref, [%tile_offset:2]: + %tile0 = xetile.init_tile %base_memref, [%tile_offset:2]: memref<128x128xbf16, affine_map=<(d0, d1)->(d1, d0)> into tile<64x32xbf16, #tile_attr> ``` - With the tile date type, XeTile supports load_tile, prefetch_tile, and store_tile. `load_tile` loads a tile to a 2D vector, which could be backed by a register region. ```mlir - %vector_a = XeTile.load_tile %tile_a : + %vector_a = xetile.load_tile %tile_a : tile<64x64xbf16> into vector<64x64xb16> ``` Attribute `padding` specifies the padding value for the out-of-boundary access. The default value is zero. ```mlir - %vector_a = XeTile.load_tile %tile_a {padding = 1.0} : + %vector_a = xetile.load_tile %tile_a {padding = 1.0} : tile<64x64xbf16> into vector<64x64xb16> ``` `load_tile` needs to be used with the tile_mma. @@ -83,41 +83,41 @@ Attribute `padding` specifies the padding value for the out-of-boundary access. `load_tile` loads a tile according to the tile's `order` attribute. Regardless of the `order` attribute value, the vector's dimensions must match exactly the Tile's dimensions. ```mlir #tile_attr = #xetile.tile_attr - %vector_a = XeTile.load_tile %tile_a : + %vector_a = xetile.load_tile %tile_a : tile<64x32xbf16, #tile_attr> into vector<64x32xb16> ``` `store_tile` stores a vector to memory. Padding attributes are not supported. ```mlir - XeTile.store_tile %tile_a, %vector_a : + xetile.store_tile %tile_a, %vector_a : vector<64x64xbf16> into tile<64x64xbf16> ``` `store_tile` stores a tile according to the tile's `order` attribute. Regardless of the `order` attribute value, the vector's dimensions must match exactly the Tile's dimensions. ```mlir #tile_attr = #xetile.tile_attr - %vector_a = XeTile.store_tile %tile_a : + %vector_a = xetile.store_tile %tile_a : vector<64x32xb16> to tile<64x32xbf16, #tile_attr> ``` `prefetch_tile` prefetches the tile to cache. Just like memref.preftech, the locality hint ranges from locality<0> (no locality) to locality<3> (extremely local keep in cache). ```mlir - XeTile.prefetch_tile %coop_tile locality<3>: tile<8x32xbf16> + xetile.prefetch_tile %coop_tile locality<3>: tile<8x32xbf16> ``` `tile_mma` represents the matrix multiplication on 2D vectors. The semantics can be represented by vector.contract, so tile_mma works more like a syntax sugar. This also means that the code can be lowered to vector.contract and mapped to HW without DPAS support nicely. ```mlir - %vector_c = XeTile.tile_mma %vector_b, %vector_a, %vector_c: + %vector_c = xetile.tile_mma %vector_b, %vector_a, %vector_c: vector<64x32xbf16>, vector<32x128xbf16>, vector<64x128xfloat> into vector<64x128xfloat> ``` A `tile_mma` variant without vector_c initialization. ```mlir - %vector_c = XeTile.tile_mma %vector_a, %vector_b : + %vector_c = xetile.tile_mma %vector_a, %vector_b : vector<64x32xbf16>, vector<32x128xbf16> into vector<64x128xfloat> ``` `update_tile_offset` updates tile with offset_x and offset_y, to move the current tile to a new position. These offsets are relative offset to the current position and counted in the number of elements. Usually only one value is needed to update since the tile is only moving along the K dimension. Users should avoid initializing new tiles repeatedly. For best performance, the user should only initialize one tile as a base tile and update the tile offset to move to a new tile. ```mlir - %tile_updated = XeTile.update_tile_offset %tile, %offset_x, offset_y : + %tile_updated = xetile.update_tile_offset %tile, %offset_x, offset_y : tile<64x64xbf16>, index, index into tile <64x64xbf16> ``` @@ -125,23 +125,23 @@ A `tile_mma` variant without vector_c initialization. `atomic_rmw_tile` atomically reads, modifies, and writes back data to the memory specified by the tile. ```mlir - %ret_value = XeTile.atomic_rmw %value, %tile: + %ret_value = xetile.atomic_rmw %value, %tile: vector<8x16xbf16>, tile<8x16xbf16> to vector<8x16xbf16> ``` -XeTile.atomic_rmw reuses the arith dialect attribute, mlir::arith::AtomicRMWKindAttr. +xetile.atomic_rmw reuses the arith dialect attribute, mlir::arith::AtomicRMWKindAttr. `tile_transpose` transpose a 2D vector. It has the same semantics as the vector.transpose, but restricts the vector dimension to 2D. ```mlir - %vector_a = XeTile.tile_transpose [1, 0] %vector_b: vector<64x32xfloat> into vector<32x64xfloat> + %vector_a = xetile.tile_transpose [1, 0] %vector_b: vector<64x32xfloat> into vector<32x64xfloat> ``` `tile_reduce` performs a reduction operation over a 2D vector. The result is a 2D vector with the size of reduced axis being 1. It has the same semantics as the vector.multi_dimesnion, but restricts the vector dimension to 2D. The reduce operation are the same as vector.multi_dimension:add/mul/minsi/minui/maxsi/maxui /and/or/xor for integers, and add/mul/minnumf/maxnumf/minimumf /maximumf for floats. ```mlir - %vector_a = XeTile.tile_reduce %vector_b [1]: vector<64x32xfloat> into vector<64x1xfloat> + %vector_a = xetile.tile_reduce %vector_b [1]: vector<64x32xfloat> into vector<64x1xfloat> ``` `tile_broadcast` broadcast from 1D vector to a 2D vector. ```mlir - %vector_a = XeTile.tile_broadcast %vector_b [0]: vector<1x32xfloat> into vector<64x32xfloat> + %vector_a = xetile.tile_broadcast %vector_b [0]: vector<1x32xfloat> into vector<64x32xfloat> ``` ## Internal Operations to support gradual lowering @@ -151,20 +151,20 @@ The 2D XeTile IR needs to be lowered in an intermediate form to support `blockin ```mlir #tile_attr = #xetile.tile_attr - %tile0 = XeTile.init_tile %base_memref, [%tile_offset:2]: + %tile0 = xetile.init_tile %base_memref, [%tile_offset:2]: memref<128x128xbf16> into tile<64x32xbf16, #tile_attr> ``` `load_tile` loads a 2D tile with an `inner_block` attribute to 4D vector. ```mlir #tile_attr = #xetile.tile_attr - %vector_a = XeTile.load_tile %tile_a : + %vector_a = xetile.load_tile %tile_a : tile<64x32xbf16, #tile_attr> into vector<4x2x16x16xb16> ``` `store_tile` stores a 4D vector to a 2D tile with an `inner_block`. ```mlir #tile_attr = #xetile.tile_attr - XeTile.store_tile %vector_a, %tile_a : + xetile.store_tile %vector_a, %tile_a : vector<4x2x16x16xb16> into tile<64x32xbf16, #tile_attr> ``` `atomic_rmw_tile` performs atomic operation on 4D vectors. @@ -176,18 +176,18 @@ The 2D XeTile IR needs to be lowered in an intermediate form to support `blockin With the data being presented as 4D vector, all the vector based XeTile operations are required to support blocking. `tile_mma` works on 4D vectors. Since dimension 1 is split into dimensions 1 and 3, the reduction of matrix multiplication is along these two dimensions. ```mlir - %vector_c = XeTile.tile_mma %vector_a, %vector_b, %vector_c : + %vector_c = xetile.tile_mma %vector_a, %vector_b, %vector_c : vector<8x4x8x8xbf16>, vector<4x8x8x16xbf16>, vector<8x8x8x16xfloat> into vector<8x8x8x16xfloat> ``` `tile_reduce` follows the vector.multi-reduction semantics and can be applied to 4D vector. The tile_reduce on 4D vector is an internal operation and only used in the transformation passes to support gradual lowering. ```mlir - %vector_a = XeTile.tile_reduce %vector_b [1, 3]: vector<8x4x8x16xfloat> into vector<8x1x8x1float> + %vector_a = xetile.tile_reduce %vector_b [1, 3]: vector<8x4x8x16xfloat> into vector<8x1x8x1float> ``` `tile_broadcast` broadcast 4D vector. The input is expected to be first reshaped from 1D vector to 2D vector, and then blocked to 4D. ```mlir - %vector_a = XeTile.tile_broadcast %vector_b [1, 3]: vector<8x1x8x1xfloat> into vector<8x4x8x16xfloat> + %vector_a = xetile.tile_broadcast %vector_b [1, 3]: vector<8x1x8x1xfloat> into vector<8x4x8x16xfloat> ``` `tile_transpose` doesn't have support 4D vector. The transpose is usually implemented by saving and restoring from the share local memory. To support this, we relax the restriction of tile_load and tile_store so that they can load 2D from share local memory. @@ -197,17 +197,33 @@ With the data being presented as 4D vector, all the vector based XeTile operatio `tile_pack` packs a 2D vector, representing the loaded value from 2D tile, to a 4D vector with an inner block size. The 4D vector was introduced to support blocking to fit the hardware matrix operation sizes.  The blocking follows an implicit rule: out_dim[0] = in_dim[0]/inner_blocks[0] , out_dim[1] = in_dim[1]/inner_blocks[1], out_dim[2] = inner_blocks[0], and out_dim[3] = inner_blocks[1]. The dim[2] and dim[3] of result 4D vector must be same as the size of `inner_blocks` attribute. ```mlir - %0 = XeTile.tile_pack %1 {inner_blocks = array} + %0 = xetile.tile_pack %1 {inner_blocks = array} : vector<64x32xf32> -> vector<4x2x16x16xf32> ``` `tile_unpack` unpacks a 4D blocked vector back to original unpacked 2D vector. `tile_unpack` ```mlir - %0 = XeTile.tile_unpack %1 {inner_blocks = array} + %0 = xetile.tile_unpack %1 {inner_blocks = array} : vector<1x2x64x16xf32> -> vector<64x32xf32> ``` The tile_pack and tile_unpack operation is similar to pack and unpack operation of tensor dialect. The source vector must be a 2D dimension vector, and no permutation is allowed for the result 4D vector, so effectively the blocking effect is identical to tensor pack/unpack operation with inner_dims_pos = [0,1] inner_dims_pos = [0, 1]. +## support for load_gather and store_scatter (experimental) +`init_tile` can create a tile with each element's address being explictly specified. The tile is created with a base address and offsets for all elements to be loaded. The result tile has a `scatter` attribute to distinguish it from the regular tile. +```mlir + %tile0 = xetile.init_tile %base_addr, %tile_offsets: + i64, vector<1x256xindex> into tile<1x256xbf16, #scatter> +``` +`load_gather` (aka. load) loads data with prepared tile and mask. Attribute `padding` specifies the padding value for the out-of-boundary access. The default value is zero. +```mlir + %vector_a = xetile.load_gather %tile_0, %mask, {padding = 1.0} : + tile<1x256xbf16, #scatter> into vector<1x256xbf16> +``` +`store_scatter` stores a 2d vector to a 2D tile with `scatter` attribute. +```mlir + xetile.store_scatter %vector_a, %mask, %tile_0 : + vector<1x256xbf16> into tile<1x256xbf16, #scatter> +``` ## Workgroup Level XeTile extension (experimental) `xetile.wg_map` mapping attribute allows XeTile operation to work at the workgroup level. XeTile operations work by default at the subgroup level without wg_map attribute. With wg_map attributes, XeTile operations can be applied to workgroup-level tile sizes. The attribute `xetile.wg_map` guides the lowering from the workgroup level to the subgroup level by specifying how the data is distributed across parallel subgroups. It gives the user full control over the lowering process so that the user can tune the block size for both the workgroup and subgroup for optimal performance. @@ -215,7 +231,7 @@ The tile_pack and tile_unpack operation is similar to pack and unpack operation Below is an example. ```mlir #wg_map_a = #xetile.wg_map - #tile_attr = #xetile.tile_attr + #tile_attr = #xetile.tile_attr > %wg_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<128x128xf16, #tile_attr> ``` @@ -237,56 +253,57 @@ With the `xetile.wg_map` attribute being included in the tile data type, the til The proposal is to attach the `xetile.wg_map` attribute to the vector based XeTile operations as illustrated below. The attribute applies only to the output value of each operation. The input values `xetile.wg_map` are determined by their respective defining operations. | Ops | Syntax | Example | | :--- | :---- | :--- | -|tile_mma | operation ::=XeTile.tile_mma $matA, $matB, $matC attr_dict: type($matA), type($matB), type($matC)-> type($res) | %vector_c = XeTile.tile_mma %vector_a, %vector_b, %vector_c {#mp_c} : vector<64x32xbf16>, vector<32x128xbf16>, vector<64x128xfloat> into vector<64x128xfloat> | -|tile_transpose | operation ::=XeTile.tile_transpose $permuation_dims attr_dict $vec : type($vec) -> type($res) | %vector_a = XeTile.tile_transpose %vector_b {#mp_a}: vector<64x32xfloat> into vector<32x64xfloat> | -|tile_reduce | operation ::=XeTile.tile_reduce $kind $src $reduction_dims attr_dict: type($value) -> type($res) | %vector_a = XeTile.tile_reduce %vector_b [1] {#mp_a}: vector<64x32xfloat> into vector<64x1xfloat> | -|tile_broadcast | operation ::=XeTile.tile_broadcast $src $broadcast_dims attr_dict : type($value) -> type($res) | %vector_a = XeTile.tile_broadcast %vector_b [0] {#mp_a}: vector<1x32xfloat> into vector<64x32xfloat> | -|tile_conv_layout | operation ::=XeTile.conv_layout $src attr_dict: type($value) -> type($res) | %vector_a = XeTile.tile_conv_layout %vector_b {#mp_a} : vector<256x256xfloat> into vector<256x256xfloat> | + +|tile_mma | operation ::= xetile.tile_mma $matA, $matB, $matC attr_dict: type($matA), type($matB), type($matC)-> type($res) | %vector_c = xetile.tile_mma %vector_a, %vector_b, %vector_c {#mp_c} : vector<64x32xbf16>, vector<32x128xbf16>, vector<64x128xfloat> into vector<64x128xfloat> | +|tile_transpose | operation ::= xetile.tile_transpose $permuation_dims attr_dict $vec : type($vec) -> type($res) | %vector_a = xetile.tile_transpose %vector_b {#mp_a}: vector<64x32xfloat> into vector<32x64xfloat> | +|tile_reduce | operation ::= xetile.tile_reduce $kind $src $reduction_dims attr_dict: type($value) -> type($res) | %vector_a = xetile.tile_reduce %vector_b [1] {#mp_a}: vector<64x32xfloat> into vector<64x1xfloat> | +|tile_broadcast | operation ::= xetile.tile_broadcast $src $broadcast_dims attr_dict : type($value) -> type($res) | %vector_a = xetile.tile_broadcast %vector_b [0] {#mp_a}: vector<1x32xfloat> into vector<64x32xfloat> | +|tile_conv_layout | operation ::= xetile.conv_layout $src attr_dict: type($value) -> type($res) | %vector_a = xetile.tile_conv_layout %vector_b {#mp_a} : vector<256x256xfloat> into vector<256x256xfloat> | With the `wg_map` attribute attached for the output vector, `tile_mma` does a matrix multiplication at a work group level vector. ```mlir #wg_map_d = #xetile.wg_map - %vector_d = XeTile.tile_mma %vector_a, %vector_b, %vector_c {#wg_map_d}: + %vector_d = xetile.tile_mma %vector_a, %vector_b, %vector_c {#wg_map_d}: vector<256x256xfloat>, vector<256x32xbf16>, vector<32x256xbf16> into vector<256x256xfloat> ``` The `wg_map` attribute of input vector operands can be derived from the wg_map_d. They must have the same sg_layout, and sg_data for m and n dimenion must be same as wg_map_d, and sg_data for k dimension must be same as operand A and B. These attributes may be retrieved from their producer ops, and the retrieved attributes must be consistent with the derived ones. Below is the derived wg_map for the three vector operands in the example above. ```mlir - #wg_map_a = #xetile.wg_map - #wg_map_b = #xetile.wg_map - #wg_map_c = #xetile.wg_map + #wg_map_a = #xetile.wg_map //wg_map for %vector_a + #wg_map_b = #xetile.wg_map //wg_map for %vector_b + #wg_map_c = #xetile.wg_map //wg_map for %vector_c ``` `tile_reduce` with `wg_map` does the reduction over a workgroup level vector. ```mlir #wg_map_a = #xetile.wg_map - %vector_a = XeTile.tile_reduce %vector_b [1] {#wg_map_a}: vector<256x128xfloat> into vector<256x1xfloat> + %vector_a = xetile.tile_reduce %vector_b [1] {#wg_map_a}: vector<256x128xfloat> into vector<256x1xfloat> ``` The `wg_map` attribute of the input vector can be derived from the wg_map_a. sg_layout must be same, sg_data for the dimension being reduced must be same as the input vector, and the other dimension must be same as the wg_map_a. The input vector's wg_map attribute may be retrieved from its producer op, and the retrieved attribute must be consistent with the derived one. Below is the derived wg_map for the input vector in the example above. ```mlir - #wg_map_b = #xetile.wg_map + #wg_map_b = #xetile.wg_map //wg_map for %vector_b ``` `tile_broadcast` with `wg_map` attribute broadcast at workgroup level. ```mlir #wg_map_a = #xetile.wg_map - %vector_a = XeTile.tile_broadcast %vector_b [1] {#wg_map_a}: vector<256x1xfloat> into vector<256x256xfloat> + %vector_a = xetile.tile_broadcast %vector_b [1] {#wg_map_a}: vector<256x1xfloat> into vector<256x256xfloat> ``` The `wg_map` attribute of the input vector can be derived from the wg_map_a. sg_layout must be same, sg_data for the dimension being broadcast must be "1", and the other dimension must be same as the wg_map_a. The input vector's wg_map attribute may be retrieved from its producer op, and the retrieved attribute must be consistent with the derived one. Below is the derived wg_map for the input vector in the example above. ```mlir - #wg_map_b = #xetile.wg_map + #wg_map_b = #xetile.wg_map //wg_map for %vector_b ``` `tile_transpose` with `wg_map` attribute transpose a workgroup level vector. ```mlir #wg_map_a = #xetile.wg_map - %vector_a = XeTile.tile_transpose %vector_b {#wg_map_a}: vector<512x128xfloat> into vector<128x512xfloat> + %vector_a = xetile.tile_transpose %vector_b {#wg_map_a}: vector<512x128xfloat> into vector<128x512xfloat> ``` The `wg_map` attribute of the input vector can be derived from the wg_map_a. The two dimension of sg_layout and sg_data must be swapped. The input vector's wg_map attribute may be retrieved from its producer op, and the retrieved attribute must be consistent with the derived one. Below is the derived wg_map for the input vector in the example above. ```mlir - #wg_map_b = #xetile.wg_map + #wg_map_b = #xetile.wg_map //wg_map for %vector_b ``` The tile_transpose can be implemented by saving and restoring from the shared local memory. It can be conceptually viewed as a composition of two operations: 1) store the vector to to shared memory with the #wg_map_b mapping assuming row_major and 2) use wg_map_a mapping to load the data from shared memory to vector assuming column_major. To support this, we relax the restriction of tile_load and tile_store so that they can load 2D from share local memory. @@ -298,24 +315,24 @@ Example with the wg_map specified for both input and output operands. ```mlir #wg_map_b = #xetile.wg_map // used for cooperative load/prefetch #wg_map_a = #xetile.wg_map // used as mma's input matrix A - %vector_a = XeTile.tile_conv_layout %vector_b {#wg_map_a #wg_map_b}: vector<256x256xfloat> into vector<256x256xfloat> + %vector_a = xetile.tile_conv_layout %vector_b {#wg_map_a #wg_map_b}: vector<256x256xfloat> into vector<256x256xfloat> ``` Example without the wg_map specified for the input operand. ```mlir #wg_map_a = #xetile.wg_map // used as mma's input matrix A - %vector_a = XeTile.tile_conv_layout %vector_b {#wg_map_a}: vector<256x256xfloat> into vector<256x256xfloat> + %vector_a = xetile.tile_conv_layout %vector_b {#wg_map_a}: vector<256x256xfloat> into vector<256x256xfloat> ``` The tile_conv_layout could be implemented by saving and restoring from the shared local memory. It can be conceptually viewed as a composition of two operations: 1) store the vector to to shared memory with the #wg_map_b mapping assuming row_major and 2) use wg_map_a mapping to load the data from shared memory to vector assuming same row_major. To support this, we relax the restriction of tile_load and tile_store so that they can load 2D from share local memory. ## Alternative design considerations -The alternative design of tile data type is to reuse the memref data type. The memref data type needs to be enhanced to allow attributes. So the XeTile's tile data type can be expressed with memref associated with Tile attributes. XeTile.wg_map and XeTile.sg_map are examples of these attributes. +The alternative design of tile data type is to reuse the memref data type. The memref data type needs to be enhanced to allow attributes. So the XeTile's tile data type can be expressed with memref associated with Tile attributes. xetile.wg_map and xetile.sg_map are examples of these attributes. -## Appendix 1 - use case for XeTile.order attribute and tile_transpose +## Appendix 1 - use case for xetile.order attribute and tile_transpose -XeTile.tile describes a 2D block in memory . The default layout of XeTile.tile is raw-major contiguous. So tile[i][j] refers to the position i*stride_i + j in the associated memory. The stride_j must be 1 since it is contiguous. This maps well the underlying 2d block loader, which loads data in raw-major layout only and no stride in innermost dimension. -Below is the example code for the most common use case of XeTile.tile. +xetile.tile describes a 2D block in memory . The default layout of xetile.tile is raw-major contiguous. So tile[i][j] refers to the position i*stride_i + j in the associated memory. The stride_j must be 1 since it is contiguous. This maps well the underlying 2d block loader, which loads data in raw-major layout only and no stride in innermost dimension. +Below is the example code for the most common use case of xetile.tile. ```mlir BF16 A[M][K], B[K][N], C[M][N]; // C = MM(A, B) For i = 0, M-1, M_tile Do @@ -328,9 +345,9 @@ Below is the example code for the most common use case of XeTile.tile. %vb = load_tile %b : vector<32x64x bf16>; %vc = tile_mma %va, %vb : vector<64x32xbf16>, vector<32x64x bf16> into vector<64x64xbf16>; ``` -The order attribute was introduced to support a second use case where the user has a row-major matrix, but likes to view it as col major. One example is the Triton flash attention code using the order attribute introduced by Triton block pointer programming (such programming mixes the row-major and column-major). With the col major view, the user can swap the i, j in the program. To support such a programming style, we introduced the order attribute to XeTile.tile data type. It provides an abstraction on top of row-major only XeGPU ops. +The order attribute was introduced to support a second use case where the user has a row-major matrix, but likes to view it as col major. One example is the Triton flash attention code using the order attribute introduced by Triton block pointer programming (such programming mixes the row-major and column-major). With the col major view, the user can swap the i, j in the program. To support such a programming style, we introduced the order attribute to xetile.tile data type. It provides an abstraction on top of row-major only XeGPU ops. -This is a use case for the order attribute of XeTile.tile. In this use case, the matrix B has a transposed memory layout to start with, for example BT [N,K] instead of B[K, N]. But the user likes to use the original program to index the matrix as if it is B[K, N], the order attribute is introduced to support this programming. User can flip the 2d block offset and size, and swap the stride from [K, 1] to [1, K]. +This is a use case for the order attribute of xetile.tile. In this use case, the matrix B has a transposed memory layout to start with, for example BT [N,K] instead of B[K, N]. But the user likes to use the original program to index the matrix as if it is B[K, N], the order attribute is introduced to support this programming. User can flip the 2d block offset and size, and swap the stride from [K, 1] to [1, K]. ```mlir BF16 A[M][K], BT[N, K], C[M][N]; // C = MM(A, BT) For i = 0, M-1, M_tile Do @@ -396,9 +413,9 @@ For i = 0, M-1, M_tile Do %vb = vector.transpose%bt : vector<64x32xbf16> to vector<32x64xbf16>; %vc = vector.contract %va, %vb : vector<64x32xbf16>, vector<32x64x bf16> into vector<64x64xbf16>; ``` -The vector/memref dialect code example can be lowered to XeTile using simple one-to-one mapping: subview maps to init_tile, transfer_read to load_tile, and contract to tile_mma. To lower the subview op to init_tile, the lowering first identifies what "layout" the input memref has, then decide whether to use the order attribute for the tile created by init_tile. The tile should have a consistent layout view with the given memref. Since Memref stride and affine_map is very generic, we limit the XeTile.tile to only support memref with the plain view (row-major) or the transposed view (col-major). +The vector/memref dialect code example can be lowered to XeTile using simple one-to-one mapping: subview maps to init_tile, transfer_read to load_tile, and contract to tile_mma. To lower the subview op to init_tile, the lowering first identifies what "layout" the input memref has, then decide whether to use the order attribute for the tile created by init_tile. The tile should have a consistent layout view with the given memref. Since Memref stride and affine_map is very generic, we limit the xetile.tile to only support memref with the plain view (row-major) or the transposed view (col-major). -The XeTile.tile order attribute needs to be consistent as the base memref’s memory layout. +The xetile.tile order attribute needs to be consistent as the base memref’s memory layout. Correct lowering - ```mlir init_tile: %0[0, 0]: memref<1024x1024xf16> -> tile<64x32xf16, order=[1, 0]>