Skip to content

Commit

Permalink
Simplify wg-level XeTile op to take wg_map attribute for the output v…
Browse files Browse the repository at this point in the history
…ector operand only (#872)
  • Loading branch information
Jianhui-Li authored Oct 7, 2024
1 parent eb9f057 commit e5dfe6b
Showing 1 changed file with 55 additions and 36 deletions.
91 changes: 55 additions & 36 deletions docs/rfcs/XeTile.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,75 +219,94 @@ Below is an example.
%wg_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<128x128xf16, #tile_attr>
```
Within the `xetile.wg_map`, `sg_layout` specifies the subgroup layout, and `sg_data` specifies the tile size owned by each subgroup. The tile created by init_tile is a workgroup-level tile. In the example above, sg_layout [2,2] means that each workgroup has 4 subgroups with 2 rows and 2 columns. sg_data [32,128] means that each subgroup works on a submatrix [32, 128]. The data elements assigned to each subgroup thread must be contiguous.
Within the `xetile.wg_map`, `sg_layout` specifies the subgroup layout, and `sg_data` specifies the tile size owned by each subgroup. The tile created by init_tile is a workgroup-level tile. In the example above, sg_layout [2,2] means that each workgroup has 4 subgroups with 2 rows and 2 columns. When mapping sg_layout to linear subgroup id, sg_layout is always mapped to subgroup id in row-major ordering. sg_data [32,128] means that each subgroup works on a submatrix [32, 128]. The data elements assigned to each subgroup thread must be contiguous.

For each dimension, the size of `sg_layout` multiplying `sg_data` must be divisible by the wg_tile size or vice versa. The wg_tile is distributed to sg_data x sg_layout in a round-robin fashion. If sg_data[i] x sg_layout[i] < wg_tile[i], we have data left after all subgroups are assigned for the first round. In this case, we continue to assign the rest data starting from the first subgroup until the data is completely assigned. If sg_data[i] x sg_layout[i] >= wg_tile[i], we may have already used up all the data before all subgroups are assigned. In this case, we wrap around the wg_tile and continue the assignment, and the rest subgroups along that dimension share the same data.

For example, for the tile size [128, 128] and sg_data [32, 128], along the second dimension, there is no more data left to assign after the first subgroup, it wraps around and moves to the beginning of the tile and continues the assignment. Instead, for the first dimension, there is more data left after the first round of distribution, so it move to the next subtile and continue the assignement. As a result, the tile would be sliced to four subtiles with size [32,128], with the following mapping:
For example, for the tile size [128, 128] and sg_data [32, 128], along the second dimension, there is no more data left to assign after the first subgroup, it wraps around and moves to the beginning of the tile and continues the assignment. Instead, for the first dimension, there is more data left after the first round of distribution, so it move to the next subtile and continue the assignement. As a result, the tile would be sliced to four subtiles with size [32,128], with the following mapping for sg_layout [2,2]:

| subtiles | threads |
| :--- | :---- |
| [ 0:31, 0:127] | [0, 0] , [0, 1] |
| [ 32:63, 0:127] | [1, 0] , [1, 1] |
| [ 64:95, 0:127] | [0, 0] , [0, 1] |
| [96:127, 0:127] | [1, 0] , [1, 1] |
| subtiles | 2D subgroup id | Linearized subgroup id
| :--- | :---- | :---- |
| [ 0:31, 0:127] | [0, 0] , [0, 1] | 0 , 1 |
| [ 32:63, 0:127] | [1, 0] , [1, 1] | 2 , 3 |
| [ 64:95, 0:127] | [0, 0] , [0, 1] | 0 , 1 |
| [96:127, 0:127] | [1, 0] , [1, 1] | 2 , 3 |

With the `xetile.wg_map` attribute being included in the tile data type, the tile memory related operations (xxx_tile) can be distributed to subgroup. The vector based operations (tile_xxx) requires extra handling, since we can't attatch the the `xetile.wg_map` attribute to MLIR vector data type.

The proposal is to attach the `xetile.wg_map` to the vector based XeTile operations as illustrated below. These operations may have different subgroup layout for the input and output, which means the data may move from one subgroup to another.
The proposal is to attach the `xetile.wg_map` attribute to the vector based XeTile operations as illustrated below. The attribute applies only to the output value of each operation. The input values `xetile.wg_map` are determined by their respective defining operations.
| Ops | Syntax | Example |
| :--- | :---- | :--- |
|tile_mma | operation ::=XeTile.tile_mma $matA, $matB, $matC attr_dict: type($matA), type($matB), type($matC)-> type($res) | %vector_c = XeTile.tile_mma %vector_a, %vector_b, %vector_c {#mp_a #mp_b #mp_c #mp_c} : vector<64x32xbf16>, vector<32x128xbf16>, vector<64x128xfloat> into vector<64x128xfloat> |
|tile_transpose | operation ::=XeTile.tile_transpose $permuation_dims attr_dict $vec : type($vec) -> type($res) | %vector_a = XeTile.tile_transpose %vector_b {#mp_b #mp_a}: vector<64x32xfloat> into vector<32x64xfloat> |
|tile_reduce | operation ::=XeTile.tile_reduce $kind $src $reduction_dims attr_dict: type($value) -> type($res) | %vector_a = XeTile.tile_reduce <add> %vector_b [1] {#mp_b #mp_a}: vector<64x32xfloat> into vector<64x1xfloat> |
|tile_broadcast | operation ::=XeTile.tile_broadcast $src $broadcast_dims attr_dict : type($value) -> type($res) | %vector_a = XeTile.tile_broadcast %vector_b [0] {#mp_b #mp_a}: vector<1x32xfloat> into vector<64x32xfloat> |
|tile_conv_layout | operation ::=XeTile.conv_layout $src attr_dict: type($value) -> type($res) | %vector_a = XeTile.tile_conv_layout %vector_b {#mp_b #mp_a} : vector<256x256xfloat> into vector<256x256xfloat> |
|tile_mma | operation ::=XeTile.tile_mma $matA, $matB, $matC attr_dict: type($matA), type($matB), type($matC)-> type($res) | %vector_c = XeTile.tile_mma %vector_a, %vector_b, %vector_c {#mp_c} : vector<64x32xbf16>, vector<32x128xbf16>, vector<64x128xfloat> into vector<64x128xfloat> |
|tile_transpose | operation ::=XeTile.tile_transpose $permuation_dims attr_dict $vec : type($vec) -> type($res) | %vector_a = XeTile.tile_transpose %vector_b {#mp_a}: vector<64x32xfloat> into vector<32x64xfloat> |
|tile_reduce | operation ::=XeTile.tile_reduce $kind $src $reduction_dims attr_dict: type($value) -> type($res) | %vector_a = XeTile.tile_reduce <add> %vector_b [1] {#mp_a}: vector<64x32xfloat> into vector<64x1xfloat> |
|tile_broadcast | operation ::=XeTile.tile_broadcast $src $broadcast_dims attr_dict : type($value) -> type($res) | %vector_a = XeTile.tile_broadcast %vector_b [0] {#mp_a}: vector<1x32xfloat> into vector<64x32xfloat> |
|tile_conv_layout | operation ::=XeTile.conv_layout $src attr_dict: type($value) -> type($res) | %vector_a = XeTile.tile_conv_layout %vector_b {#mp_a} : vector<256x256xfloat> into vector<256x256xfloat> |

With the `wg_map` attribute attached for the output vector, `tile_mma` does a matrix multiplication at a work group level vector.
```mlir
#wg_map_d = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]>
With these attributes, `tile_mma` does a matrix multiplication at a work group level vector.
%vector_d = XeTile.tile_mma %vector_a, %vector_b, %vector_c {#wg_map_d}:
vector<256x256xfloat>, vector<256x32xbf16>, vector<32x256xbf16>
into vector<256x256xfloat>
```
The `wg_map` attribute of input vector operands can be derived from the wg_map_d. They must have the same sg_layout, and sg_data for m and n dimenion must be same as wg_map_d, and sg_data for k dimension must be same as operand A and B. These attributes may be retrieved from their producer ops, and the retrieved attributes must be consistent with the derived ones. Below is the derived wg_map for the three vector operands in the example above.
```mlir
#wg_map_a = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 32]>
#wg_map_b = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]>
#wg_map_c = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]>
%vector_c = XeTile.tile_mma %vector_a, %vector_b, %vector_c {#wg_map_a #wg_map_b #wg_map_c #wg_map_c}:
vector<256x256xfloat>, vector<256x256xbf16>, vector<256x256xbf16>
into vector<256x256xfloat>
```

`tile_reduce` follows the vector.multi-reduction semantics and can be applied to 4D vector.
`tile_reduce` with `wg_map` does the reduction over a workgroup level vector.
```mlir
#wg_map_a = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 128]>
#wg_map_a2 = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 1]>
%vector_a = XeTile.tile_reduce <add> %vector_b [1] {#wg_map_a #wg_map_a2}: vector<256x128xfloat> into vector<256x1xfloat>
#wg_map_a = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 1]>
%vector_a = XeTile.tile_reduce <add> %vector_b [1] {#wg_map_a}: vector<256x128xfloat> into vector<256x1xfloat>
```
The `wg_map` attribute of the input vector can be derived from the wg_map_a. sg_layout must be same, sg_data for the dimension being reduced must be same as the input vector, and the other dimension must be same as the wg_map_a. The input vector's wg_map attribute may be retrieved from its producer op, and the retrieved attribute must be consistent with the derived one. Below is the derived wg_map for the input vector in the example above.
```mlir
#wg_map_b = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 128]>
```

`tile_broadcast` broadcast 4D vector.
`tile_broadcast` with `wg_map` attribute broadcast at workgroup level.
```mlir
#wg_map_a = #xetile.wg_map<sg_layout = [16, 1], sg_data = [16, 256]>
%vector_a = XeTile.tile_broadcast %vector_b [1] {#wg_map_a}: vector<256x1xfloat> into vector<256x256xfloat>
```
The `wg_map` attribute of the input vector can be derived from the wg_map_a. sg_layout must be same, sg_data for the dimension being broadcast must be "1", and the other dimension must be same as the wg_map_a. The input vector's wg_map attribute may be retrieved from its producer op, and the retrieved attribute must be consistent with the derived one. Below is the derived wg_map for the input vector in the example above.
```mlir
#wg_map_b = #xetile.wg_map<sg_layout = [16, 1], sg_data = [16, 1]>
#wg_map_a = #xetile.wg_map<sg_layout = [4, 4], sg_data = [64, 64]>
%vector_a = XeTile.tile_broadcast %vector_b [1] {#wg_map_b #wg_map_a}: vector<256x1xfloat> into vector<256x256xfloat>
```

`tile_transpose` transpose 4D vector.

The transpose could be implemented by saving and restoring from the share local memory. To support this, we relax the restriction of tile_load and tile_store so that they can load 2D from share local memory.
`tile_transpose` with `wg_map` attribute transpose a workgroup level vector.
```mlir
#wg_map_a = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 64]>
%vector_a = XeTile.tile_transpose %vector_b {#wg_map_a}: vector<512x128xfloat> into vector<128x512xfloat>
```

The `wg_map` attribute of the input vector can be derived from the wg_map_a. The two dimension of sg_layout and sg_data must be swapped. The input vector's wg_map attribute may be retrieved from its producer op, and the retrieved attribute must be consistent with the derived one. Below is the derived wg_map for the input vector in the example above.
```mlir
#wg_map_b = #xetile.wg_map<sg_layout = [8, 4], sg_data = [64, 32]>
#wg_map_a = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 64]>
%vector_a = XeTile.tile_transpose %vector_b {#wg_map_b #wg_map_a}: vector<512x128xfloat> into vector<128x512xfloat>
```
The tile_transpose can be conceptually viewd as a composition of two operations: 1) store the vector to to shared memory with the #wg_map_b mapping assuming row_major and 2) use wg_map_a mapping to load the data from shared memory to vector assuming row_major.
The tile_transpose can be implemented by saving and restoring from the shared local memory. It can be conceptually viewed as a composition of two operations: 1) store the vector to to shared memory with the #wg_map_b mapping assuming row_major and 2) use wg_map_a mapping to load the data from shared memory to vector assuming column_major. To support this, we relax the restriction of tile_load and tile_store so that they can load 2D from share local memory.

An optimization is to analyze the load op which produces %vector_b, carefully arrange its mapping so that each subgroup thread loads its corresponding subgroup tile, and then either combine transpose function to the load op or do an in-register transpose.

`tile_convert_layout` changes the layout of subgroup threads.
`tile_conv_layout` with `wg_map` attributes remaps the workgroup level vector to subgroup threads. The second `wg_map` attribute is optional and describes the input operand. The input vector's wg_map attribute may be retrieved from its producer op, and the retrieved attribute must be consistent with the second `wg_map` attribute if it is present.

Example with the wg_map specified for both input and output operands.
```mlir
#wg_map_b = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]>
#wg_map_a = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 256]>
%vector_a = XeTile.tile_convert_layout <add> %vector_b [1] {#wg_map_b #wg_map_a}: vector<256x256xfloat> into vector<256x256float>
#wg_map_b = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]> // used for cooperative load/prefetch
#wg_map_a = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 256]> // used as mma's input matrix A
%vector_a = XeTile.tile_conv_layout %vector_b {#wg_map_a #wg_map_b}: vector<256x256xfloat> into vector<256x256xfloat>
```
Example without the wg_map specified for the input operand.
```mlir
#wg_map_a = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 256]> // used as mma's input matrix A
%vector_a = XeTile.tile_conv_layout %vector_b {#wg_map_a}: vector<256x256xfloat> into vector<256x256xfloat>
```
The tile_conv_layout could be implemented by saving and restoring from the shared local memory. It can be conceptually viewed as a composition of two operations: 1) store the vector to to shared memory with the #wg_map_b mapping assuming row_major and 2) use wg_map_a mapping to load the data from shared memory to vector assuming same row_major. To support this, we relax the restriction of tile_load and tile_store so that they can load 2D from share local memory.


## Alternative design considerations

Expand Down

0 comments on commit e5dfe6b

Please sign in to comment.