Skip to content

Commit

Permalink
new env syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlimb committed Nov 15, 2024
1 parent 4396b53 commit 631a8e3
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 37 deletions.
4 changes: 2 additions & 2 deletions include/imex/Dialect/NDArray/IR/NDArrayOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ namespace ndarray {
/// @return true if given NDArrayTYpe has this specific environment attribute
template <typename T> bool hasEnv(const ::mlir::RankedTensorType &t) {
auto encoding = t.getEncoding();
if (auto envs = ::mlir::dyn_cast<EnvironmentAttr>(encoding)) {
if (auto envs = ::mlir::dyn_cast<EnvsAttr>(encoding)) {
for (auto a : envs.getEnvs()) {
if (::mlir::isa<T>(a)) {
return true;
Expand All @@ -103,7 +103,7 @@ inline bool hasGPUEnv(const ::mlir::Type &t) {
inline ::imex::region::GPUEnvAttr getGPUEnv(const ::mlir::Type &t) {
if (auto tt = ::mlir::dyn_cast<::mlir::RankedTensorType>(t)) {
auto encoding = tt.getEncoding();
if (auto envs = ::mlir::dyn_cast<EnvironmentAttr>(encoding)) {
if (auto envs = ::mlir::dyn_cast<EnvsAttr>(encoding)) {
for (auto a : envs.getEnvs()) {
if (auto g = ::mlir::dyn_cast<::imex::region::GPUEnvAttr>(a)) {
return g;
Expand Down
4 changes: 2 additions & 2 deletions include/imex/Dialect/NDArray/IR/NDArrayOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def NDArray_Dialect : Dialect {
let useDefaultAttributePrinterParser = true;
}

def NDArray_EnvironmentAttr : AttrDef<NDArray_Dialect, "Environment"> {
let mnemonic = "environment";
def NDArray_EnvsAttr : AttrDef<NDArray_Dialect, "Envs"> {
let mnemonic = "envs";
let parameters = (ins ArrayRefParameter<"::mlir::Attribute">:$envs);
let assemblyFormat = "`<` $envs `>`";
}
Expand Down
68 changes: 35 additions & 33 deletions test/Dialect/NDArray/Transforms/AddGPURegions.mlir
Original file line number Diff line number Diff line change
@@ -1,44 +1,46 @@
// RUN: imex-opt --split-input-file --add-gpu-regions %s -verify-diagnostics -o -| FileCheck %s

#GPUENV = #ndarray.envs<#region.gpu_env<device = "XeGPU">>

func.func @test_region(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 {
%c0 = arith.constant 0 : index
%c3 = arith.constant 3 : index
%c33 = arith.constant 33 : i64
%c22 = arith.constant 22 : index
%v = arith.constant 55 : i64
%s = arith.index_cast %arg0 : i64 to index
%0 = ndarray.linspace %arg0 %arg1 %c33 false {device = "XeGPU", team = 1 : i64} : (i64, i64, i64) -> tensor<33xi64, #region.gpu_env<device = "XeGPU">>
%1 = ndarray.create %c22 value %v {dtype = 2 : i8, device = "XeGPU", team = 1 : i64} : (index, i64) -> tensor<?xi64, #region.gpu_env<device = "XeGPU">>
%10 = ndarray.subview %0[%c0][22][%c3] : tensor<33xi64, #region.gpu_env<device = "XeGPU">> to tensor<?xi64, #region.gpu_env<device = "XeGPU">>
%20 = ndarray.ewbin %10, %1 {op = 0 : i32} : (tensor<?xi64, #region.gpu_env<device = "XeGPU">>, tensor<?xi64, #region.gpu_env<device = "XeGPU">>) -> tensor<?xi64, #region.gpu_env<device = "XeGPU">>
%21 = ndarray.reduction %20 {op = 4 : i32} : tensor<?xi64, #region.gpu_env<device = "XeGPU">> -> tensor<i64, #region.gpu_env<device = "XeGPU">>
%30 = builtin.unrealized_conversion_cast %21 : tensor<i64, #region.gpu_env<device = "XeGPU">> to i64
ndarray.delete %0 : tensor<33xi64, #region.gpu_env<device = "XeGPU">>
ndarray.delete %1 : tensor<?xi64, #region.gpu_env<device = "XeGPU">>
%0 = ndarray.linspace %arg0 %arg1 %c33 false {device = "XeGPU", team = 1 : i64} : (i64, i64, i64) -> tensor<33xi64, #GPUENV>
%1 = ndarray.create %c22 value %v {dtype = 2 : i8, device = "XeGPU", team = 1 : i64} : (index, i64) -> tensor<?xi64, #GPUENV>
%10 = ndarray.subview %0[%c0][22][%c3] : tensor<33xi64, #GPUENV> to tensor<?xi64, #GPUENV>
%20 = ndarray.ewbin %10, %1 {op = 0 : i32} : (tensor<?xi64, #GPUENV>, tensor<?xi64, #GPUENV>) -> tensor<?xi64, #GPUENV>
%21 = ndarray.reduction %20 {op = 4 : i32} : tensor<?xi64, #GPUENV> -> tensor<i64, #GPUENV>
%30 = builtin.unrealized_conversion_cast %21 : tensor<i64, #GPUENV> to i64
ndarray.delete %0 : tensor<33xi64, #GPUENV>
ndarray.delete %1 : tensor<?xi64, #GPUENV>
return %30 : i64
}
// CHECK-LABEL: func.func @test_region
// CHECK: [[V0:%.*]] = region.env_region #region.gpu_env<device = "XeGPU"> -> tensor<33xi64, #region.gpu_env<device = "XeGPU">> {
// CHECK: [[V0:%.*]] = region.env_region #GPUENV -> tensor<33xi64, #GPUENV> {
// CHECK-NEXT: ndarray.linspace
// CHECK-NEXT: region.env_region_yield
// CHECK: [[V1:%.*]] = region.env_region #region.gpu_env<device = "XeGPU"> -> tensor<?xi64, #region.gpu_env<device = "XeGPU">> {
// CHECK: [[V1:%.*]] = region.env_region #GPUENV -> tensor<?xi64, #GPUENV> {
// CHECK-NEXT: ndarray.create
// CHECK-NEXT: region.env_region_yield
// CHECK: [[V2:%.*]] = region.env_region #region.gpu_env<device = "XeGPU"> -> tensor<?xi64, #region.gpu_env<device = "XeGPU">> {
// CHECK: [[V2:%.*]] = region.env_region #GPUENV -> tensor<?xi64, #GPUENV> {
// CHECK-NEXT: ndarray.subview [[V0]]
// CHECK-NEXT: region.env_region_yield
// CHECK: [[V3:%.*]] = region.env_region #region.gpu_env<device = "XeGPU"> -> tensor<?xi64, #region.gpu_env<device = "XeGPU">> {
// CHECK: [[V3:%.*]] = region.env_region #GPUENV -> tensor<?xi64, #GPUENV> {
// CHECK-NEXT: ndarray.ewbin [[V2]], [[V1]]
// CHECK: [[V4:%.*]] = region.env_region #region.gpu_env<device = "XeGPU"> -> tensor<i64, #region.gpu_env<device = "XeGPU">> {
// CHECK: [[V4:%.*]] = region.env_region #GPUENV -> tensor<i64, #GPUENV> {
// CHECK-NEXT: ndarray.reduction [[V3]]
// CHECK-NEXT: region.env_region_yield
// CHECK-NEXT: }
// CHECK-NEXT: [[V5:%.*]] = builtin.unrealized_conversion_cast
// CHECK: region.env_region #region.gpu_env<device = "XeGPU"> {
// CHECK-NEXT: ndarray.delete [[V0]] : tensor<33xi64, #region.gpu_env<device = "XeGPU">>
// CHECK: region.env_region #GPUENV {
// CHECK-NEXT: ndarray.delete [[V0]] : tensor<33xi64, #GPUENV>
// CHECK-NEXT: }
// CHECK-NEXT: region.env_region #region.gpu_env<device = "XeGPU"> {
// CHECK-NEXT: ndarray.delete [[V1]] : tensor<?xi64, #region.gpu_env<device = "XeGPU">>
// CHECK-NEXT: region.env_region #GPUENV {
// CHECK-NEXT: ndarray.delete [[V1]] : tensor<?xi64, #GPUENV>
// CHECK-NEXT: }
// CHECK-NEXT: return [[V5]]

Expand All @@ -48,34 +50,34 @@ func.func @test_copy() -> tensor<33xi64> {
%c0 = arith.constant 0 : i64
%c3 = arith.constant 3 : i64
%c33 = arith.constant 33 : i64
%0 = ndarray.linspace %c0 %c3 %c33 false {device = "XeGPU", team = 1 : i64} : (i64, i64, i64) -> tensor<33xi64, #region.gpu_env<device = "XeGPU">>
%1 = ndarray.copy %0 : tensor<33xi64, #region.gpu_env<device = "XeGPU">> -> tensor<33xi64>
%2 = ndarray.copy %1 : tensor<33xi64> -> tensor<33xi64, #region.gpu_env<device = "XeGPU">>
%3 = ndarray.copy %2 : tensor<33xi64, #region.gpu_env<device = "XeGPU">> -> tensor<33xi64, #region.gpu_env<device = "XeGPU">>
%4 = ndarray.copy %3 : tensor<33xi64, #region.gpu_env<device = "XeGPU">> -> tensor<33xi64>
%0 = ndarray.linspace %c0 %c3 %c33 false {device = "XeGPU", team = 1 : i64} : (i64, i64, i64) -> tensor<33xi64, #GPUENV>
%1 = ndarray.copy %0 : tensor<33xi64, #GPUENV> -> tensor<33xi64>
%2 = ndarray.copy %1 : tensor<33xi64> -> tensor<33xi64, #GPUENV>
%3 = ndarray.copy %2 : tensor<33xi64, #GPUENV> -> tensor<33xi64, #GPUENV>
%4 = ndarray.copy %3 : tensor<33xi64, #GPUENV> -> tensor<33xi64>
%5 = ndarray.copy %4 : tensor<33xi64> -> tensor<33xi64>
return %5 : tensor<33xi64>
}
// CHECK-LABEL: func.func @test_copy() -> tensor<33xi64> {
// CHECK: region.env_region #region.gpu_env<device = "XeGPU"> -> tensor<33xi64, #region.gpu_env<device = "XeGPU">> {
// CHECK: region.env_region #GPUENV -> tensor<33xi64, #GPUENV> {
// CHECK: ndarray.linspace
// CHECK-SAME: -> tensor<33xi64, #region.gpu_env<device = "XeGPU">>
// CHECK-SAME: -> tensor<33xi64, #GPUENV>
// CHECK: region.env_region_yield
// CHECK-SAME: tensor<33xi64, #region.gpu_env<device = "XeGPU">>
// CHECK-SAME: tensor<33xi64, #GPUENV>
// CHECK: ndarray.copy
// CHECK-SAME: tensor<33xi64, #region.gpu_env<device = "XeGPU">> -> tensor<33xi64>
// CHECK: region.env_region #region.gpu_env<device = "XeGPU"> -> tensor<33xi64, #region.gpu_env<device = "XeGPU">> {
// CHECK-SAME: tensor<33xi64, #GPUENV> -> tensor<33xi64>
// CHECK: region.env_region #GPUENV -> tensor<33xi64, #GPUENV> {
// CHECK: ndarray.copy
// CHECK-SAME: tensor<33xi64> -> tensor<33xi64, #region.gpu_env<device = "XeGPU">>
// CHECK-SAME: tensor<33xi64> -> tensor<33xi64, #GPUENV>
// CHECK: region.env_region_yield
// CHECK-SAME: tensor<33xi64, #region.gpu_env<device = "XeGPU">>
// CHECK: region.env_region #region.gpu_env<device = "XeGPU"> -> tensor<33xi64, #region.gpu_env<device = "XeGPU">> {
// CHECK-SAME: tensor<33xi64, #GPUENV>
// CHECK: region.env_region #GPUENV -> tensor<33xi64, #GPUENV> {
// CHECK: ndarray.copy
// CHECK-SAME: tensor<33xi64, #region.gpu_env<device = "XeGPU">> -> tensor<33xi64, #region.gpu_env<device = "XeGPU">>
// CHECK-SAME: tensor<33xi64, #GPUENV> -> tensor<33xi64, #GPUENV>
// CHECK: region.env_region_yield
// CHECK-SAME: tensor<33xi64, #region.gpu_env<device = "XeGPU">>
// CHECK-SAME: tensor<33xi64, #GPUENV>
// CHECK: ndarray.copy
// CHECK-SAME: tensor<33xi64, #region.gpu_env<device = "XeGPU">> -> tensor<33xi64>
// CHECK-SAME: tensor<33xi64, #GPUENV> -> tensor<33xi64>
// CHECK: ndarray.copy
// CHECK-SAME: tensor<33xi64> -> tensor<33xi64>
// CHECK: return
Expand Down

0 comments on commit 631a8e3

Please sign in to comment.