Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tuner] Add single dispatch tuner example #253

Merged
merged 3 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tuner/examples/dispatch/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Test files/dirs recommended by README.md.
dump/
benchmark.mlir
35 changes: 35 additions & 0 deletions tuner/examples/dispatch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Dispatch Tuner

Allows to tune a single dispatch in isolation.

## Environments
Follow instructions in [`/tuner/README.md`](../README.md)

## Running the Dispatch Tuner

### Generate a benchmark file
Use the usual `iree-compile` command for your dispatch and add
`--iree-hal-dump-executable-files-to=dump`. For example:
```shell
iree-compile mmt.mlir --iree-hal-target-backends=rocm --iree-hip-target=gfx942 --iree-hal-dump-executable-files-to=dump -o /dev/null
```

Next, copy the `*_benchmark.mlir` file to some temporary directory of choice.
This will be the input to the dispatch tuner.

### Recommended Trial Run
For an initial trial to test the tuning loop, use:
```shell
python -m examples.dispatch benchmark.mlir --num-candidates=20
```

### Dry Run Test
To perform a dry run (no GPU required), use:
```shell
python -m examples.dispatch benchmark.mlir --num-candidates=64 --num-model-candidates=10 --dry-run
```

### Basic Usage
```shell
python -m examples.dispatch benchmark.mlir
```
5 changes: 5 additions & 0 deletions tuner/examples/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9 changes: 9 additions & 0 deletions tuner/examples/dispatch/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from . import dispatch_tuner

dispatch_tuner.main()
18 changes: 18 additions & 0 deletions tuner/examples/dispatch/compile_dispatch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#! /usr/bin/env bash

set -eou pipefail

readonly INPUT="$1"
readonly DIR="$(dirname "$INPUT")"
readonly BASENAME="$(basename "$INPUT" .mlir)"
readonly OUT="${DIR}/compiled/${BASENAME}.vmfb"

iree-compile "$INPUT" -o "$OUT" \
--compile-from=executable-sources 2>/dev/null || (mv "$INPUT" "$DIR/failed" && exit 1)

iree-dump-module "$OUT" | grep -q 'rocm-hsaco-fb' || (mv "$INPUT" "$DIR/failed" && rm -f "$OUT" && exit 1)
if [ -f "${DIR}/${BASENAME}_config.mlir" ]; then
cat "${DIR}/../config_prolog.mlir" "${DIR}/${BASENAME}_config.mlir" "${DIR}/../config_epilog.mlir" > "${DIR}/specs/${BASENAME}_spec.mlir"
fi

echo "Compiling ${INPUT}: success"
12 changes: 12 additions & 0 deletions tuner/examples/dispatch/config_epilog.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

//===----------------------------------------------------------------------===//
// Entry point
//===----------------------------------------------------------------------===//

transform.named_sequence @__kernel_config(%variant_op: !transform.any_op {transform.consumed}) {
transform.foreach_match in %variant_op
, @match_op -> @apply_op_config
: (!transform.any_op) -> (!transform.any_op)
transform.yield
}
} //// module
32 changes: 32 additions & 0 deletions tuner/examples/dispatch/config_prolog.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Transform dialect specification for attention on MI300 with MFMA.
module attributes { transform.with_named_sequence } {
//===----------------------------------------------------------------------===//
// Matmul tuning
//===----------------------------------------------------------------------===//

transform.named_sequence @match_mmt_f16_f16_f32(%root: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
transform.match.operation_name %root ["linalg.generic"] : !transform.any_op
// transform.print %root {name = "Generic"} : !transform.any_op
%ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root {
^bb0(%lhs: tensor<?x?xf16>, %rhs: tensor<?x?xf16>, %out: tensor<?x?xf32>):
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%lhs, %rhs : tensor<?x?xf16>, tensor<?x?xf16>) outs(%out : tensor<?x?xf32>) {
^bb0(%in: f16, %in_0: f16, %acc: f32):
%8 = arith.extf %in : f16 to f32
%9 = arith.extf %in_0 : f16 to f32
%10 = arith.mulf %8, %9 : f32
%11 = arith.addf %acc, %10 : f32
linalg.yield %11 : f32
} -> tensor<?x?xf32>
} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
transform.yield %root : !transform.any_op
}

transform.named_sequence @apply_op_config(%op: !transform.any_op {transform.readonly}, %config: !transform.any_param {transform.readonly}) {
transform.annotate %op "compilation_info" = %config : !transform.any_op, !transform.any_param
// transform.print %op {name = "Applied"} : !transform.any_op
transform.yield
}
137 changes: 137 additions & 0 deletions tuner/examples/dispatch/dispatch_tuner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""
Sample Usage:

python -m examples.dispatch benchmark.mlir --lhs-dims=bmk --rhs-dims=bkn --tile-dims=*mnk --devices=hip://0,hip://1 --num-candidates=64


Recommended Trial Run:

python -m examples.dispatch benchmark.mlir --num-candidates=10


Dry Run Test (no gpu required):

python -m examples.dispatch benchmark.mlir --num-candidates=64 --dry-run

"""

from tuner import libtuner
archana-ramalingam marked this conversation as resolved.
Show resolved Hide resolved
from pathlib import Path, PurePath
import os


class DispatchTuner(libtuner.TuningClient):
def get_dispatch_compile_timeout_s(self) -> int:
return 10

def get_dispatch_compile_command(
self, candidate_tracker: libtuner.CandidateTracker
) -> list[str]:
assert candidate_tracker.dispatch_mlir_path is not None
mlir_path: Path = candidate_tracker.dispatch_mlir_path
script_dir = Path(__file__).resolve().parent
command = [
(script_dir / "compile_dispatch.sh").as_posix(),
mlir_path.as_posix(),
]
return command

def get_dispatch_benchmark_timeout_s(self) -> int:
return 15

def get_dispatch_benchmark_command(
self,
candidate_tracker: libtuner.CandidateTracker,
) -> list[str]:
compiled_vmfb_path = candidate_tracker.compiled_dispatch_path
assert compiled_vmfb_path is not None

command = [
"iree-benchmark-module",
f"--device={libtuner.DEVICE_ID_PLACEHOLDER}",
f"--module={compiled_vmfb_path.resolve()}",
"--batch_size=1000",
"--benchmark_repetitions=3",
]

return command

def get_model_compile_timeout_s(self) -> int:
return 0

def get_model_compile_command(
self, candidate_tracker: libtuner.CandidateTracker
) -> list[str]:
return []

def get_model_benchmark_timeout_s(self) -> int:
return 0

def get_model_benchmark_command(
self, candidate_tracker: libtuner.CandidateTracker
) -> list[str]:
return []


def main():
args = libtuner.parse_arguments()
path_config = libtuner.PathConfig()
# These will not be used, so always default to the empty config in the script dir.
script_dir = Path(__file__).resolve().parent
path_config.global_config_prolog_mlir = (
script_dir / path_config.global_config_prolog_mlir
)
path_config.global_config_epilog_mlir = (
script_dir / path_config.global_config_epilog_mlir
)
path_config.base_dir.mkdir(parents=True, exist_ok=True)
path_config.output_unilog.touch()
candidate_trackers: list[libtuner.CandidateTracker] = []
dispatch_tuner = DispatchTuner()
stop_after_phase: str = args.stop_after

print("Setup logging")
libtuner.setup_logging(args, path_config)
print(path_config.run_log, end="\n\n")

if not args.dry_run:
print("Validating devices")
libtuner.validate_devices(args.devices)
print("Validation successful!\n")

print("Generating candidates...")
candidates = libtuner.generate_candidates(args, path_config, candidate_trackers)
print(f"Stored candidates in {path_config.candidates_dir}\n")
if stop_after_phase == libtuner.ExecutionPhases.generate_candidates:
return

print("Compiling candidates...")
compiled_candidates = libtuner.compile_dispatches(
args, path_config, candidates, candidate_trackers, dispatch_tuner
)
print(f"Compiled files are stored in {path_config.compiled_dir}\n")
if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches:
return

print("Benchmarking compiled candidates...")
top_candidates = libtuner.benchmark_dispatches(
args, path_config, compiled_candidates, candidate_trackers, dispatch_tuner
)
print(f"\nStored results in {path_config.output_unilog.resolve()}\n")
if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches:
return

libtuner.save_pickle(path_config.candidate_trackers_pkl, candidate_trackers)
print(f"Candidate trackers are saved in {path_config.candidate_trackers_pkl}\n")

print("Check the detailed execution logs in:")
print(path_config.run_log.resolve())

for candidate in candidate_trackers:
libtuner.logging.debug(candidate)
11 changes: 11 additions & 0 deletions tuner/examples/dispatch/mmt.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
!matA_0 = tensor<2048x1280xf16>
!matB_0 = tensor<10240x1280xf16>
!matC_0 = tensor<2048x10240xf32>

func.func @main_0(%arg0: !matA_0, %arg1: !matB_0) -> !matC_0 {
%cst = arith.constant 0.000000e+00 : f16
%5 = tensor.empty() : !matC_0
%6 = linalg.fill ins(%cst : f16) outs(%5 : !matC_0) -> !matC_0
%8 = linalg.matmul_transpose_b ins(%arg0, %arg1 : !matA_0, !matB_0) outs(%6 : !matC_0) -> !matC_0
return %8 : !matC_0
}
15 changes: 8 additions & 7 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ class CandidateTracker:
calibrated_benchmark_diff: Optional[float] = None


@dataclass(frozen=True)
@dataclass()
class PathConfig:
# Preset constants
global_config_prolog_mlir: Path = Path("./config_prolog.mlir")
global_config_epilog_mlir: Path = Path("./config_epilog.mlir")
model_baseline_vmfb: Path = Path("./baseline.vmfb")
global_config_prolog_mlir: Path = Path("config_prolog.mlir")
global_config_epilog_mlir: Path = Path("config_epilog.mlir")
model_baseline_vmfb: Path = Path("baseline.vmfb")

# Dynamic paths
base_dir: Path = field(init=False)
Expand Down Expand Up @@ -523,7 +523,7 @@ def create_worker_context_queue(device_ids: list[int]) -> queue.Queue[tuple[int,
def run_command(run_pack: RunPack) -> TaskResult:
command = run_pack.command
check = run_pack.check
timeout_seconds = run_pack.timeout
timeout_seconds = run_pack.timeout_seconds

result = None
is_timeout = False
Expand Down Expand Up @@ -828,7 +828,7 @@ def compile_dispatches(
num_worker=num_worker, task_list=task_list, function=run_command_wrapper
)

# Note: failed/incompleted candidates can also be detected by checking if subprocess.res is None
# Note: failed/incomplete candidates can also be detected by checking if subprocess.res is None
compiled_files = sorted(
path_config.compiled_dir.glob("*.vmfb"), key=numerical_sort_key
)
Expand Down Expand Up @@ -860,7 +860,8 @@ def compile_dispatches(
compiled_candidates_hash_list.append((index, hash_val))

handle_error(
condition=(good == 0), msg="Failed to compile all candidate .mlir files"
condition=(good == 0),
msg="All candidate dispatches .mlir files failed to compile",
)
handle_error(
condition=(compiling_rate < 10),
Expand Down
Loading