From d47050d04d299dab799f8a1c6eb7e5d85bda535f Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Fri, 11 Feb 2022 11:09:33 +0000 Subject: [PATCH] Refactoring ALP This is a big refactor of our experimental codebase to follow latest developments: - Generate functions through the Python DSL interface - Use the trasnformation infra common in the sandbox - Generalize tuner/mlirc to be program agnostic (although transformations are still GEMM-specific) - Add a python transformation to save the IR - Make `configure.py` compatible with python 3.7 --- configure.py | 5 + experimental/alp/README.md | 111 +++++++ experimental/alp/alp/library/blas.py | 138 --------- experimental/alp/alp/mlirc.py | 60 ---- experimental/alp/alp/tuner.py | 149 ---------- experimental/alp/alp/utils.py | 105 ------- .../alp/include/alp/Transforms/Passes.h | 7 +- .../alp/include/alp/Transforms/Passes.td | 16 +- .../alp/lib/AlpRuntime/alp_runtime.cpp | 27 +- .../alp/lib/Transforms/CMakeLists.txt | 1 + .../lib/Transforms/extract_kernel_pass.cpp | 212 ++++++++------ .../lib/Transforms/for_to_dowhile_loop.cpp | 148 ++++++++++ .../alp/lib/Transforms/legalize_vector.cpp | 191 ++++++------ .../lib/Transforms/modulo_scheduling_pass.cpp | 277 +++++++++--------- experimental/alp/python/alp/__init__.py | 0 .../alp/python/alp/backend/__init__.py | 0 .../alp/python/alp/backend/codegen.py | 25 ++ experimental/alp/python/alp/backend/mlirc.py | 198 +++++++++++++ .../alp/python/alp/backend/transforms.py | 61 ++++ experimental/alp/python/alp/backend/tuner.py | 146 +++++++++ experimental/alp/python/alp/backend/utils.py | 114 +++++++ .../alp/python/alp/benchmark/__init__.py | 0 .../alp/python/alp/benchmark/blas/__init__.py | 0 .../alp/python/alp/benchmark/blas/gemm.py | 141 +++++++++ .../alp/python/alp/benchmark/infra.py | 73 +++++ experimental/alp/python/alp/test/__init__.py | 0 .../alp/python/alp/test/blas/__init__.py | 0 experimental/alp/python/alp/test/blas/gemm.py | 162 ++++++++++ experimental/alp/python/alp/test/infra.py | 77 +++++ .../alp/python/alp/transition/__init__.py | 0 .../python/alp/transition/blas/__init__.py | 0 .../alp/python/alp/transition/blas/gemm.py | 177 +++++++++++ experimental/alp/python/test/gemm.py | 137 +++++++++ experimental/alp/test/test-alp-legalize.mlir | 35 +++ .../alp/test/test-alp-modulo-scheduling.mlir | 53 ++++ 35 files changed, 2069 insertions(+), 777 deletions(-) delete mode 100644 experimental/alp/alp/library/blas.py delete mode 100644 experimental/alp/alp/mlirc.py delete mode 100644 experimental/alp/alp/tuner.py delete mode 100644 experimental/alp/alp/utils.py create mode 100644 experimental/alp/lib/Transforms/for_to_dowhile_loop.cpp create mode 100644 experimental/alp/python/alp/__init__.py create mode 100644 experimental/alp/python/alp/backend/__init__.py create mode 100644 experimental/alp/python/alp/backend/codegen.py create mode 100644 experimental/alp/python/alp/backend/mlirc.py create mode 100644 experimental/alp/python/alp/backend/transforms.py create mode 100644 experimental/alp/python/alp/backend/tuner.py create mode 100644 experimental/alp/python/alp/backend/utils.py create mode 100644 experimental/alp/python/alp/benchmark/__init__.py create mode 100644 experimental/alp/python/alp/benchmark/blas/__init__.py create mode 100644 experimental/alp/python/alp/benchmark/blas/gemm.py create mode 100644 experimental/alp/python/alp/benchmark/infra.py create mode 100644 experimental/alp/python/alp/test/__init__.py create mode 100644 experimental/alp/python/alp/test/blas/__init__.py create mode 100644 experimental/alp/python/alp/test/blas/gemm.py create mode 100644 experimental/alp/python/alp/test/infra.py create mode 100644 experimental/alp/python/alp/transition/__init__.py create mode 100644 experimental/alp/python/alp/transition/blas/__init__.py create mode 100644 experimental/alp/python/alp/transition/blas/gemm.py create mode 100644 experimental/alp/python/test/gemm.py create mode 100644 experimental/alp/test/test-alp-legalize.mlir create mode 100644 experimental/alp/test/test-alp-modulo-scheduling.mlir diff --git a/configure.py b/configure.py index 6c7c67335175..18f02be4fae6 100755 --- a/configure.py +++ b/configure.py @@ -214,6 +214,11 @@ def main(args): "mlir-cpu-runner", "mlir_runner_utils", "mlir_c_runner_utils", \ "mlir_async_runtime_copy", "llvm-mca", "llvm-objdump", "llc", "opt", \ "FileCheck"] + + if args.enable_alp: + cmake_args.append("clang") + cmake_args.append("clang-cpp") + print(f"-- Performing initial build: {' '.join(cmake_args)}") subprocess.check_call(cmake_args, cwd=build_dir) diff --git a/experimental/alp/README.md b/experimental/alp/README.md index e69de29bb2d1..6528a563df3b 100644 --- a/experimental/alp/README.md +++ b/experimental/alp/README.md @@ -0,0 +1,111 @@ +# How to enable and use alp +This is a very simple set of instructions to enable and work with alp from iree-llvm-sandbox. We use the following environment variables defaults in these instructions: + +* `IREE_LLVM_SANDBOX_SOURCE_DIR`: path to the source of the iree-llvm-sandbox +* `IREE_LLVM_SANDBOX_BUILD_DIR`: path to the source of the iree-llvm-sandbox +* `LLVM_SOURCE_DIR`: path to the source of the llvm-project folder + +We also need to set the correct `$PYTHONPATH` to enable the python infrastructure: +``` +$ export PYTHONPATH=$IREE_LLVM_SANDBOX_SOURCE_DIR/build/tools/sandbox/python_package:$IREE_LLVM_SANDBOX_SOURCE_DIR/python/examples/:$LLVM_SOURCE_DIR/mlir/python:$IREE_LLVM_SANDBOX_SOURCE_DIR/experimental/alp/python + ``` + +## Download LLVM +You should clone LLVM and point it to the commit indicated in: `${IREE_LLVM_SOURCE_DIR}/pinned-llvm-version` +``` +$ git clone https://github.com/llvm/llvm-project.git +$ cd llvm-project +$ git checkout `cat ${IREE_LLVM_SANDBOX_SOURCE_DIR}/pinned-llvm-version` +``` + +## Download LLVM [Internal development] +For internal development you should clone directly from the codehub mirror: +``` +$ git clone ssh://git@codehub-dg-y.huawei.com:2222/boole-compiler/uk-team/llvm-project.git +$ git checkout main +``` + +## Build all together +This needs to be run from $IREE_LLVM_SANDBOX_SOURCE_DIR. I am pointing out instructions for AArch64 + ALP: +``` +$ python3 ./configure.py --target=AArch64 --llvm-path=$LLVM_SOURCE_DIR --alp +``` + Please note that the supported `cmake` version is >= 3.21.0 + +After this command, if you only want to rebuild, you can simply do: +``` +$ cmake --build $IREE_LLVM_SANDBOX_SOURCE_DIR/build --target tools/sandbox/all mlir-opt mlir-translate mlir_runner_utils mlir_c_runner_utils llvm-mca llvm-objdump llc opt + ``` + +## Use the tool +Given a generic MLIR program, `prog.mlir`, we can compile it in the following way: +``` +$ python3 -m alp.backend.mlirc --input-file=prog.mlir ... # transformation flags +``` +This will create an assembly file `prog.s`. In order to run it, we have two options: +a) Link the assembly to a C++ program (see Transition Path below), link and run +b) Write a benchmark program in MLIR and execute it through the python framework. + +In this section, we will show-case option b) using GEMM as an example. In the following we assume that the current folder is `$IREE_LLVM_SANDBOX_SOURCE_DIR/experimental/alp` + +### Generate the target program +Our transition python module is supposed to generate MLIR program for known library functions. To generate GEMM, you can run: +``` +$ python3 -m alp.transition.blas.gemm --M 2048 --N 2048 --K 2048 --trA +``` +This will generate a `gemm.mlir` program in the current folder which is supposed to execute a matrix multiply operation `C += A*B` where `A` is pre-transposed. You can also generate a dynamic sized GEMM by not specifying any of the sizes. For instance: +``` +$ python3 -m alp.transition.blas.gemm --trA +``` +Generates a fully dynamic GEMM implementation where the sizes are read dynamically from the inputs. + +### Compile the program +We can compile `gemm.mlir` in the following way: + +``` +$ python3 -m alp.backend.mlirc --input-file=gemm.mlir --tile-sizes 2048 512 128 --register-tile-sizes 8 8 1 --reorder-tile-sizes 0 2 1 --reorder-register-tile-sizes 0 1 2 --unroll-vector-transfers --split-vector-transfers-to none --hoist-packing 4 3 0 --modulo-scheduling --ms-unroll=2 --transpose-packing 0 0 0 --verbosity-level=4 +``` + +A file `gemm.s` should be created in your current folder. + +### Benchmark the program +Our infrastructure provides the possibility to generate a benchmark MLIR file, compile it, link it with the target assembly file and run it. This is what you have to do: + +``` +python3 -m alp.benchmark.blas.gemm --asm-program=gemm.s --M=2048 --N=2048 --K=2048 --trA +``` + +Please note that in this case we need to provide information to the benchmark about what we want to run. If you want to re-run the benchmark you can either issue the same command again, or you can simply run the executable `gemm.bench.out` that has been created in your current folder. You may also want to just generate the benchmark program, and in this case you should simply run: +``` +python3 -m alp.benchmark.blas.gemm --M=2048 --N=2048 --K=2048 --trA +``` + +Also, you can have a look at the `gemm.bench.mlir` file that has been generated within your current folder. + +### Test the program +You can finally test that the transformed program is correct. The command is very similar to the ones using for benchmarking: + +``` +python3 -m alp.test.blas.gemm --asm-program=gemm.s --M=2048 --N=2048 --K=2048 --trA +``` +Please note that we are using a naive algorithm to compute the matrix multiply, and this might take some time to finish. +### Smoke test +``` +$ cd $IREE_LLVM_SANDBOX_SOURCE_DIR/experimental/alp +$ make check +``` +## Use the tuner +### Download OpenTuner +OpenTuner should come as a prebuild package installable directly from `pip3`: +``` +$ pip3 install --user opentuner +``` + +### Tune a gemm program +The tuner is the real backend compiler, since it issues the transformations to apply to the program via `mlirc`. To run the tuner needs: +* The MLIR program to compile +* The MLIR benchmark to execute the program + +``` +python3 -m alp.backend.tuner --input-file gemm.mlir --benchmark gemm.bench.mlir +``` diff --git a/experimental/alp/alp/library/blas.py b/experimental/alp/alp/library/blas.py deleted file mode 100644 index 8ec4948943f5..000000000000 --- a/experimental/alp/alp/library/blas.py +++ /dev/null @@ -1,138 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -typest = """ -!memref_type_A = type tensor<_K_x_M_xf32> -!memref_type_B = type tensor<_K_x_N_xf32> -!memref_type_C = type tensor<_M_x_N_xf32> -""" - -types = """ -!memref_type_A = type tensor<_M_x_K_xf32> -!memref_type_B = type tensor<_K_x_N_xf32> -!memref_type_C = type tensor<_M_x_N_xf32> -""" - -init_tensors = """ - %A0 = linalg.init_tensor [_M_,_K_] : !memref_type_A - %B0 = linalg.init_tensor [_K_,_N_] : !memref_type_B - %C = linalg.init_tensor [_M_, _N_] : !memref_type_C -""" - -init_tensors_t = """ - %A0 = linalg.init_tensor [_K_,_M_] : !memref_type_A - %B0 = linalg.init_tensor [_K_,_N_] : !memref_type_B - %C = linalg.init_tensor [_M_, _N_] : !memref_type_C -""" - - -gemm_benchmark = f""" -func @main() -> i32 {{ - call @print_pid() : () -> () - __INIT_TENSORS__ - - %elem = arith.constant 1.0 : f32 - %A = linalg.fill(%elem, %A0) : f32, !memref_type_A -> !memref_type_A - %B = linalg.fill(%elem, %B0) : f32, !memref_type_B -> !memref_type_B - - %out = call @gemm(%A, %B, %C) : (!memref_type_A, !memref_type_B, !memref_type_C) -> !memref_type_C - %reps = arith.constant _REPS_ : index - %t_start = call @rtclock() : () -> f64 - affine.for %arg0 = 0 to %reps {{ - call @gemm(%A, %B, %C) : (!memref_type_A, !memref_type_B, !memref_type_C) -> !memref_type_C - }} - %t_end = call @rtclock() : () -> f64 - %repsi = arith.index_cast %reps : index to i64 - %repsf = arith.sitofp %repsi: i64 to f64 - %t_tot = arith.subf %t_end, %t_start : f64 - %t = arith.divf %t_tot, %repsf : f64 - - call @print_time(%t) : (f64) -> () - - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - %M = tensor.dim %C, %c0 : !memref_type_C - %N = tensor.dim %C, %c1 : !memref_type_C - %K = tensor.dim %A, %c0 : !memref_type_A - - %Mi32 = arith.index_cast %M: index to i64 - %Ni32 = arith.index_cast %N: index to i64 - %Ki32 = arith.index_cast %K: index to i64 - - %c2 = arith.constant 2 : i64 - %f1 = arith.muli %Mi32, %Ni32 : i64 - %f2 = arith.muli %f1, %Ki32 : i64 - %f3 = arith.muli %c2, %f2 : i64 - - // 2*M*N*K. - %num_flops_f = arith.sitofp %f3: i64 to f64 - %flops = arith.divf %num_flops_f, %t : f64 - call @print_flops(%flops) : (f64) -> () - - %i0 = arith.constant 0 : i32 - return %i0 : i32 -}} - - -func private @print_flops(f64) -func private @print_time(f64) -func private @printNewline() -func private @print_pid() -func private @rtclock() -> f64 -func private @print_memref_f32(memref<*xf32>) -func private @gemm(%A : !memref_type_A, %B : !memref_type_B, %C : !memref_type_C) -> !memref_type_C -""" - - - - -GEMM = """ -func @gemm(%A : !memref_type_A {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %B : !memref_type_B {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %C : !memref_type_C {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> !memref_type_C { - %0 = linalg.generic - {indexing_maps = [affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (k, n)>, - affine_map<(m, n, k) -> (m, n)>], - iterator_types = ["parallel", "parallel", "reduction"]} - - ins(%A, %B: !memref_type_A, !memref_type_B) - outs(%C: !memref_type_C) { - ^bb0(%a: f32, %b: f32, %c: f32) : - %d = arith.mulf %a, %b: f32 - %e = arith.addf %c, %d: f32 - linalg.yield %e : f32 - } -> !memref_type_C - return %0 : !memref_type_C - } -""" - -GEMM_T = """ -func @gemm(%A : !memref_type_A {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %B : !memref_type_B {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %C : !memref_type_C {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> !memref_type_C { - %0 = linalg.generic - {indexing_maps = [affine_map<(m, n, k) -> (k, m)>, - affine_map<(m, n, k) -> (k, n)>, - affine_map<(m, n, k) -> (m, n)>], - iterator_types = ["parallel", "parallel", "reduction"]} - - ins(%A, %B: !memref_type_A, !memref_type_B) - outs(%C: !memref_type_C) { - ^bb0(%a: f32, %b: f32, %c: f32) : - %d = arith.mulf %a, %b: f32 - %e = arith.addf %c, %d: f32 - linalg.yield %e : f32 - } -> !memref_type_C - return %0 : !memref_type_C - } -""" - -def gemm(trA): - if trA: - bench = gemm_benchmark.replace("__INIT_TENSORS__", str(init_tensors_t)) - return (typest + bench, typest + GEMM_T) - else: - bench = gemm_benchmark.replace("__INIT_TENSORS__", str(init_tensors)) - return (types + bench, types+ GEMM) diff --git a/experimental/alp/alp/mlirc.py b/experimental/alp/alp/mlirc.py deleted file mode 100644 index f46fdaf5c9b7..000000000000 --- a/experimental/alp/alp/mlirc.py +++ /dev/null @@ -1,60 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import sys -import argparse -from .utils import parse, run_command, print_command -from .compile_op import build_mlir - - -if __name__ == "__main__": - parser = argparse.ArgumentParser("mlirc") - - # GEMM size - parser.add_argument("--M", type=int) - parser.add_argument("--N", type=int) - parser.add_argument("--K", type=int) - - # Outer tiling - parser.add_argument("--tile-sizes", nargs='+', type=int) - parser.add_argument("--reorder-tile-sizes", nargs='+', type=int) - - # Inner tiling - parser.add_argument("--register-tile-sizes", nargs='+', type=int) - parser.add_argument("--reorder-register-tile-sizes", nargs='+', type=int) - parser.add_argument("--hoist-packing", nargs='+', type=int) - - # Vector lowering - parser.add_argument("--unroll-vector-transfers", action="store_true") - parser.add_argument("--split-vector-transfers-to") - - # micro-kernel transforms - parser.add_argument("--extract-micro-kernel", action="store_true") - parser.add_argument("--modulo-scheduling", action="store_true") - - # Verbosity - parser.add_argument("--verbose", action="store_true") - parser.add_argument("--verbosity-level", type=int, default=0) - parser.add_argument("--reps", type=int, default=1) - - args = parser.parse_args() - - stringify = lambda l : ','.join([str(e) for e in l]) - options = { "tile_sizes" : stringify(args.tile_sizes), - "register_tile_sizes" : stringify(args.register_tile_sizes), - "split_vector_transfers_to" : args.split_vector_transfers_to, - "unroll_vector_transfers" : args.unroll_vector_transfers, - "reorder_tile_sizes": stringify(args.reorder_tile_sizes), - "reorder_register_tile_sizes": stringify(args.reorder_register_tile_sizes), - "hoist_packing": stringify(args.hoist_packing), - "extract_micro_kernel": args.extract_micro_kernel, - "modulo_scheduling": args.modulo_scheduling, - "verbosity_level" : 0, - "reps": args.reps - } - - if (args.verbose): - options["verbosity_level"]=1 - if (args.verbosity_level > 0): - options["verbosity_level"]=args.verbosity_level - build_mlir("gemm", args.M, args.N, args.K, options) diff --git a/experimental/alp/alp/tuner.py b/experimental/alp/alp/tuner.py deleted file mode 100644 index d281fbda2e90..000000000000 --- a/experimental/alp/alp/tuner.py +++ /dev/null @@ -1,149 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#!/usr/bin/env python -import opentuner -from opentuner import ConfigurationManipulator -from opentuner.search.manipulator import IntegerParameter, PowerOfTwoParameter, EnumParameter, BooleanParameter -from opentuner import MeasurementInterface -from opentuner import Result -import sys - -from .utils import parse - -max_flops = 0 -class MLIRFlagsTuner(MeasurementInterface): - - def manipulator(self): - """ - Define the search space by creating a - ConfigurationManipulator - """ - manipulator = ConfigurationManipulator() - - manipulator.add_parameter( - PowerOfTwoParameter('mr', 4, 4)) - - manipulator.add_parameter( - PowerOfTwoParameter('nr', 16, 16)) - - manipulator.add_parameter( - PowerOfTwoParameter('kr', 16, 64)) - - manipulator.add_parameter( - PowerOfTwoParameter('kc', 64, 128)) - - manipulator.add_parameter( - PowerOfTwoParameter('mc', 256, 2048)) - - manipulator.add_parameter( - PowerOfTwoParameter('nc', 64, 2048)) - - manipulator.add_parameter( - IntegerParameter('ha', 4 , 4)) - - manipulator.add_parameter( - IntegerParameter('hb', 3 , 3)) - - return manipulator - - def run(self, desired_result, input, limit): - global max_flops - - - """ - Compile and run a given configuration then - return performance - """ - - cfg = desired_result.configuration.data - - mr = cfg['mr'] - nr = cfg['nr'] - kr = cfg['kr'] - kc = cfg['kc'] - mc = cfg['mc'] - nc = cfg['nc'] - ha = cfg['ha'] - hb = cfg['hb'] - # reordering = cfg['reorder'] - - M = self.args.M - N = self.args.N - K = self.args.K - - # mr = min(mr,mc) - # nr = min(nr,nc) - # kr = min(kr, kc) - # kr = kc - - cfg['mr'] = mr - cfg['nr'] = nr - cfg['kr'] = kr - reordering = "Afirst" - - if reordering == "Afirst": - reorder_inner = "0 1 2" - reorder_outer = "0 2 1" - else: - reorder_inner = "1 0 2" - reorder_outer = "1 2 0" - - hoisting_params = f"{ha} {hb} 0" - cmd = ['python3 -m alp.mlirc'] - cmd.append(f'--M {M}') - cmd.append(f'--N {N}') - cmd.append(f'--K {K}') - - cmd.append(f"--tile-sizes {mc} {nc} {kc}") - cmd.append(f"--register-tile-sizes {mr} {nr} {kr}") - cmd.append(f"--reorder-tile-sizes {reorder_outer}") - cmd.append(f"--reorder-register-tile-sizes {reorder_inner}") - - #if cfg['unrollVectorTransfers']: - cmd.append(f"--unroll-vector-transfers") - cmd.append(f"--split-vector-transfers-to none") # {cfg['splitVectorTransfersTo']}") - cmd.append(f"--hoist-packing {hoisting_params}") - - compile_result = self.call_program(' '.join(cmd)) - - - if compile_result['returncode'] != 0: - return Result(time=sys.maxsize) - - assert compile_result['returncode'] == 0 - - run_cmd = './exec_matmul' - run_result = self.call_program(run_cmd, limit=0.7) - - if run_result['returncode'] != 0: - return Result(time=sys.maxsize) - - assert run_result['returncode'] == 0 - - secs, flops = parse(run_result['stderr']) - - if(flops>max_flops): - s = ' '.join([str(elem) for elem in cmd]) - max_flops=flops - - - return Result(time=1/flops) - - def save_final_config(self, configuration): - """called at the end of tuning""" - print("Optimal block size written to mmm_final_config.json:", configuration.data) - M = self.args.M - N = self.args.N - K = self.args.K - self.manipulator().save_to_file(configuration.data, - f'mmm_final_config_{M}_{N}_{K}.json') - - -if __name__ == '__main__': - argparser = opentuner.default_argparser() - argparser.add_argument("--M", type=int) - argparser.add_argument("--N", type=int) - argparser.add_argument("--K", type=int) - MLIRFlagsTuner.main(argparser.parse_args()) diff --git a/experimental/alp/alp/utils.py b/experimental/alp/alp/utils.py deleted file mode 100644 index 02ed775241c2..000000000000 --- a/experimental/alp/alp/utils.py +++ /dev/null @@ -1,105 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import subprocess -import os -import numpy as np -from subprocess import PIPE, Popen - -def run_command(cmd): - #print(cmd) - output = subprocess.check_output(' '.join(cmd), shell=True) - return output.decode('ascii') - -def print_command(cmd): - print(' '.join(cmd)) - -def run_and_save(cmd, original_ir, new_ir): - out = run_command(cmd + [original_ir]) - f = open(f"{new_ir}", "w") - # Save the command that generated the IR - f.write("//"+' '.join(cmd)+"\n") - # Save the IR - f.write(out) - f.close() - -def add_extension(fname, ext): - orig_ext = os.path.splitext(fname)[1] - newfilename = os.path.splitext(fname)[0] + "." + ext + orig_ext - return newfilename - -def parse(out): - secs = 0 - flops = 0 - lines = out.split('\n') - for l in lines: - if not l: - continue - [a,b]= l.split() - if b == "secs": - secs = float(a) - if b == "GFLOPS": - flops = float(a) - return (secs, flops) - -def analytical_model(hw, Sdata): - # Analyitical model for GEMM - # https://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf - - # Vector unit properties - Nvec = hw["Nvec"] - Lvfma = hw["Lvfma"] - Nvfma = hw["Nvfma"] - - # Determine mr/nr - K = Nvec*Nvfma*Lvfma - mr = np.ceil((np.sqrt(K)/Nvec))*Nvec - nr = np.ceil(K/mr) - - # L1 properties - SL1 = hw["SL"][0]*1024 - WL1 = hw["WL"][0] - - # L2 properties - SL2 = hw["SL"][1] *1024 - WL2 = hw["WL"][1] - - if "CL" in hw: - CL1 = hw["CL"][0] - CL2 = hw["CL"][1] - NL1 = SL1/(WL1*CL1) - NL2 = SL2/(WL2*CL2) - elif "NL" in hw: - NL1 = hw["NL"][0] - NL2 = hw["NL"][1] - CL1 = SL1/(WL1*NL1) - CL2 = SL2/(WL2*NL2) - - # if L3 properties are specified, then determine nc - if hw["num_caches"] == 3: - SL3 = hw["SL"][2] * 1024 - WL3 = hw["WL"][2] - - if "CL" in hw: - CL3 = hw["CL"][2] - NL3 = SL3/(WL3*CL3) - elif "NL" in hw: - NL3 = hw["NL"][2] - CL3 = SL3/(WL3*NL3) - - # Determine kc - CAr = np.floor((WL1-1)/(1+nr/mr)) - kc = (CAr*NL1*CL1)/(mr*Sdata) - - # Determine mc - CBr2 = np.ceil(nr*kc*Sdata/(NL2*CL2)) - mc = ( (WL2-1-CBr2)*NL2*CL2/(kc*Sdata)) - - # Determine nc - if hw["num_caches"] == 3: - CAc3 = np.ceil(mc*kc*Sdata/(NL3*CL3)) - nc = ((WL3-CAc3-1)*NL3*CL3)/(kc*Sdata) - else: - nc = -1 - - return (mc, nc, kc, mr, nr) diff --git a/experimental/alp/include/alp/Transforms/Passes.h b/experimental/alp/include/alp/Transforms/Passes.h index 15648e9e9da5..06d50c87e063 100644 --- a/experimental/alp/include/alp/Transforms/Passes.h +++ b/experimental/alp/include/alp/Transforms/Passes.h @@ -24,10 +24,13 @@ std::unique_ptr> createExtractKernelPass(); std::unique_ptr> createExtractKernelTailPass(); /// Create a pass to modulo-schedule the kernel -std::unique_ptr createModuloSchedulingPass(); +std::unique_ptr> createModuloSchedulingPass(); /// Create a pass to legalize vectors in a given function -std::unique_ptr createLegalizePass(); +std::unique_ptr> createLegalizePass(); + +/// Create a pass to transform a For-Loop to a Do-While loop +std::unique_ptr> createForToDoWhileLoopPass(); //===----------------------------------------------------------------------===// // Registration diff --git a/experimental/alp/include/alp/Transforms/Passes.td b/experimental/alp/include/alp/Transforms/Passes.td index 0a3fe5bc0ec9..a7947f36d5ce 100644 --- a/experimental/alp/include/alp/Transforms/Passes.td +++ b/experimental/alp/include/alp/Transforms/Passes.td @@ -29,19 +29,21 @@ def ExtractKernelTailPass: Pass<"alp-extract-kernel-tail", "ModuleOp"> { ]; } -def ModuloSchedulingPass: FunctionPass<"alp-modulo-scheduling"> { +def ModuloSchedulingPass: Pass<"alp-modulo-scheduling", "FuncOp"> { let summary = "Pass to modulo-schedule a loop."; let constructor = "mlir::createModuloSchedulingPass()"; let options = [ Option<"unrolling", "unrolling", /*type*/"int", /*default=*/"2", "Unrolling level before scheduling the loop.">, + Option<"distance", "distance", /*type*/"int", /*default=*/"1", + "Unrolling level before scheduling the loop.">, Option<"interleave", "interleave", /*type*/"bool", /*default=*/"false", "interleave the kernel computation while modulo scheduling.">, ]; } -def LegalizePass: FunctionPass<"alp-legalize"> { +def LegalizePass: Pass<"alp-legalize", "FuncOp"> { let summary = "Pass to legalize vector operations."; let constructor = "mlir::createLegalizePass()"; let options = [ @@ -52,4 +54,14 @@ def LegalizePass: FunctionPass<"alp-legalize"> { ]; } +def ForToDoWhileLoop: Pass<"alp-for-to-dowhile", "FuncOp"> { + let summary = "Pass to legalize vector operations."; + let constructor = "mlir::createForToDoWhileLoopPass()"; + let options = [ + Option<"anchorFuncOpName", "anchor-func", "std::string", /*default=*/"\"kernel\"", + "Which func op is the anchor to latch on.">, + ]; +} + + #endif // ALP_LLVM_SANDBOX_PASSES diff --git a/experimental/alp/lib/AlpRuntime/alp_runtime.cpp b/experimental/alp/lib/AlpRuntime/alp_runtime.cpp index 564e504d89b2..ad01f8ced0f6 100644 --- a/experimental/alp/lib/AlpRuntime/alp_runtime.cpp +++ b/experimental/alp/lib/AlpRuntime/alp_runtime.cpp @@ -8,16 +8,39 @@ #include #include #include +#include #include /** Additional runtime functions used in Alp */ /// Print time (passed as a double constant) extern "C" void print_time(double time_s) { - fprintf(stderr, "%lf secs\n", time_s); + fprintf(stdout, "%lf secs\n", time_s); } /// Print the pid of the current application (for profiling purposes) extern "C" void print_pid() { int pid = getpid(); - fprintf(stderr, "pid: %i\n", pid); + fprintf(stdout, "pid: %i\n", pid); } + +/// Prints GFLOPS rating. +extern "C" void print_flops(double flops) { + fprintf(stdout, "%lf GFLOPS\n", flops / 1.0E9); +} + +extern "C" void printF32(float f) { fprintf(stdout, "%g", f); } +extern "C" void printNewline() { fputc('\n', stdout); } + +/// Returns the number of seconds since Epoch 1970-01-01 00:00:00 +0000 (UTC). +extern "C" double rtclock() { +#ifndef _WIN32 + struct timeval tp; + int stat = gettimeofday(&tp, NULL); + if (stat != 0) + fprintf(stdout, "Error returning time from gettimeofday: %d\n", stat); + return (tp.tv_sec + tp.tv_usec * 1.0e-6); +#else + fprintf(stderr, "Timing utility not implemented on Windows\n"); + return 0.0; +#endif // _WIN32 +} \ No newline at end of file diff --git a/experimental/alp/lib/Transforms/CMakeLists.txt b/experimental/alp/lib/Transforms/CMakeLists.txt index 53cb9beeebf2..8a6755c99069 100644 --- a/experimental/alp/lib/Transforms/CMakeLists.txt +++ b/experimental/alp/lib/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(ExperimentalAlpTransforms extract_kernel_pass.cpp modulo_scheduling_pass.cpp legalize_vector.cpp + for_to_dowhile_loop.cpp LINK_LIBS PRIVATE MLIRLinalg diff --git a/experimental/alp/lib/Transforms/extract_kernel_pass.cpp b/experimental/alp/lib/Transforms/extract_kernel_pass.cpp index da8265b64063..b04fa6b500e5 100644 --- a/experimental/alp/lib/Transforms/extract_kernel_pass.cpp +++ b/experimental/alp/lib/Transforms/extract_kernel_pass.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -22,7 +23,6 @@ #include #define DEBUG_TYPE "extract-kernel" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") using namespace mlir; @@ -32,14 +32,18 @@ void extract_function(StringRef func_name, Block *block, ModuleOp parentModule, // Create the function (callee site) with an empty block rewriter.setInsertionPointToStart(parentModule.getBody()); - auto func_op = rewriter.create(parentModule.getLoc(), func_name, FunctionType::get(parentModule.getContext(), Input, Output)); + auto func_op = rewriter.create( + parentModule.getLoc(), func_name, + FunctionType::get(parentModule.getContext(), Input, Output)); mlir::MLIRContext *ctx = parentModule.getContext(); llvm::SmallVector attrs; - attrs.push_back(mlir::ArrayAttr::get(ctx, - {mlir::StringAttr::get(ctx, "prefer-vector-width"), - mlir::StringAttr::get(ctx, "128")})); - attrs.push_back(mlir::ArrayAttr::get( ctx, {mlir::StringAttr::get(ctx, "target-cpu"), - mlir::StringAttr::get(ctx, "thunderx2t99")})); + attrs.push_back(mlir::ArrayAttr::get( + ctx, {mlir::StringAttr::get(ctx, "prefer-vector-width"), + mlir::StringAttr::get(ctx, "128")})); + // attrs.push_back(mlir::ArrayAttr::get( ctx, {mlir::StringAttr::get(ctx, + // "target-cpu"), + // mlir::StringAttr::get(ctx, + // "thunderx2t99")})); func_op->setAttr("passthrough", mlir::ArrayAttr::get(ctx, attrs)); auto entry_block = func_op.addEntryBlock(); @@ -52,6 +56,7 @@ void extract_function(StringRef func_name, Block *block, ModuleOp parentModule, // std::set consts; llvm::SmallVector vals; llvm::SmallVector consts; + llvm::SmallVector broadcasts; bool add_yield = false; // Walk the block and find out all the variables that were defined outside @@ -60,9 +65,11 @@ void extract_function(StringRef func_name, Block *block, ModuleOp parentModule, // entry block are constants. For all other variables, we will add them as // inputs to the function block->walk([&](Operation *inst) { - for (Value val: inst->getOperands()){ - if (dom_info.properlyDominates(val, parent_op)){ + for (Value val : inst->getOperands()) { + if (dom_info.properlyDominates(val, parent_op)) { arith::ConstantOp const_op = val.getDefiningOp(); + vector::BroadcastOp broadcast_op = + val.getDefiningOp(); if (const_op) { // It's useless to add many times the same index if (std::find(consts.begin(), consts.end(), const_op) == @@ -73,9 +80,20 @@ void extract_function(StringRef func_name, Block *block, ModuleOp parentModule, rewriter.replaceOpWithinBlock(const_op, new_const->getResult(0), block); } + } else if (broadcast_op) { + func_op.insertArgument(vals.size(), broadcast_op.source().getType(), + {}, loc); + + rewriter.setInsertionPointToStart(entry_block); + vector::BroadcastOp new_broadcast = + rewriter.create( + loc, broadcast_op.getType(), func_op.getArguments().back()); + rewriter.replaceOpWithinBlock(broadcast_op, + new_broadcast->getResult(0), block); + vals.push_back(broadcast_op.source()); } else { if (std::find(vals.begin(), vals.end(), val) == vals.end()) { - func_op.insertArgument(vals.size(), val.getType(), {}); + func_op.insertArgument(vals.size(), val.getType(), {}, loc); vals.push_back(val); } } @@ -101,11 +119,13 @@ void extract_function(StringRef func_name, Block *block, ModuleOp parentModule, // Step 1: get all the block arguments, add them as function arguments and // replece their use inside the block + int arg_pos = vals.size(); for (auto block_arg : block->getArguments()) { - func_op.insertArgument(vals.size(), block_arg.getType(), {}); - auto arg = func_op.getArgument(vals.size()); + func_op.insertArgument(arg_pos, block_arg.getType(), {}, loc); + auto arg = func_op.getArgument(arg_pos); block_arg.replaceAllUsesWith(arg); newtypes.push_back(block_arg.getType()); + arg_pos++; } // Step 2: replace all the values that are pointing outside the block and @@ -125,9 +145,7 @@ void extract_function(StringRef func_name, Block *block, ModuleOp parentModule, Block *succ = (has_no_successor ? nullptr : block->getSuccessor(0)); // Remove all arguments from the block signature - for (unsigned i = 0; i < block->getNumArguments(); i++) { - block->eraseArgument(i); - } + block->eraseArguments([](auto b) { return true; }); // Merge block into entry_block (this destroys block) rewriter.mergeBlocks(block, entry_block); @@ -139,11 +157,12 @@ void extract_function(StringRef func_name, Block *block, ModuleOp parentModule, // We are done with the callee. Now we have to work on the caller. The overall // idea is to insert a new_block right before the successor of the old block. // If the old block has no successors, then add it at the end of the region + llvm::SmallVector locs(newtypes.size(), loc); Block *new_block = nullptr; if (has_no_successor) { - new_block = rewriter.createBlock(region, region->end(), newtypes); + new_block = rewriter.createBlock(region, region->end(), newtypes, locs); } else { - new_block = rewriter.createBlock(succ, newtypes); + new_block = rewriter.createBlock(succ, newtypes, locs); } // Remember to add the block arguments as inputs to the function @@ -171,9 +190,14 @@ struct ExtractKernelPass : public ExtractKernelPassBase { void runOnOperation() override { // Get the current FuncOp operation being operated on. auto module = getOperation(); - scf::ForOp loop; + scf::ForOp loop = {}; for (FuncOp func : module.getOps()) { + Region &r = func.getRegion(); + // TODO(@Joey): this basically stops the pass to work because many + // functions (like gemm) are only composed of a single block + // if (r.hasOneBlock() || r.empty()) continue; + // Walk the operations within the function. func.walk([&](scf::ForOp forop) { if (forop.getNumIterOperands()) { @@ -182,6 +206,15 @@ struct ExtractKernelPass : public ExtractKernelPassBase { }); } + if (!loop) + return; + + // Do not extract return from current function. Split block to + // leave return in the next block. + Block *blockToExtract = loop->getBlock(); + if (dyn_cast(blockToExtract->back())) { + blockToExtract->splitBlock(&blockToExtract->back()); + } IRRewriter rewriter(module.getContext()); extract_function("kernel", loop->getBlock(), module, rewriter, module.getLoc()); @@ -192,13 +225,16 @@ std::unique_ptr> mlir::createExtractKernelPass() { return std::make_unique(); } -void extract_function_2(StringRef func_name, Block *block, ModuleOp parentModule, RewriterBase &rewriter, Block* origin, Location loc) -{ +void extract_function_2(StringRef func_name, Block *block, + ModuleOp parentModule, RewriterBase &rewriter, + Block *origin, Location loc) { SmallVector Input, Output; // Create the function (callee site) with an empty block rewriter.setInsertionPointToStart(parentModule.getBody()); - auto func_op = rewriter.create(parentModule.getLoc(), func_name, FunctionType::get(parentModule.getContext(), Input, Output)); + auto func_op = rewriter.create( + parentModule.getLoc(), func_name, + FunctionType::get(parentModule.getContext(), Input, Output)); auto entry_block = func_op.addEntryBlock(); // Build the dominance tree of the parent op of the block @@ -206,86 +242,77 @@ void extract_function_2(StringRef func_name, Block *block, ModuleOp parentModule Operation *parent_op = region->getParentOp(); auto dom_info = mlir::DominanceInfo(parent_op); - //std::set consts; + // std::set consts; llvm::SmallVector vals; llvm::SmallVector consts; + llvm::SmallVector affine_apply; - // Walk the block and find out all the variables that were defined outside + // Walk the block and find out all the variables that were defined outside // this block and are used inside the block (i.e., all the variables x that // properly dominate the block). The only things we will redefine inside the - // entry block are constants. For all other variables, we will add them as + // entry block are constants. For all other variables, we will add them as // inputs to the function block->walk([&](Operation *inst) { - for (Value val: inst->getOperands()){ - if (dom_info.properlyDominates(val, &block->getOperations().front())){ - arith::ConstantOp const_op = val.getDefiningOp(); - if (const_op){ + for (Value val : inst->getOperands()) { + if (dom_info.properlyDominates(val, &block->getOperations().front())) { + if (auto const_op = val.getDefiningOp()) { // It's useless to add many times the same index - if (std::find(consts.begin(), consts.end(), const_op) == consts.end()){ + if (std::find(consts.begin(), consts.end(), const_op) == + consts.end()) { consts.push_back(const_op); } - } else { - if (std::find(vals.begin(), vals.end(), val) == vals.end()){ - func_op.insertArgument(vals.size(), val.getType(), {}); + } else if (auto apply_op = val.getDefiningOp()) { + if (std::find(affine_apply.begin(), affine_apply.end(), apply_op) == + affine_apply.end()) { + affine_apply.push_back(apply_op); + auto apply_val = apply_op.getOperand(0); + if (std::find(vals.begin(), vals.end(), apply_val) == vals.end()) { + func_op.insertArgument(vals.size(), apply_val.getType(), {}, loc); + vals.push_back(apply_val); + } + } + } else { + if (std::find(vals.begin(), vals.end(), val) == vals.end()) { + func_op.insertArgument(vals.size(), val.getType(), {}, loc); vals.push_back(val); } } - } + } } }); llvm::SmallVector newtypes; - // We are not done yet. We need to merge the block into the entry block. To do this: - // 1 If an operation in the block is using a value coming from the block - // argument, add the value as function argument and replace the value with it - // 2 If an operation in the block is using a value generated outside the - // block, simply replace its value with a funciton argument - - // Step 1: get all the block arguments, add them as function arguments and - // replece their use inside the block - //for (auto block_arg : block->getArguments()) - //{ - // func_op.insertArgument(vals.size(), block_arg.getType(), {}); - // auto arg = func_op.getArgument(vals.size()); - // block_arg.replaceAllUsesWith(arg); - // newtypes.push_back(block_arg.getType()); - //} - - // Step 2: replace all the values that are pointing outside the block and - // replace them with function arguments - auto args = func_op.getArguments(); - for (unsigned i = 0; i < vals.size(); i++) - { - auto val = vals[i]; - auto arg = args[i]; - val.replaceUsesWithIf(arg, [&](OpOperand &op) - { - Operation *target = op.getOwner(); - for(Operation &op : block->getOperations()){ - if (&op == target) return true; - } - return false; - }); - } - - // Remove all arguments from the block signature - //for (unsigned i = 0; igetNumArguments(); i++){ - // block->eraseArgument(i); - //} - - // Merge block into entry_block (this destroys block) // Add constants rewriter.mergeBlocks(block, entry_block); rewriter.setInsertionPointToStart(entry_block); - for (Operation * c : consts){ - Operation * new_const = rewriter.clone(*c); + for (Operation *c : consts) { + Operation *new_const = rewriter.clone(*c); rewriter.replaceOpWithinBlock(c, new_const->getResult(0), entry_block); } + for (Operation *c : affine_apply) { + Operation *new_apply = rewriter.clone(*c); + rewriter.replaceOpWithinBlock(c, new_apply->getResult(0), entry_block); + } + + auto args = func_op.getArguments(); + for (unsigned i = 0; i < vals.size(); i++) { + auto val = vals[i]; + auto arg = args[i]; + val.replaceUsesWithIf(arg, [&](OpOperand &op) { + Operation *target = op.getOwner(); + for (Operation &op : entry_block->getOperations()) { + if (&op == target) + return true; + } + return false; + }); + } + // Add a returnOp into the block to properly terminate it rewriter.setInsertionPointToEnd(entry_block); - //rewriter.create(loc); + // rewriter.create(loc); // We are done with the callee. Now we have to work on the caller. The overall // idea is to insert a new_block right before the successor of the old block. @@ -293,46 +320,45 @@ void extract_function_2(StringRef func_name, Block *block, ModuleOp parentModule rewriter.setInsertionPointToEnd(origin); - // Create the call + // Create the call rewriter.create(loc, func_op, vals); rewriter.create(loc); } -// Walk the for loops and find the one that as operands. In GEMM is the micro-kernel. -// TODO: we should have the linalg::split to signal the microkernel of the operation -// and use it to run the function extractor if needed -struct ExtractKernelTailPass : public ExtractKernelTailPassBase { +// Walk the for loops and find the one that as operands. In GEMM is the +// micro-kernel. +// TODO: we should have the linalg::split to signal the microkernel of the +// operation and use it to run the function extractor if needed +struct ExtractKernelTailPass + : public ExtractKernelTailPassBase { ExtractKernelTailPass() = default; - ExtractKernelTailPass(const ExtractKernelTailPass& pass) { } - void getDependentDialects(DialectRegistry ®istry) const override { - - } + ExtractKernelTailPass(const ExtractKernelTailPass &pass) {} + void getDependentDialects(DialectRegistry ®istry) const override {} void runOnOperation() override { // Get the current FuncOp operation being operated on. auto module = getOperation(); scf::ForOp loop; + LLVM_DEBUG(llvm::dbgs() << "extract_kernel_tail starts\n "); for (FuncOp func : module.getOps()) { // Walk the operations within the function. func.walk([&](scf::ForOp forop) { - if (forop.getNumIterOperands()){ + if (forop.getNumIterOperands()) { loop = forop; } }); } IRRewriter rewriter(module.getContext()); - Block *tail = rewriter.splitBlock(loop->getBlock(), Block::iterator(loop->getNextNode())); - //rewriter.setInsertionPointToEnd(loop->getBlock()); - // Create the call - //rewriter.create(loc, func_op, vals); - //rewriter.create(module.getLoc()); - extract_function_2("kernel_tail", tail, module, rewriter, loop->getBlock(), module.getLoc()); - + Block *tail = rewriter.splitBlock(loop->getBlock(), + Block::iterator(loop->getNextNode())); + extract_function_2("kernel_tail", tail, module, rewriter, loop->getBlock(), + module.getLoc()); } }; -std::unique_ptr> mlir::createExtractKernelTailPass() { +std::unique_ptr> +mlir::createExtractKernelTailPass() { return std::make_unique(); } diff --git a/experimental/alp/lib/Transforms/for_to_dowhile_loop.cpp b/experimental/alp/lib/Transforms/for_to_dowhile_loop.cpp new file mode 100644 index 000000000000..4e2fba3bd53a --- /dev/null +++ b/experimental/alp/lib/Transforms/for_to_dowhile_loop.cpp @@ -0,0 +1,148 @@ +//===-- for_to_dowhile.cpp - Implement for to do-while tranformation ------*- +// c++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "alp/Transforms/PassDetail.h" +#include "alp/Transforms/Passes.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include + +#include "mlir/Dialect/SCF/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "for-to-dowhile" +using namespace mlir; + +namespace { +using scf::ForOp; +using scf::WhileOp; + +struct ForLoopLoweringPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ForOp forOp, + PatternRewriter &rewriter) const override { + // Generate type signature for the loop-carried values. The induction + // variable is placed first, followed by the forOp.iterArgs. + SmallVector lcvTypes; + lcvTypes.push_back(forOp.getInductionVar().getType()); + llvm::transform(forOp.getInitArgs(), std::back_inserter(lcvTypes), + [&](auto v) { return v.getType(); }); + + // Build scf.WhileOp + SmallVector initArgs; + initArgs.push_back(forOp.getLowerBound()); + llvm::append_range(initArgs, forOp.getInitArgs()); + + // We need to add an if-else condition to avoid executing the first + // iteration + auto shouldWhileExecute = rewriter.create( + forOp.getLoc(), arith::CmpIPredicate::sgt, forOp.getUpperBound(), + forOp.getLowerBound()); + + auto if_while_should_execute = rewriter.create( + forOp.getLoc(), forOp.getResultTypes(), shouldWhileExecute, + forOp.getNumIterOperands() > 0); + rewriter.setInsertionPointToStart( + &if_while_should_execute.getThenRegion().front()); + + // The while-loop should be contained within the then region + auto whileOp = rewriter.create(forOp.getLoc(), lcvTypes, initArgs, + forOp->getAttrs()); + + llvm::SmallVector locs(lcvTypes.size(), forOp.getLoc()); + auto *beforeBlock = rewriter.createBlock( + &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, locs); + + auto *afterBlock = rewriter.createBlock( + &whileOp.getAfter(), whileOp.getAfter().begin(), lcvTypes, locs); + + // Rewrite uses of the for-loop block arguments to the new while-loop + // "after" arguments + for (auto barg : enumerate(forOp.getBody(0)->getArguments())) + barg.value().replaceAllUsesWith(beforeBlock->getArgument(barg.index())); + // Inline for-loop body operations into 'before' region (except Yield). + llvm::SmallVector nextIterArgs; + for (auto &arg : llvm::make_early_inc_range(*forOp.getBody())) { + if (auto yieldOp = dyn_cast(&arg)) { + nextIterArgs = yieldOp.getOperands(); + } else { + arg.moveBefore(beforeBlock, beforeBlock->end()); + } + } + + // 'before' region contains the loop condition and forwarding of iteration + // arguments to the 'after' region. + rewriter.setInsertionPointToEnd(&whileOp.getBefore().front()); + + // Add induction variable incrementation + auto ivIncOp = rewriter.create( + whileOp.getLoc(), beforeBlock->getArgument(0), forOp.getStep()); + auto cmpOp = rewriter.create(whileOp.getLoc(), + arith::CmpIPredicate::slt, + ivIncOp, forOp.getUpperBound()); + + nextIterArgs.insert(nextIterArgs.begin(), ivIncOp.getResult()); + rewriter.create(whileOp.getLoc(), cmpOp.getResult(), + nextIterArgs); + + // Inline for-loop body into an executeRegion operation in the "after" + // region. The return type of the execRegionOp does not contain the + // iv - yields in the source for-loop contain only iterArgs. + + // SmallVector yieldOperands; + rewriter.setInsertionPointToEnd(afterBlock); + rewriter.create(whileOp.getLoc(), afterBlock->getArguments()); + + llvm::SmallVector if_values; + for (auto arg : llvm::enumerate(forOp.getResults())) { + if_values.push_back(whileOp.getResult(arg.index() + 1)); + } + + if (if_values.size() > 0) { + rewriter.setInsertionPointAfter(whileOp); + rewriter.create(whileOp.getLoc(), if_values); + } + + if (forOp.getNumIterOperands() > 0) { + rewriter.setInsertionPointToStart( + &if_while_should_execute.getElseRegion().front()); + rewriter.create(whileOp.getLoc(), forOp.getInitArgs()); + } + + rewriter.replaceOp(forOp, if_while_should_execute.getResults()); + + return success(); + } +}; +struct ForToDoWhileLoop : public ForToDoWhileLoopBase { + void runOnOperation() override { + // Apply on the given function name + + FuncOp funcOp = getOperation(); + if (anchorFuncOpName != funcOp.getName()) { + return; + } + LLVM_DEBUG(llvm::dbgs() << "loop conversion starts\n "); + + MLIRContext *ctx = funcOp.getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } +}; +} // namespace + +std::unique_ptr> +mlir::createForToDoWhileLoopPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/experimental/alp/lib/Transforms/legalize_vector.cpp b/experimental/alp/lib/Transforms/legalize_vector.cpp index 08111dbcf138..917999d330c1 100644 --- a/experimental/alp/lib/Transforms/legalize_vector.cpp +++ b/experimental/alp/lib/Transforms/legalize_vector.cpp @@ -5,14 +5,14 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -#include "alp/Transforms/Passes.h" #include "alp/Transforms/PassDetail.h" +#include "alp/Transforms/Passes.h" #include "mlir/Dialect/SCF/SCF.h" -#include "mlir/Dialect/Vector/VectorOps.h" -#include "mlir/Dialect/Vector/VectorTransforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/Transforms/LoopUtils.h" +//#include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Pass/Pass.h" @@ -22,7 +22,6 @@ #include #define DEBUG_TYPE "legalize-vector" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") using namespace mlir; using namespace mlir::vector; @@ -33,18 +32,20 @@ struct ForOpVectorProgates : public OpRewritePattern { LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter) const final { - InsertStridedSliceOp insertOp = forOp.getIterOperands()[0].getDefiningOp(); + InsertStridedSliceOp insertOp = + forOp.getIterOperands()[0].getDefiningOp(); - if (!insertOp){ + if (!insertOp) { return failure(); } SmallVector newIterOperands; newIterOperands.push_back(insertOp.dest()); + auto loc = forOp.getLoc(); - for(auto op : forOp.getIterOperands()){ - if (op == insertOp){ + for (auto op : forOp.getIterOperands()) { + if (op == insertOp) { continue; } newIterOperands.push_back(op); @@ -52,43 +53,51 @@ struct ForOpVectorProgates : public OpRewritePattern { newIterOperands.push_back(insertOp.source()); int64_t working_offset1 = extractFromI64ArrayAttr(insertOp.offsets())[0]; int64_t working_offset2 = extractFromI64ArrayAttr(insertOp.offsets())[1]; - - Block &oldBlock = forOp.region().front(); - oldBlock.addArgument(insertOp.source().getType()); + + Block &oldBlock = forOp.getRegion().front(); + oldBlock.addArgument(insertOp.source().getType(), loc); scf::ForOp newForOp = rewriter.create( - forOp.getLoc(), forOp.lowerBound(), forOp.upperBound(), forOp.step(), newIterOperands); + loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), + newIterOperands); - Block &newBlock = newForOp.region().front(); + Block &newBlock = newForOp.getRegion().front(); SmallVector extraYieldOperands; - oldBlock.walk([&](Operation *instr){ - if (auto extractOp = dyn_cast(instr)){ - // if offsets are the same - int64_t offset1 = extractFromI64ArrayAttr(extractOp.offsets())[0]; - int64_t offset2 = extractFromI64ArrayAttr(extractOp.offsets())[1]; - if (offset1 == working_offset1 && offset2 == working_offset2){ - extractOp.getResult().replaceAllUsesWith(oldBlock.getArguments().back()); - rewriter.eraseOp(extractOp); - } - } - else if (auto insertOp = dyn_cast(instr)){ - int64_t offset1 = extractFromI64ArrayAttr(insertOp.offsets())[0]; - int64_t offset2 = extractFromI64ArrayAttr(insertOp.offsets())[1]; - if (offset1 == working_offset1 && offset2 == working_offset2){ - insertOp.getResult().replaceAllUsesWith(insertOp.dest()); - extraYieldOperands.push_back(insertOp.source()); - rewriter.eraseOp(insertOp); - } - } + oldBlock.walk([&](Operation *instr) { + if (auto extractOp = dyn_cast(instr)) { + if (std::find(oldBlock.getArguments().begin(), + oldBlock.getArguments().end(), + extractOp.vector()) != oldBlock.getArguments().end()) { + // if offsets are the same + int64_t offset1 = extractFromI64ArrayAttr(extractOp.offsets())[0]; + int64_t offset2 = extractFromI64ArrayAttr(extractOp.offsets())[1]; + if (offset1 == working_offset1 && offset2 == working_offset2) { + extractOp.getResult().replaceAllUsesWith( + oldBlock.getArguments().back()); + rewriter.eraseOp(extractOp); + } + } + } else if (auto insertOp = dyn_cast(instr)) { + // if (std::find(forOp.getResults().begin(), forOp.getResults().end(), + // insertOp.dest()) != forOp.getResults().end()){ + int64_t offset1 = extractFromI64ArrayAttr(insertOp.offsets())[0]; + int64_t offset2 = extractFromI64ArrayAttr(insertOp.offsets())[1]; + if (offset1 == working_offset1 && offset2 == working_offset2) { + insertOp.getResult().replaceAllUsesWith(insertOp.dest()); + extraYieldOperands.push_back(insertOp.source()); + rewriter.eraseOp(insertOp); + } + // } + } }); SmallVector newBlockTransferArgs(newBlock.getArguments().begin(), - newBlock.getArguments().end()); + newBlock.getArguments().end()); rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs); auto clonedYieldOp = cast(newBlock.getTerminator()); SmallVector newYieldOperands = clonedYieldOp.getOperands(); - for (auto val : extraYieldOperands){ + for (auto val : extraYieldOperands) { newYieldOperands.push_back(val); } rewriter.setInsertionPoint(clonedYieldOp); @@ -98,7 +107,7 @@ struct ForOpVectorProgates : public OpRewritePattern { forOp.getResult(0).replaceAllUsesWith(newForOp.getResult(0)); SmallVector newResults; - for (unsigned i =0; i { auto users = newForOp.getResult(0).getUsers(); // And now let's change this - for (Operation *user : users){ - if (auto extractOp = dyn_cast(user)){ - int64_t offset1 = extractFromI64ArrayAttr(extractOp.offsets())[0]; - int64_t offset2 = extractFromI64ArrayAttr(extractOp.offsets())[1]; - if (offset1 == working_offset1 && offset2 == working_offset2){ - extractOp.getResult().replaceAllUsesWith(newForOp.getResults().back()); - rewriter.eraseOp(extractOp); - } + for (Operation *user : users) { + if (auto extractOp = dyn_cast(user)) { + int64_t offset1 = extractFromI64ArrayAttr(extractOp.offsets())[0]; + int64_t offset2 = extractFromI64ArrayAttr(extractOp.offsets())[1]; + if (offset1 == working_offset1 && offset2 == working_offset2) { + extractOp.getResult().replaceAllUsesWith( + newForOp.getResults().back()); + rewriter.eraseOp(extractOp); + } } } return success(); } }; -struct LegalizePass : public LegalizePassBase{ +struct LegalizePass : public LegalizePassBase { LegalizePass() = default; - LegalizePass(const LegalizePass&pass) {} + LegalizePass(const LegalizePass &pass) {} void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } - void runOnFunction() override { + void runOnOperation() override { auto *ctx = &getContext(); - if (getFunction().getName() != "kernel"){ - return; + if (getOperation().getName() != "kernel") { + return; } + LLVM_DEBUG(llvm::dbgs() << "legalize starts\n "); + RewritePatternSet patterns(ctx); populateVectorUnrollPatterns( patterns, UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( filter)); - (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); // Now there will be a lot of: // - transfer_read/insert_strided_slices // - transfer_write/extract_strided_slices - // Next step is to remove as many {insert,extract}_strided_slices as we can, especially across loop blocks + // Next step is to remove as many {insert,extract}_strided_slices as we can, + // especially across loop blocks RewritePatternSet extra_patterns(ctx); extra_patterns.add(ctx); - (void)applyPatternsAndFoldGreedily(getFunction(), std::move(extra_patterns)); + (void)applyPatternsAndFoldGreedily(getOperation(), + std::move(extra_patterns)); } private: // Return the target shape based on op type. static Optional> getShape(Operation *op) { - if (isa(op)) + if (isa(op)) return SmallVector(2, 4); - if (isa(op)){ + if (isa(op)) { return SmallVector{4, 4, 1}; } // For transfer ops, just propagate the shape coming from // InsertStridedSlices/ExtractStridedSlices. if (auto readOp = dyn_cast(op)) { - VectorType dstVec; - for (Operation *users : readOp->getUsers()) { - auto extract = dyn_cast(users); - scf::ForOp loop = dyn_cast(users); - if (loop){ - if (loop.getIterOperands()[0] == readOp){ - OpOperand &val_arg = loop.getIterOpOperands()[0]; - auto val = loop.getRegionIterArgForOpOperand(val_arg); - for (Operation *arg_users : val.getUsers()){ - extract = dyn_cast(arg_users); - } - } - } - if (!extract){ - return llvm::None; - } - auto vecType = extract.getResult().getType().cast(); - if (dstVec && dstVec != vecType) - return llvm::None; - dstVec = vecType; - } - return SmallVector(dstVec.getShape().begin(), - dstVec.getShape().end()); + VectorType dstVec; + for (Operation *users : readOp->getUsers()) { + auto extract = dyn_cast(users); + scf::ForOp loop = dyn_cast(users); + if (loop) { + if (loop.getIterOperands()[0] == readOp) { + OpOperand &val_arg = loop.getIterOpOperands()[0]; + auto val = loop.getRegionIterArgForOpOperand(val_arg); + for (Operation *arg_users : val.getUsers()) { + extract = dyn_cast(arg_users); + } + } + } + if (!extract) { + return llvm::None; + } + auto vecType = extract.getResult().getType().cast(); + if (dstVec && dstVec != vecType) + return llvm::None; + dstVec = vecType; + } + return SmallVector(dstVec.getShape().begin(), + dstVec.getShape().end()); } if (auto writeOp = dyn_cast(op)) { - auto insert = writeOp.vector().getDefiningOp(); - auto loop = writeOp.vector().getDefiningOp(); - if (loop){ - auto yieldOp = cast(&loop.region().front().back()); + auto insert = writeOp.vector().getDefiningOp(); + auto loop = writeOp.vector().getDefiningOp(); + if (loop) { + auto yieldOp = cast(&loop.getRegion().front().back()); insert = yieldOp.getOperand(0).getDefiningOp(); - } - if (!insert) - return llvm::None; - ArrayRef shape = insert.getSourceVectorType().getShape(); - return SmallVector(shape.begin(), shape.end()); + } + if (!insert) + return llvm::None; + ArrayRef shape = insert.getSourceVectorType().getShape(); + return SmallVector(shape.begin(), shape.end()); } return llvm::None; } @@ -205,7 +219,6 @@ struct LegalizePass : public LegalizePassBase{ } }; - -std::unique_ptr mlir::createLegalizePass() { +std::unique_ptr> mlir::createLegalizePass() { return std::make_unique(); -} \ No newline at end of file +} diff --git a/experimental/alp/lib/Transforms/modulo_scheduling_pass.cpp b/experimental/alp/lib/Transforms/modulo_scheduling_pass.cpp index 2b6d84dc8be7..9eb81c8c30e1 100644 --- a/experimental/alp/lib/Transforms/modulo_scheduling_pass.cpp +++ b/experimental/alp/lib/Transforms/modulo_scheduling_pass.cpp @@ -1,4 +1,5 @@ -//===-- modulo_scheduling_pass.cpp - Implement modulo scheduling ------*- c++ -*-===// +//===-- modulo_scheduling_pass.cpp - Implement modulo scheduling ------*- c++ +//-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -8,160 +9,161 @@ #include "alp/Transforms/PassDetail.h" #include "alp/Transforms/Passes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" -#include +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/Passes.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/Transforms.h" -#include "mlir/Dialect/SCF/Utils.h" -#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/LoopUtils.h" +#include + +#define DEBUG_TYPE "modulo-scheduling" using namespace mlir; -namespace{ - enum StageType{ - Compute, - Load - }; - - struct ModuloSchedulingPass : public ModuloSchedulingPassBase - { - //ModuloScheduling(int unrollFactor):unrollFactor_(unrollFactor){} - ModuloSchedulingPass() = default; - ModuloSchedulingPass(const ModuloSchedulingPass &pass) {} - void getDependentDialects(DialectRegistry ®istry) const override - { - registry.insert(); - } - void runOnFunction() override - { - // Get the current FuncOp operation being operated on. - auto f = getFunction(); - - scf::ForOp loop; - - // Unroll the kernel - f.walk([&](scf::ForOp forop) - { - if (forop.getNumIterOperands()) - { - loop = forop; - } - }); - - if (loop) - { - // Unroll - auto annotateFn = [this](unsigned i, Operation *op, OpBuilder b) - { - op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i)); - }; - - (void)loopUnrollByFactor(loop, unrolling, annotateFn); - - // Order/stage the instruction within the for loop. We are looking for a pattern like - // %x0 = load -> (stage0, pos3) - // %y0 = load -> (stage0, pos4) - // %z0 = outerprod(%x0, %y0) -> (stage1, pos2) - // %x1 = load -> (stage1, pos0) - // %y1 = load -> (stage1, pos1) - // %z1 = outerprod(%x1, %y1) -> (stage1, pos5) - - std::unordered_map stage_map; - std::vector> compute_queue(unrolling); - std::vector> load_queue(unrolling); - - int stage = 0; - - StageType stage_type = Load; - int current_compute_queue = 0; - int current_load_queue = 0; - - // Take care of the stages - for (Operation &operation : loop.getBody()->getOperations()) - { - Operation *op = &operation; - if (dyn_cast(op)) - { - continue; - } - // This is a state machine with two states - if (stage_type == Compute){ - if (auto compute_op = dyn_cast(op)){ - compute_queue[current_compute_queue].push_back(op); - } else { - stage_type = Load; - current_compute_queue++; - load_queue[current_load_queue].push_back(op); - } - } else {// if stage_type == Load - if (auto compute_op = dyn_cast(op)){ - stage_type = Compute; - if (current_compute_queue == 0){ - stage = 1; - } - current_load_queue++; - compute_queue[current_compute_queue].push_back(op); - } else { - load_queue[current_load_queue].push_back(op); - } +namespace { +enum StageType { Compute, Load }; + +struct ModuloSchedulingPass + : public ModuloSchedulingPassBase { + // ModuloScheduling(int unrollFactor):unrollFactor_(unrollFactor){} + ModuloSchedulingPass() = default; + ModuloSchedulingPass(const ModuloSchedulingPass &pass) {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { + LLVM_DEBUG(llvm::dbgs() << "modulo scheudling starts\n "); + // Get the current FuncOp operation being operated on. + FuncOp f = getOperation(); + + scf::ForOp loop; + + // Unroll the kernel + f.walk([&](scf::ForOp forop) { + if (forop.getNumIterOperands()) { + loop = forop; + } + }); + + if (loop) { + assert(distance <= unrolling && + "pipeline distance cannot be bigger than unrolling"); + + // Unroll + (void)loopUnrollByFactor(loop, unrolling); + + /* Order/stage the instructions within the for loop. + + Initial loop (unroll==4) looks like: + for (...){ L0 C0 L1 C1 L2 C2 } + + New loop should look like(distance==2): + L0 L1 for (...){ C0 L2 C1 L3 C2 L0 C3 L1 } C0 C1 + + */ + std::unordered_map stage_map; + std::vector> compute_queue(unrolling); + std::vector> load_queue(unrolling); + + StageType stage_type = Load; + int current_compute_queue = 0; + int current_load_queue = 0; + // Take care of the stages + for (Operation &operation : loop.getBody()->getOperations()) { + Operation *op = &operation; + if (dyn_cast(op)) { + continue; + } + // This is a state machine with two states + if (stage_type == Compute) { + if (auto compute_op = dyn_cast(op)) { + compute_queue[current_compute_queue].push_back(op); + } else { + stage_type = Load; + current_compute_queue++; + load_queue[current_load_queue].push_back(op); + } + } else { // if stage_type == Load + if (auto compute_op = dyn_cast(op)) { + stage_type = Compute; + current_load_queue++; + compute_queue[current_compute_queue].push_back(op); + } else { + load_queue[current_load_queue].push_back(op); } - stage_map[op] = stage; } + } + + for (int i = 0; i < unrolling; i++) { + for (Operation *op : load_queue[i]) { + stage_map[op] = std::min(int(distance), i); + } + for (Operation *op : compute_queue[i]) { + stage_map[op] = distance; + } + } mlir::scf::PipeliningOption options; - options.getScheduleFn = [&](scf::ForOp forOp, std::vector> &schedule) { - schedule.resize(forOp.getBody()->getOperations().size() - 1); - int pos = 0; - for (int i =0; i < unrolling; i++){ - auto current_comp_block = compute_queue[i]; - // Loads from the next block - auto current_load_block = load_queue[(i+1)%unrolling]; - - // Get to the first load op (while scheduling the rest) - unsigned load_start= 0; - for (;load_start < current_load_block.size(); load_start++){ - Operation *op = current_load_block[load_start]; - if (!dyn_cast(op)){ - schedule[pos++] = {op, stage_map[op]}; - } else { - break; - } - } - size_t real_load_size = current_load_block.size() - load_start; - unsigned min_size = 0; - if (interleave){ - min_size = std::min(current_comp_block.size(), real_load_size); - } - for (unsigned j = 0; j> &schedule) { + schedule.resize(forOp.getBody()->getOperations().size() - 1); + + int pos = 0; + for (int i = 0; i < unrolling; i++) { + auto current_comp_block = compute_queue[i]; + // Loads from the next block + auto current_load_block = load_queue[(i + distance) % unrolling]; + + // Get to the first load op (while scheduling the rest) + unsigned load_start = 0; + for (; load_start < current_load_block.size(); load_start++) { + Operation *op = current_load_block[load_start]; + if (!dyn_cast(op)) { + schedule[pos++] = {op, stage_map[op]}; + } else { + break; + } + } + size_t real_load_size = current_load_block.size() - load_start; + unsigned min_size = 0; + if (interleave) { + min_size = std::min(current_comp_block.size(), real_load_size); + } + for (unsigned j = 0; j < min_size; j++) { + Operation *load_op = current_load_block[j + load_start]; + Operation *compute_op = current_comp_block[j]; + schedule[pos++] = {compute_op, stage_map[compute_op]}; + schedule[pos++] = {load_op, stage_map[load_op]}; + } - if (real_load_size > min_size){ - for (unsigned j = min_size; j min_size) { + for (unsigned j = min_size; j < real_load_size; j++) { + Operation *load_op = current_load_block[j + load_start]; + schedule[pos++] = {load_op, stage_map[load_op]}; + } + } - if (current_comp_block.size() > min_size){ - for (unsigned j = min_size; j min_size) { + for (unsigned j = min_size; j < current_comp_block.size(); + j++) { + Operation *compute_op = current_comp_block[j]; + schedule[pos++] = {compute_op, stage_map[compute_op]}; + } + } } - } - - } - }; + }; RewritePatternSet patterns(&getContext()); + // TODO: we should upstream this + // options.dynamic_loops = + // scf::PipelineDynamicLoops::EnableLoopVersioning; scf::populateSCFLoopPipeliningPatterns(patterns, options); (void)applyOpPatternsAndFold(loop, std::move(patterns)); } @@ -170,6 +172,7 @@ namespace{ }; } // namespace -std::unique_ptr mlir::createModuloSchedulingPass() { +std::unique_ptr> +mlir::createModuloSchedulingPass() { return std::make_unique(); } diff --git a/experimental/alp/python/alp/__init__.py b/experimental/alp/python/alp/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/experimental/alp/python/alp/backend/__init__.py b/experimental/alp/python/alp/backend/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/experimental/alp/python/alp/backend/codegen.py b/experimental/alp/python/alp/backend/codegen.py new file mode 100644 index 000000000000..d50e9edb58dd --- /dev/null +++ b/experimental/alp/python/alp/backend/codegen.py @@ -0,0 +1,25 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from pathlib import Path +from .utils import print_command, run_command + + +def codegen(source, dest, scheduler="ilpmax"): + cmd = ["$IREE_LLVM_SANDBOX_BUILD_DIR/bin/llc"] + cmd.append(source) + cmd.append("-filetype=asm") + cmd.append("-O3") + + # Scheduling options + if scheduler == "ilpmax": + cmd.append("--misched=ilpmax") + elif scheduler == "shuffle": + cmd.append("--misched=shuffle") + else: + raise (ValueError("Invalid scheduler algorithm")) + + # Register allocator options + cmd.append(f"-o {dest}") + run_command(cmd) diff --git a/experimental/alp/python/alp/backend/mlirc.py b/experimental/alp/python/alp/backend/mlirc.py new file mode 100644 index 000000000000..5d3f18ec18e9 --- /dev/null +++ b/experimental/alp/python/alp/backend/mlirc.py @@ -0,0 +1,198 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import os +import argparse +from pathlib import Path + +# ALP specifics imports +from .utils import run_and_save, run_command, add_extension +from .codegen import codegen +from .transforms import Pipeline, ExtractKernel, ConvertLoops + +# Sandbox imports +import mlir.iree_sandbox as sandbox +import mlir.ir as ir +import mlir.dialects.linalg_transform as transform +from mlir.iree_sandbox import register_sandbox_passes_and_dialects +from examples.core.transforms import * +from examples.core.transform import ( + TransformListFactory, + TransformationList, + PrintIR, + SaveIR, +) + +# Standalone +from mlir.ir import * +from mlir.dialects import arith, builtin, linalg, tensor, scf, std, memref +from mlir.dialects.linalg.opdsl.lang import * + +Tiling = Tile.then(Tile).then(Vectorize).then(Bufferize) + + +def apply(transform, source, dest, verbosity_level=0): + sourcef = open(source) + destf = open(dest, "w") + mlir_asm = sourcef.read() + + with Context() as ctx: + register_sandbox_passes_and_dialects(ctx) + module = Module.parse(asm=mlir_asm) + out = transform("gemm", module) + module_transformed = str(module) + destf.write(module_transformed) + + +def translate_to_llvm(source, dest): + out = run_command([ + "$IREE_LLVM_SANDBOX_BUILD_DIR/bin/mlir-translate --mlir-to-llvmir " + + source + ]) + f = open(dest, "w") + f.write(out) + f.close() + + +def generate_interchange_2d(transpose_flags): + + def gen_map(flag): + id_map = [0, 1] + if flag: + return id_map[::-1] + return id_map + + return map(gen_map, transpose_flags) + + +def generate_transform_pipeline(options): + + # Transformation options + tile_sizes = options["tile_sizes"] + reorder_tile_sizes = options["reorder_tile_sizes"] + register_tile_sizes = options["register_tile_sizes"] + reorder_register_tile_sizes = options["reorder_register_tile_sizes"] + hoist_packing = options["hoist_packing"] + split_vector_transfer = options["split_vector_transfers_to"] + extract_micro_kernel = options["extract_micro_kernel"] + modulo_scheduling = options["modulo_scheduling"] + ms_unroll = options["ms_unroll"] if options["ms_unroll"] else 2 + ms_distance = options["ms_distance"] if options["ms_distance"] else 1 + transpose_packing = options["transpose_packing"] + + tile = Tiling( + "gemm", + "linalg.generic", + tile_sizes1=tile_sizes, + tile_interchange1=reorder_tile_sizes, + tile_sizes2=register_tile_sizes, + tile_interchange2=reorder_register_tile_sizes, + pad2=True, + pack_paddings2=[1, 1, 0], + hoist_paddings2=hoist_packing, + transpose_paddings2=generate_interchange_2d(transpose_packing), + vectorize_padding=True, + ) + + # Compose the MLIR pipeline + transf = tile + if extract_micro_kernel: + transf = transf.then(ExtractKernel("gemm", "linalg.generic")) + transf = transf.then(LowerVectors(split_transfers=split_vector_transfer)) + if modulo_scheduling: + transf = transf.then( + Pipeline("gemm", + "linalg.generic", + unroll=ms_unroll, + distance=ms_distance)) + + transf = (transf.then(ConvertLoops("gemm", "linalg.generic")).then( + ConvertLoops("kernel", "linalg.generic")).then(LowerToLLVM())) + + return transf + + +def compile(mlir_program, option_list): + """The compiler program receives an mlir_program (.mlir) and generates + assembly (.s) + """ + program_base = os.path.splitext(mlir_program)[0] + transformed_program = f"{program_base}.llvm.mlir" + llvm_program = f"{program_base}.ll" + asm_program = f"{program_base}.s" + + ## Transform the MLIR program + # Generate a pipeline to transform the program + pipeline = generate_transform_pipeline(option_list) + + # Add SaveIR transforms after each transformation + if option_list["verbosity_level"] > 1: + pipeline = pipeline.save_ir(file_name=mlir_program, after_all=True) + + # Apply the pipeline + apply(pipeline, mlir_program, transformed_program) + + ## Translate MLIR LLVM to LLVMIR + translate_to_llvm(transformed_program, llvm_program) + + ## MLIR part is over. Let's pass the ball to the code generator + scheduler = option_list["scheduler"] + codegen(llvm_program, asm_program, scheduler) + + return asm_program + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser("mlirc") + + # Input MLIR to compile + parser.add_argument("--input-file") + + # Outer tiling + parser.add_argument("--tile-sizes", nargs="+", type=int) + parser.add_argument("--reorder-tile-sizes", nargs="+", type=int) + + # Inner tiling + parser.add_argument("--register-tile-sizes", nargs="+", type=int) + parser.add_argument("--reorder-register-tile-sizes", nargs="+", type=int) + parser.add_argument("--hoist-packing", nargs="+", type=int) + parser.add_argument("--transpose-packing", nargs="+", type=int) + + # Vector lowering + parser.add_argument("--unroll-vector-transfers", action="store_true") + parser.add_argument("--split-vector-transfers-to") + + # micro-kernel transforms + parser.add_argument("--extract-micro-kernel", action="store_true") + parser.add_argument("--modulo-scheduling", action="store_true") + parser.add_argument("--ms-interleave", action="store_true") + parser.add_argument("--ms-unroll", type=int) + parser.add_argument("--ms-distance", type=int) + + # scheduling algorithm + parser.add_argument("--scheduler", default="ilpmax") + + # verbosity + parser.add_argument("--verbosity-level", type=int, default=4) + args = parser.parse_args() + + options = { + "tile_sizes": args.tile_sizes, + "register_tile_sizes": args.register_tile_sizes, + "split_vector_transfers_to": args.split_vector_transfers_to, + "unroll_vector_transfers": args.unroll_vector_transfers, + "reorder_tile_sizes": args.reorder_tile_sizes, + "reorder_register_tile_sizes": args.reorder_register_tile_sizes, + "hoist_packing": args.hoist_packing, + "transpose_packing": args.transpose_packing, + "extract_micro_kernel": args.extract_micro_kernel, + "modulo_scheduling": args.modulo_scheduling, + "ms_interleave": args.ms_interleave, + "ms_unroll": args.ms_unroll, + "ms_distance": args.ms_distance, + "scheduler": args.scheduler, + "verbosity_level": args.verbosity_level, + } + + compile(args.input_file, options) diff --git a/experimental/alp/python/alp/backend/transforms.py b/experimental/alp/python/alp/backend/transforms.py new file mode 100644 index 000000000000..adead9ccf211 --- /dev/null +++ b/experimental/alp/python/alp/backend/transforms.py @@ -0,0 +1,61 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +from examples.core.transforms import * +from examples.core.transform import TransformListFactory, TransformationList + + +class Pipeline(Transform): + """Tile a linalg op with `tile_sizes`. + + This transform can be configured as follows: + * `ms_unroll`: Level of unrolling of the given loop + * `ms_distance`: Distance between a load and a compute operation + """ + + variables = { + "unroll": (IntVariable, []), + "distance": (IntVariable, []), + } + + def __init__(self, fun_name: str, op_name: str, **kwargs): + self._parse_variables_in_kwargs(kwargs) + unrolling_str = f"unrolling={self.unroll}" + distance_str = f"distance={self.distance}" + pipeline = (f"alp-modulo-scheduling{{" + f" {unrolling_str} " + f" {distance_str}}}," + f"canonicalize," + f"cse") + self.pipeline = f"builtin.func({pipeline})" + + +class ExtractKernel(Transform): + """Tile a linalg op with `tile_sizes`. + + This transform can be configured as follows: + * `ms_unroll`: Level of unrolling of the given loop + * `ms_distance`: Distance between a load and a compute operation + """ + + def __init__(self, fun_name: str, op_name: str, **kwargs): + self.pipeline = f"alp-extract-kernel," f"canonicalize," f"cse" + + +class ConvertLoops(Transform): + """Tile a linalg op with `tile_sizes`. + + This transform can be configured as follows: + * `ms_unroll`: Level of unrolling of the given loop + * `ms_distance`: Distance between a load and a compute operation + """ + + def __init__(self, fun_name: str, op_name: str, **kwargs): + self._parse_variables_in_kwargs(kwargs) + pipeline = (f"alp-for-to-dowhile{{" + f" anchor-func={fun_name}}}," + f"canonicalize," + f"cse") + self.pipeline = f"builtin.func({pipeline})" diff --git a/experimental/alp/python/alp/backend/tuner.py b/experimental/alp/python/alp/backend/tuner.py new file mode 100644 index 000000000000..12dcffd8623c --- /dev/null +++ b/experimental/alp/python/alp/backend/tuner.py @@ -0,0 +1,146 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import sys +from pathlib import Path + +import opentuner +from opentuner import ConfigurationManipulator +from opentuner.search.manipulator import ( + IntegerParameter, + PowerOfTwoParameter, + EnumParameter, + BooleanParameter, +) +from opentuner import MeasurementInterface +from opentuner import Result + +from . import mlirc +from .utils import parse +from ..benchmark import infra + +max_flops = 0 + + +class MLIRFlagsTuner(MeasurementInterface): + + def manipulator(self): + """ + Define the search space by creating a + ConfigurationManipulator + """ + manipulator = ConfigurationManipulator() + + manipulator.add_parameter(PowerOfTwoParameter("mr", 4, 4)) + + manipulator.add_parameter(PowerOfTwoParameter("nr", 16, 16)) + + manipulator.add_parameter(PowerOfTwoParameter("kr", 16, 64)) + + manipulator.add_parameter(PowerOfTwoParameter("kc", 64, 128)) + + manipulator.add_parameter(PowerOfTwoParameter("mc", 256, 2048)) + + manipulator.add_parameter(PowerOfTwoParameter("nc", 64, 2048)) + + manipulator.add_parameter(IntegerParameter("ha", 4, 4)) + + manipulator.add_parameter(IntegerParameter("hb", 3, 3)) + + return manipulator + + def run(self, desired_result, input, limit): + """ + Compile and run a given configuration then + return performance + """ + global max_flops + cfg = desired_result.configuration.data + + mr = cfg["mr"] + nr = cfg["nr"] + kr = cfg["kr"] + kc = cfg["kc"] + mc = cfg["mc"] + nc = cfg["nc"] + ha = cfg["ha"] + hb = cfg["hb"] + + reordering = "Afirst" + + if reordering == "Afirst": + reorder_inner = [0, 1, 2] + reorder_outer = [0, 2, 1] + else: + reorder_inner = [1, 0, 2] + reorder_outer = [1, 2, 0] + + options = { + f"tile_sizes": [mc, nc, kc], + "register_tile_sizes": [mr, nr, 1], + "split_vector_transfers_to": "vector-transfers", + "unroll_vector_transfers": True, + "reorder_tile_sizes": reorder_outer, + "reorder_register_tile_sizes": reorder_inner, + "hoist_packing": [ha, hb, 0], + "transpose_packing": [0, 0, 0], + "extract_micro_kernel": True, + "modulo_scheduling": True, + "ms_unroll": 1, + "ms_distance": 1, + "scheduler": "ilpmax", + "verbosity_level": 0, + } + + # Try to compile the program + try: + asm_program = mlirc.compile(self.args.input_file, options) + except: + return Result(time=sys.maxsize) + + # TODO: can we store obj_benchmark as an attribute of the class? + mlir_benchmark = args.benchmark + bench_base = os.path.splitext(mlir_benchmark)[0] + obj_benchmark = bench_base + ".o" + + # Link and run + exe = infra.link(asm_program, obj_benchmark) + run_result = self.call_program(f"./{exe}") + + if run_result["returncode"] != 0: + return Result(time=sys.maxsize) + + assert run_result["returncode"] == 0 + + secs, flops = parse(run_result["stdout"]) + + if flops > max_flops: + max_flops = flops + + return Result(time=1 / flops) + + def save_final_config(self, configuration): + """called at the end of tuning""" + print("Optimal block size written to mmm_final_config.json:", + configuration.data) + self.manipulator().save_to_file(configuration.data, f"final_config.json") + + +# TODO: create an API to call tune() from the library packager +if __name__ == "__main__": + argparser = opentuner.default_argparser() + argparser.add_argument("--input-file", required=True) + + # TODO: is it possible to understand properties of the MLIR input file and + # generating directly the benchmark program? Also, we should infer from the program + # structure what transformations to run instead of having them hardcoded + argparser.add_argument("--benchmark", required=True) + + args = argparser.parse_args() + + # Build the MLIR benchmark + benchmark_obj = infra.compile(args.benchmark) + + MLIRFlagsTuner.main(args) diff --git a/experimental/alp/python/alp/backend/utils.py b/experimental/alp/python/alp/backend/utils.py new file mode 100644 index 000000000000..066566aa112a --- /dev/null +++ b/experimental/alp/python/alp/backend/utils.py @@ -0,0 +1,114 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import subprocess +import os +import numpy as np +from subprocess import PIPE, Popen + + +def run_command(cmd): + output = subprocess.check_output(" ".join(cmd), shell=True) + return output.decode("ascii") + + +def print_command(cmd): + print(" ".join(cmd)) + + +def run_and_save(cmd, original_ir, new_ir): + out = run_command(cmd + [original_ir]) + f = open(f"{new_ir}", "w") + # Save the command that generated the IR + f.write("//" + " ".join(cmd + [original_ir]) + "\n") + # Save the IR + f.write(out) + f.close() + + +def add_extension(fname, ext): + orig_ext = os.path.splitext(fname)[1] + newfilename = os.path.splitext(fname)[0] + "." + ext + orig_ext + return newfilename + + +def parse(out): + if isinstance(out, bytes): + out = out.decode("utf-8") + + secs = 0 + flops = 0 + lines = out.split("\n") + for l in lines: + if not l: + continue + [a, b] = l.split() + if b == "secs": + secs = float(a) + if b == "GFLOPS": + flops = float(a) + return (secs, flops) + + +def analytical_model(hw, Sdata): + # Analyitical model for GEMM + # https://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf + + # Vector unit properties + Nvec = hw["Nvec"] + Lvfma = hw["Lvfma"] + Nvfma = hw["Nvfma"] + + # Determine mr/nr + K = Nvec * Nvfma * Lvfma + mr = np.ceil((np.sqrt(K) / Nvec)) * Nvec + nr = np.ceil(K / mr) + + # L1 properties + SL1 = hw["SL"][0] * 1024 + WL1 = hw["WL"][0] + + # L2 properties + SL2 = hw["SL"][1] * 1024 + WL2 = hw["WL"][1] + + if "CL" in hw: + CL1 = hw["CL"][0] + CL2 = hw["CL"][1] + NL1 = SL1 / (WL1 * CL1) + NL2 = SL2 / (WL2 * CL2) + elif "NL" in hw: + NL1 = hw["NL"][0] + NL2 = hw["NL"][1] + CL1 = SL1 / (WL1 * NL1) + CL2 = SL2 / (WL2 * NL2) + + # if L3 properties are specified, then determine nc + if hw["num_caches"] == 3: + SL3 = hw["SL"][2] * 1024 + WL3 = hw["WL"][2] + + if "CL" in hw: + CL3 = hw["CL"][2] + NL3 = SL3 / (WL3 * CL3) + elif "NL" in hw: + NL3 = hw["NL"][2] + CL3 = SL3 / (WL3 * NL3) + + # Determine kc + CAr = np.floor((WL1 - 1) / (1 + nr / mr)) + kc = (CAr * NL1 * CL1) / (mr * Sdata) + + # Determine mc + CBr2 = np.ceil(nr * kc * Sdata / (NL2 * CL2)) + mc = (WL2 - 1 - CBr2) * NL2 * CL2 / (kc * Sdata) + + # Determine nc + if hw["num_caches"] == 3: + CAc3 = np.ceil(mc * kc * Sdata / (NL3 * CL3)) + nc = ((WL3 - CAc3 - 1) * NL3 * CL3) / (kc * Sdata) + else: + nc = -1 + + return (mc, nc, kc, mr, nr) diff --git a/experimental/alp/python/alp/benchmark/__init__.py b/experimental/alp/python/alp/benchmark/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/experimental/alp/python/alp/benchmark/blas/__init__.py b/experimental/alp/python/alp/benchmark/blas/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/experimental/alp/python/alp/benchmark/blas/gemm.py b/experimental/alp/python/alp/benchmark/blas/gemm.py new file mode 100644 index 000000000000..71dc878346e3 --- /dev/null +++ b/experimental/alp/python/alp/benchmark/blas/gemm.py @@ -0,0 +1,141 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import sys, time, os + +from typing import Any, List, Mapping, Optional, Sequence +import numpy as np +import argparse + +from mlir.ir import * +from mlir.dialects import arith, builtin, linalg, tensor, scf, std, memref +from mlir.dialects.linalg.opdsl.lang import * + +from ...transition.blas.gemm import GEMM +from ..infra import * + +from examples.core.problem_definition import * + + +def save_mlir(mlir_txt, dest): + f = open(dest, "w") + f.write(mlir_txt) + f.close() + + +def emit_benchmarking_function(trA, sizes, niter, + func: builtin.FuncOp) -> builtin.FuncOp: + """Produces the benchmarking function. + + This function calls the given function `func` as many times as requested by + its last argument. + """ + f64 = F64Type.get() + + print_flops = builtin.FuncOp("print_flops", ([f64], []), visibility="private") + + print_time = builtin.FuncOp("print_time", ([f64], []), visibility="private") + + print_pid = builtin.FuncOp("print_pid", ([], []), visibility="private") + + rtclock = builtin.FuncOp("rtclock", ([], [f64]), visibility="private") + + wrapper = builtin.FuncOp( + # Same signature and an extra buffer of indices to save timings. + "main", + ([], [IntegerType.get_signless(32)]), + visibility="public", + ) + + M = sizes[0] + N = sizes[1] + K = sizes[2] + + with InsertionPoint(wrapper.add_entry_block()): + zero = arith.ConstantOp.create_index(0) + one = arith.ConstantOp.create_index(1) + elem = arith.ConstantOp(F32Type.get(), 1.0) + nops = arith.ConstantOp(F64Type.get(), 1.0 * (2 * M * N * K)) + + if trA: + A0 = linalg.InitTensorOp([K, M], F32Type.get()) + else: + A0 = linalg.InitTensorOp([M, K], F32Type.get()) + + B0 = linalg.InitTensorOp([K, N], F32Type.get()) + C = linalg.InitTensorOp([M, N], F32Type.get()) + + A = linalg.FillOp(output=A0.results[0], value=elem) + B = linalg.FillOp(output=B0.results[0], value=elem) + std.CallOp(print_pid, []) + + call = std.CallOp(func, [A.results[0], B.results[0], C.results[0]]) + + n_iterations = arith.ConstantOp.create_index(niter) + start = std.CallOp(rtclock, []) + loop = scf.ForOp(zero, n_iterations, one, []) + with InsertionPoint(loop.body): + call = std.CallOp(func, [A.results[0], B.results[0], C.results[0]]) + scf.YieldOp([]) + end = std.CallOp(rtclock, []) + treps = arith.SubFOp(end, start) + n_iterations_f = arith.ConstantOp(F64Type.get(), float(niter)) + t = arith.DivFOp(treps, n_iterations_f) + flops = arith.DivFOp(nops, t) + std.CallOp(print_time, [t.results[0]]) + std.CallOp(print_flops, [flops.results[0]]) + + ret = arith.ConstantOp(IntegerType.get_signless(32), 0) + std.ReturnOp(ret) + + return wrapper + + +def generate_benchmark_mlir(func_name, trA, size, reps, dest): + with Context() as ctx, Location.unknown() as loc: + f32 = F32Type.get() + problem_definition = GEMM(trA) + mlir_module = Module.create() + problem_sizes = {"M": size[0], "N": size[1], "K": size[2]} + types = problem_definition.types_mlir_builder( + problem_sizes, + [f32, f32, f32], + ) + with InsertionPoint(mlir_module.body): + gemm = builtin.FuncOp(func_name, (types, [types[-1]]), + visibility="private") + benchmark = emit_benchmarking_function(trA, size, reps, gemm) + save_mlir(str(mlir_module), dest) + + +def main(argv): + parser = argparse.ArgumentParser("gemm") + + # Files and paths + parser.add_argument("--asm-program", default="") + parser.add_argument("--function-name", default="gemm") + parser.add_argument("--output", default="") + + # GEMM specific parameters for the problem to benchmark + parser.add_argument("--M", type=int) + parser.add_argument("--N", type=int) + parser.add_argument("--K", type=int) + parser.add_argument("--trA", action="store_true") + + # Benchmark specs + parser.add_argument("--reps", type=int, default=1) + args = parser.parse_args(argv) + + # Generate the benchmark, build and run + size = [args.M, args.N, args.K] + dest = args.output if args.output else args.function_name + ".bench.mlir" + generate_benchmark_mlir(args.function_name, args.trA, size, args.reps, dest) + + if args.asm_program: + exe = build(args.asm_program, dest) + print(run(exe)) + + +if __name__ == "__main__": + main(os.sys.argv[1:]) diff --git a/experimental/alp/python/alp/benchmark/infra.py b/experimental/alp/python/alp/benchmark/infra.py new file mode 100644 index 000000000000..6b3b9dc34e15 --- /dev/null +++ b/experimental/alp/python/alp/benchmark/infra.py @@ -0,0 +1,73 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import tempfile +import shutil +import argparse +from pathlib import Path +from ..backend.utils import print_command, run_and_save, run_command, add_extension + + +def compile(mlir_benchmark): + mlir_benchmark_base = os.path.splitext(mlir_benchmark)[0] + mlir_llvm_benchmark = mlir_benchmark_base + ".llvm.mlir" + llvm_benchmark = mlir_benchmark_base + ".ll" + obj_benchmark = mlir_benchmark_base + ".o" + + # main program + cmd = ["$IREE_LLVM_SANDBOX_BUILD_DIR/bin/mlir-opt"] + cmd.append(mlir_benchmark) + cmd.append("--linalg-bufferize") + cmd.append("--std-bufferize") + cmd.append("--tensor-bufferize") + cmd.append("--func-bufferize") + cmd.append("-convert-linalg-to-affine-loops") + cmd.append("-lower-affine") + cmd.append("-convert-scf-to-std") + cmd.append("-convert-memref-to-llvm") + cmd.append("-convert-std-to-llvm") + cmd.append("-reconcile-unrealized-casts") + cmd.append(f"> {mlir_llvm_benchmark}") + run_command(cmd) + + cmd = ["$IREE_LLVM_SANDBOX_BUILD_DIR/bin/mlir-translate"] + cmd.append("--mlir-to-llvmir") + cmd.append(f"{mlir_llvm_benchmark}") + cmd.append(f"> {llvm_benchmark}") + run_command(cmd) + + cmd = ["$IREE_LLVM_SANDBOX_BUILD_DIR/bin/llc"] + cmd.append(f"{llvm_benchmark}") + cmd.append("-O3") + cmd.append("-filetype=obj") + cmd.append(f"-o {obj_benchmark}") + run_command(cmd) + return obj_benchmark + + +def link(asm_program, obj_benchmark): + mlir_benchmark_base = os.path.splitext(obj_benchmark)[0] + exe_benchmark = mlir_benchmark_base + ".out" + + runtime_src = ( + "$IREE_LLVM_SANDBOX_SOURCE_DIR/experimental/alp/lib/AlpRuntime/alp_runtime.cpp" + ) + cmd = ["$IREE_LLVM_SANDBOX_BUILD_DIR/bin/clang++"] + cmd.append(f"{obj_benchmark}") + cmd.append(f"{asm_program}") + cmd.append(runtime_src) + cmd.append(f"-o {exe_benchmark}") + run_command(cmd) + return exe_benchmark + + +def build(asm_program, mlir_benchmark): + # Compile and link + obj_benchmark = compile(mlir_benchmark) + return link(asm_program, obj_benchmark) + + +def run(executable): + return run_command([os.path.abspath(executable)]) diff --git a/experimental/alp/python/alp/test/__init__.py b/experimental/alp/python/alp/test/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/experimental/alp/python/alp/test/blas/__init__.py b/experimental/alp/python/alp/test/blas/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/experimental/alp/python/alp/test/blas/gemm.py b/experimental/alp/python/alp/test/blas/gemm.py new file mode 100644 index 000000000000..4be77b8dd72c --- /dev/null +++ b/experimental/alp/python/alp/test/blas/gemm.py @@ -0,0 +1,162 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import argparse + +# MLIR imports +from mlir.ir import * +from mlir.dialects import arith, builtin, linalg, tensor, scf, std, memref +from mlir.dialects.linalg.opdsl.lang import * + +# Sandbox imports +from examples.core.problem_definition import * + +# Alp imports +from ...backend.utils import print_command, run_and_save, run_command, add_extension +from ...transition.blas.gemm import save_mlir +from ...transition.blas.gemm import GEMM, matmul_NN, matmul_TN +from ..infra import * + + +def emit_test_function(trA, sizes, func: builtin.FuncOp) -> builtin.FuncOp: + """Produces the test function.""" + f64 = F64Type.get() + f32 = F32Type.get() + + printF32 = builtin.FuncOp("printF32", ([f32], []), visibility="private") + + printNewline = builtin.FuncOp("printNewline", ([], []), visibility="private") + wrapper = builtin.FuncOp( + # Same signature and an extra buffer of indices to save timings. + "main", + ([], [IntegerType.get_signless(32)]), + visibility="public", + ) + + M = sizes[0] + N = sizes[1] + K = sizes[2] + + with InsertionPoint(wrapper.add_entry_block()): + zero = arith.ConstantOp.create_index(0) + one_index = arith.ConstantOp.create_index(1) + one_f64 = arith.ConstantOp(F64Type.get(), 1.0) + + cM = arith.ConstantOp.create_index(M) + cN = arith.ConstantOp.create_index(N) + cK = arith.ConstantOp.create_index(K) + + minus_one_f64 = arith.ConstantOp(F64Type.get(), -1.0) + c12345_i32 = arith.ConstantOp(IntegerType.get_signless(32), 12345) + c6789_i32 = arith.ConstantOp(IntegerType.get_signless(32), 6789) + + # Initialize tensors + if trA: + A0 = linalg.InitTensorOp([K, M], F32Type.get()) + else: + A0 = linalg.InitTensorOp([M, K], F32Type.get()) + + B0 = linalg.InitTensorOp([K, N], F32Type.get()) + C = linalg.InitTensorOp([M, N], F32Type.get()) + D = linalg.InitTensorOp([M, N], F32Type.get()) + + elem = arith.ConstantOp(F32Type.get(), 1.0) + # Fill the inputs + A = linalg.fill_rng_2d(minus_one_f64, + one_f64, + c12345_i32, + outs=[A0.results[0]]) + B = linalg.fill_rng_2d(minus_one_f64, + one_f64, + c6789_i32, + outs=[B0.results[0]]) + # A = linalg.FillOp(output=A0.results[0], value=elem) + # B = linalg.FillOp(output=B0.results[0], value=elem) + + # Evaluate actual function and the reference + std.CallOp(func, [A, B, C.results[0]]) + if trA: + D1 = matmul_TN(A, B, outs=[D]) + else: + D1 = matmul_NN(A, B, outs=[D]) + + # Verify correctness loop + loopM = scf.ForOp(zero, cM, one_index, []) + with InsertionPoint(loopM.body): + loopN = scf.ForOp(zero, cN, one_index, []) + with InsertionPoint(loopN.body): + x = loopM.induction_variable + y = loopN.induction_variable + res1 = tensor.ExtractOp(F32Type.get(), C, [x, y]) + res2 = tensor.ExtractOp(F32Type.get(), D1, [x, y]) + diff = arith.SubFOp(res1, res2) + # TODO Add support for scf.If op to verify directly from here + std.CallOp(printF32, [diff.results[0]]) + std.CallOp(printNewline, []) + scf.YieldOp([]) + scf.YieldOp([]) + + ret = arith.ConstantOp(IntegerType.get_signless(32), 0) + std.ReturnOp(ret) + + return wrapper + + +def generate_test_mlir(func_name, trA, size, dest): + with Context() as ctx, Location.unknown() as loc: + f32 = F32Type.get() + problem_definition = GEMM(trA) + mlir_module = Module.create() + problem_sizes = {"M": size[0], "N": size[1], "K": size[2]} + types = problem_definition.types_mlir_builder( + problem_sizes, + [f32, f32, f32], + ) + + with InsertionPoint(mlir_module.body): + gemm = builtin.FuncOp(func_name, (types, [types[-1]]), + visibility="private") + test = emit_test_function(trA, size, gemm) + save_mlir(str(mlir_module), dest) + + +def main(argv): + parser = argparse.ArgumentParser("benchmark") + + # Information about the program + parser.add_argument("--asm-program") + parser.add_argument("--function-name", default="gemm") + parser.add_argument("--output", default="") + + # Details of the problem to test + parser.add_argument("--trA", action="store_true") + parser.add_argument("--M", type=int) + parser.add_argument("--N", type=int) + parser.add_argument("--K", type=int) + + # Test specs + parser.add_argument("--threshold", type=float, default=0.1) + args = parser.parse_args(argv) + + # Generate the test, build and run + size = [args.M, args.N, args.K] + dest = args.output if args.output else args.function_name + ".test.mlir" + generate_test_mlir(args.function_name, args.trA, size, dest) + exe = build(args.asm_program, dest) + + out = run_command([os.path.abspath(exe), "> out.log"]) + out = out.split("\n") + for l in out: + if not l: + continue + f = float(l.strip()) + if f > args.threshold: + print("ERROR") + exit() + print("PASSED") + + +if __name__ == "__main__": + main(os.sys.argv[1:]) diff --git a/experimental/alp/python/alp/test/infra.py b/experimental/alp/python/alp/test/infra.py new file mode 100644 index 000000000000..0fe935c320e8 --- /dev/null +++ b/experimental/alp/python/alp/test/infra.py @@ -0,0 +1,77 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import tempfile +import shutil +import argparse +from pathlib import Path + +# Alp imports +from ..backend.utils import print_command, run_and_save, run_command, add_extension +from ..transition.blas.gemm import save_mlir +from ..transition.blas.gemm import GEMM, matmul_NN, matmul_TN + + +def compile(mlir_test): + mlir_test_base = os.path.splitext(mlir_test)[0] + mlir_llvm_test = mlir_test_base + ".llvm.mlir" + llvm_test = mlir_test_base + ".ll" + obj_test = mlir_test_base + ".o" + + # main program + cmd = ["$IREE_LLVM_SANDBOX_BUILD_DIR/bin/mlir-opt"] + cmd.append(mlir_test) + cmd.append("--linalg-bufferize") + cmd.append("--std-bufferize") + cmd.append("--tensor-bufferize") + cmd.append("--func-bufferize") + cmd.append("-convert-linalg-to-affine-loops") + cmd.append("-lower-affine") + cmd.append("-convert-scf-to-std") + cmd.append("-convert-memref-to-llvm") + cmd.append("-convert-std-to-llvm") + cmd.append("-reconcile-unrealized-casts") + cmd.append(f"> {mlir_llvm_test}") + run_command(cmd) + + cmd = ["$IREE_LLVM_SANDBOX_BUILD_DIR/bin/mlir-translate"] + cmd.append("--mlir-to-llvmir") + cmd.append(f"{mlir_llvm_test}") + cmd.append(f"> {llvm_test}") + run_command(cmd) + + cmd = ["$IREE_LLVM_SANDBOX_BUILD_DIR/bin/llc"] + cmd.append(f"{llvm_test}") + cmd.append("-O3") + cmd.append("-filetype=obj") + cmd.append(f"-o {obj_test}") + run_command(cmd) + return obj_test + + +def link(asm_program, obj_test): + mlir_test_base = os.path.splitext(obj_test)[0] + exe_test = mlir_test_base + ".out" + + runtime_src = ( + "$IREE_LLVM_SANDBOX_SOURCE_DIR/experimental/alp/lib/AlpRuntime/alp_runtime.cpp" + ) + cmd = ["$IREE_LLVM_SANDBOX_BUILD_DIR/bin/clang++"] + cmd.append(f"{obj_test}") + cmd.append(f"{asm_program}") + cmd.append(runtime_src) + cmd.append(f"-o {exe_test}") + run_command(cmd) + return exe_test + + +def build(asm_program, mlir_test): + # Compile and link + obj_test = compile(mlir_test) + return link(asm_program, obj_test) + + +def run(executable): + run_command([os.path.abspath(executable)]) diff --git a/experimental/alp/python/alp/transition/__init__.py b/experimental/alp/python/alp/transition/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/experimental/alp/python/alp/transition/blas/__init__.py b/experimental/alp/python/alp/transition/blas/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/experimental/alp/python/alp/transition/blas/gemm.py b/experimental/alp/python/alp/transition/blas/gemm.py new file mode 100644 index 000000000000..0d0074e99c1e --- /dev/null +++ b/experimental/alp/python/alp/transition/blas/gemm.py @@ -0,0 +1,177 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import sys, time, os + +from typing import Any, List, Mapping, Optional, Sequence +import numpy as np +import argparse + +from mlir.ir import * +from mlir.dialects import arith, builtin, linalg, tensor, scf, std, memref +from mlir.dialects.linalg.opdsl.lang import * + +from examples.core.problem_definition import * + + +def save_mlir(mlir_txt, dest): + f = open(dest, "w") + f.write(mlir_txt) + f.close() + + +def attach_inplaceable_attributes(func: builtin.FuncOp, + inplaceable: Sequence[Optional[bool]]): + + # Create the following affine_map + # (d0, d1)[s0, s1] -> (s0 + d0*s1 + d1) + d0 = AffineDimExpr.get(0) + d1 = AffineDimExpr.get(1) + s0 = AffineSymbolExpr.get(0) + s1 = AffineSymbolExpr.get(1) + mul = AffineMulExpr.get(d0, s1) + add = AffineAddExpr.get(s0, mul) + add = AffineAddExpr.get(add, d1) + map0 = AffineMap.get(2, 2, [add]) + + # Add the attributes to the inputs + attrs = [] + for flag in inplaceable: + if flag is None: + attrs.append(DictAttr.get({})) + continue + attrs.append( + DictAttr.get({ + "linalg.inplaceable": BoolAttr.get(flag), + "linalg.buffer_layout": AffineMapAttr.get(map0), + })) + func.arg_attrs = attrs + + +# C = tr(A)*B +@linalg_structured_op +def matmul_TN( + A=TensorDef(TV.T1, S.K, S.M), + B=TensorDef(TV.T2, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True), +): + + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += TypeFn.cast(U, A[D.k, D.m]) * TypeFn.cast(U, B[D.k, D.n]) + + +# C = A*B +@linalg_structured_op +def matmul_NN( + A=TensorDef(TV.T1, S.M, S.K), + B=TensorDef(TV.T2, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True), +): + + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) + + +class GEMM(ProblemDefinition): + """Problem definition for a fill + matmul + generic op.""" + + def __init__(self, trA): + self.trA = trA + + def shapes_builder(self, sizes: Mapping[str, Any]) -> List[List[int]]: + """Shape builder function. + + Given a mapping between dimension names / op attributes and their numeric + values, return the list of lists of shapes of the FuncOp operands. The + FuncOp is responsible for distinguishing between input operands and results. + """ + M, N, K = sizes["M"], sizes["N"], sizes["K"] + if self.trA: + return [ + [K, M], + [K, N], + [M, N], + ] + else: + return [ + [M, K], + [K, N], + [M, N], + ] + + def types_mlir_builder(self, sizes: Mapping[str, Any], + types: Sequence[Type]) -> List[Type]: + shapes = self.shapes_builder(sizes) + return [ + RankedTensorType.get(s, t) for s, t in zip(shapes, + list(types) + [types[-1]]) + ] + + def build_problem_under_context_manager( + self, name: str, types: Sequence[Type]) -> builtin.FuncOp: + + # Actual benchmarked function called under entry_point_name. + func = builtin.FuncOp(name, (types, [types[-1]])) + + attach_inplaceable_attributes(func, inplaceable=[False, False, True]) + + with InsertionPoint(func.add_entry_block()): + if self.trA: + matmul = matmul_TN(func.arguments[0], + func.arguments[1], + outs=[func.arguments[2]]) + else: + matmul = matmul_NN(func.arguments[0], + func.arguments[1], + outs=[func.arguments[2]]) + std.ReturnOp([matmul]) + + return func + + +def generate_mlir(func_name, trA, size, dest): + # Build MLIR GEMM + with Context() as ctx, Location.unknown() as loc: + f32 = F32Type.get() + problem_definition = GEMM(trA) + mlir_module = Module.create() + problem_sizes = {"M": size[0], "N": size[1], "K": size[2]} + types = problem_definition.types_mlir_builder( + problem_sizes, + [f32, f32, f32], + ) + with InsertionPoint(mlir_module.body): + func = problem_definition.build_problem_under_context_manager( + func_name, types) + save_mlir(str(mlir_module), dest) + + +def main(argv): + parser = argparse.ArgumentParser("gemm") + # Paths and naming + parser.add_argument("--func-name", default="gemm") + parser.add_argument("--output", default="") + + # Problem description + parser.add_argument("--M", type=int) + parser.add_argument("--N", type=int) + parser.add_argument("--K", type=int) + parser.add_argument("--trA", action="store_true") + parser.add_argument("--reps", type=int, default=1) + args = parser.parse_args() + + # Generate the problem + M = args.M if args.M else -1 + N = args.N if args.N else -1 + K = args.K if args.K else -1 + + size = [M, N, K] + dest = args.output if args.output else args.func_name + ".mlir" + generate_mlir(args.func_name, args.trA, size, dest) + + +if __name__ == "__main__": + main(os.sys.argv[1:]) diff --git a/experimental/alp/python/test/gemm.py b/experimental/alp/python/test/gemm.py new file mode 100644 index 000000000000..259e2e65d426 --- /dev/null +++ b/experimental/alp/python/test/gemm.py @@ -0,0 +1,137 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from math import exp +import os +from pathlib import Path + +import alp.backend.mlirc +import alp.benchmark.infra +import alp.benchmark.blas.gemm +import alp.test.blas.gemm +import alp.test.infra +import alp.backend.utils +import alp.transition.blas.gemm + + +def get_default_compiler_options(): + options = { + "tile_sizes": [256, 256, 256], + "register_tile_sizes": [4, 16, 1], + "split_vector_transfers_to": "none", + "unroll_vector_transfers": True, + "reorder_tile_sizes": [0, 2, 1], + "reorder_register_tile_sizes": [0, 1, 2], + "hoist_packing": [4, 3, 0], + "transpose_packing": [0, 0, 0], + "extract_micro_kernel": False, + "modulo_scheduling": False, + "ms_unroll": 2, + "ms_distance": 1, + "scheduler": "ilpmax", + "verbosity_level": 4, + } + return options + + +fun_name = "gemm" +mlir_name = f"{fun_name}.mlir" +bench_name = f"{fun_name}.bench.mlir" +test_name = f"{fun_name}.test.mlir" +dest_path = os.path.join(os.path.curdir, fun_name) +Path(dest_path).mkdir(exist_ok=True) + + +def gen_sgemm_trA(): + trA = True + mlir_file_name = os.path.join(dest_path, mlir_name) + # Generate a fully dynamic gemm + alp.transition.blas.gemm.generate_mlir(fun_name, trA, [-1, -1, -1], + mlir_file_name) + asm_file = alp.backend.mlirc.compile(mlir_file_name, + get_default_compiler_options()) + return asm_file + + +def perf(asm_file, sizes, trA, compare=False, expected_flops=[]): + + # Build GEMMTN for f32 with default compiler options + + # Benchmark for gemm 2048^3 + mlir_flops = [] + for size in sizes: + bench_file_name = os.path.join(dest_path, bench_name) + alp.benchmark.blas.gemm.generate_benchmark_mlir(fun_name, trA, size, 1, + bench_file_name) + bench = alp.benchmark.infra.build(asm_file, bench_file_name) + s, f = alp.backend.utils.parse(alp.benchmark.infra.run(bench)) + mlir_flops.append(f) + + if compare: + for i in range(0, len(mlir_flops)): + print(sizes[i], mlir_flops[i]) + + slow_down = False + for i in range(0, len(mlir_flops)): + if expected_flops and mlir_flops[i] < expected_flops[i]: + print(mlir_flops[i], expected_flops[i]) + slow_down = True + break + return slow_down + + +def verify(asm_file, sizes, trA): + + def verify(out): + for l in out: + if not l: + continue + f = float(l.strip()) + if f > 0.1: + return False + return True + + # Verify for very weird gemm dimensions (all prime numbers) + for size in sizes: + test_file_name = os.path.join(dest_path, test_name) + alp.test.blas.gemm.generate_test_mlir(fun_name, trA, size, test_file_name) + test = alp.test.infra.build(asm_file, test_file_name) + out = alp.backend.utils.run_command([os.path.abspath(test)]) + out = out.split("\n") + return True if verify(out) else False + + +def main(): + asm = gen_sgemm_trA() + + # Test performance against different sizes + perf_sizes = [ + [2048, 2048, 2048], + [1024, 1024, 1024], + [512, 512, 512], + [128, 128, 128], + [64, 64, 64], + ] + slow_down = perf(asm, + perf_sizes, + compare=False, + trA=True, + expected_flops=[-1, -1, -1, -1, -1]) + if slow_down: + print("PERF regression!!!") + else: + print("PERF OK!") + + ## Verify for weird sizes + verify_sizes = [[513, 431, 23], [128, 10, 11]] + correct = verify(asm, verify_sizes, trA=True) + + if correct: + print("GEMM is correct!") + else: + print("Something is wrong!!!") + + +if __name__ == "__main__": + main() diff --git a/experimental/alp/test/test-alp-legalize.mlir b/experimental/alp/test/test-alp-legalize.mlir new file mode 100644 index 000000000000..9b4232d0cff4 --- /dev/null +++ b/experimental/alp/test/test-alp-legalize.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-proto-opt -alp-legalize %s | FileCheck %s +#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +#map1 = affine_map<()[s0] -> (s0 ceildiv 8)> +#map2 = affine_map<(d0, d1, d2) -> (d2, d0)> +#map3 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map4 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map5 = affine_map<(d0) -> (d0 ceildiv 8)> +#map6 = affine_map<(d0, d1) -> (d0 + d1)> +// CHECK-LABEL: @kernel +func @kernel(%arg0: memref<2048x256xf32, #map0>, %arg1: index, %arg2: memref<256x256x1x8xf32>, %arg3: index, %arg4: memref<32x256x1x8xf32>, %arg5: index) attributes {passthrough = [["prefer-vector-width", "128"]]} { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1 = arith.constant 1 : index + %0 = affine.apply #map1()[%arg5] + %1 = vector.transfer_read %arg0[%arg1, %arg5], %cst {in_bounds = [true, true]} : memref<2048x256xf32, #map0>, vector<8x8xf32> + // CHECK: vector.transfer_read %arg0 + // CHECK: vector.transfer_read %arg0 + // CHECK: vector.transfer_read %arg0 + // CHECK: vector.transfer_read %arg0 + // CHECK: scf.for %[[it0:.*]] = %[[lb:.*]] to %[[ub:.*]] step %[[s:.*]] iter_args( + // CHECK-SAME: %arg7 = %cst_0, %arg8 = %8, %arg9 = %5, %arg10 = %3, %arg11 = %1) -> (vector<8x8xf32>, vector<4x4xf32>, vector<4x4xf32>, vector<4x4xf32>, vector<4x4xf32>) { + %2 = scf.for %arg6 = %c0 to %c256 step %c1 iter_args(%arg7 = %1) -> (vector<8x8xf32>) { + // CHECK: %[[V0:.*]] = vector.transfer_read %arg2 + // CHECK-NEXT: %[[V1:.*]] = vector.transfer_read %arg2 + // CHECK-NEXT: %[[V2:.*]] = vector.transfer_read %arg4 + // CHECK-NEXT: %[[V3:.*]] = vector.transfer_read %arg4 + %3 = vector.transfer_read %arg2[%arg3, %arg6, %c0, %c0], %cst {in_bounds = [true, true]} : memref<256x256x1x8xf32>, vector<1x8xf32> + %4 = vector.transfer_read %arg4[%0, %arg6, %c0, %c0], %cst {in_bounds = [true, true]} : memref<32x256x1x8xf32>, vector<1x8xf32> + %5 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %3, %4, %arg7 : vector<1x8xf32>, vector<1x8xf32> into vector<8x8xf32> + scf.yield %5 : vector<8x8xf32> + } + vector.transfer_write %2, %arg0[%arg1, %arg5] {in_bounds = [true, true]} : vector<8x8xf32>, memref<2048x256xf32, #map0> + return +} diff --git a/experimental/alp/test/test-alp-modulo-scheduling.mlir b/experimental/alp/test/test-alp-modulo-scheduling.mlir new file mode 100644 index 000000000000..82372d09d353 --- /dev/null +++ b/experimental/alp/test/test-alp-modulo-scheduling.mlir @@ -0,0 +1,53 @@ +// RUN: mlir-proto-opt -alp-modulo-scheduling %s | FileCheck %s +#map1 = affine_map<()[s0] -> (s0 ceildiv 8)> + +func @kernel(%arg0: memref<2048x256xf32>, %arg1: index, %arg2: memref<256x256x1x8xf32>, %arg3: index, %arg4: memref<32x256x1x8xf32>, %arg5: index){ + %0 = affine.apply #map1()[%arg5] + %cst = arith.constant dense<0.000000e+00> : vector<4x4xf32> + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %v0 = arith.constant dense<1.0> : vector<4x4xf32> + %v1 = arith.constant dense<1.0> : vector<4x4xf32> + %v2 = arith.constant dense<1.0> : vector<4x4xf32> + %v3 = arith.constant dense<1.0> : vector<4x4xf32> + + // CHECK: vector.load %arg2 + // CHECK-NEXT: vector.load %arg2 + // CHECK-NEXT: vector.load %arg4 + // CHECK-NEXT: vector.load %arg4 + %41:4 = scf.for %arg6 = %c0 to %c256 step %c1 iter_args(%arg7 = %v0, %arg8 = %v1, %arg9 = %v2, %arg10 = %v3) -> (vector<4x4xf32>, vector<4x4xf32>, vector<4x4xf32>, vector<4x4xf32>) { + // CHECK: scf.for + // CHECK: %[[V0:.*]] = vector.load %arg2 + // CHECK-NEXT: %[[V1:.*]] = vector.load %arg2 + // CHECK-NEXT: %[[V2:.*]] = vector.load %arg4 + // CHECK-NEXT: %[[V3:.*]] = vector.load %arg4 + // CHECK-NEXT: vector.outerproduct %arg11, %arg12 + // CHECK-NEXT: vector.outerproduct %arg11, %arg13 + // CHECK-NEXT: vector.outerproduct %arg14, %arg12 + // CHECK-NEXT: vector.outerproduct %arg14, %arg13 + // CHECK: %[[Y0:.*]] = vector.load %arg2 + // CHECK: %[[Y1:.*]] = vector.load %arg2 + // CHECK: %[[Y2:.*]] = vector.load %arg4 + // CHECK: %[[Y3:.*]] = vector.load %arg4 + // CHECK: %[[Y4:.*]] = vector.outerproduct %[[V0]], %[[V2]] + // CHECK-NEXT: %[[Y5:.*]] = vector.outerproduct %[[V0]], %[[V3]] + // CHECK-NEXT: %[[Y6:.*]] = vector.outerproduct %[[V1]], %[[V2]] + // CHECK-NEXT: %[[Y7:.*]] = vector.outerproduct %[[V1]], %[[V3]] + // CHECK-NEXT: scf.yield %[[Y7]], %[[Y6]], %[[Y5]], %[[Y4]], %[[Y0]], %[[Y2]], %[[Y3]], %[[Y1]] + %58 = vector.load %arg2[%arg3, %arg6, %c0, %c0] : memref<256x256x1x8xf32>, vector<4xf32> + %59 = vector.load %arg2[%arg3, %arg6, %c0, %c4] : memref<256x256x1x8xf32>, vector<4xf32> + %60 = vector.load %arg4[%0, %arg6, %c0, %c0] : memref<32x256x1x8xf32>, vector<4xf32> + %61 = vector.load %arg4[%0, %arg6, %c0, %c4] : memref<32x256x1x8xf32>, vector<4xf32> + %62 = vector.outerproduct %58, %60, %arg10 {kind = #vector.kind} : vector<4xf32>, vector<4xf32> + %63 = vector.outerproduct %58, %61, %arg9 {kind = #vector.kind} : vector<4xf32>, vector<4xf32> + %64 = vector.outerproduct %59, %60, %arg8 {kind = #vector.kind} : vector<4xf32>, vector<4xf32> + %65 = vector.outerproduct %59, %61, %arg7 {kind = #vector.kind} : vector<4xf32>, vector<4xf32> + scf.yield %65, %64, %63, %62 : vector<4x4xf32>, vector<4x4xf32>, vector<4x4xf32>, vector<4x4xf32> + } + %s0 = vector.extract %41#3[0] : vector<4x4xf32> + vector.store %s0, %arg0[%c0, %c0] : memref<2048x256xf32>, vector<4xf32> + + return +}