Skip to content

Commit

Permalink
Enable benchmarking in performance ci
Browse files Browse the repository at this point in the history
This PR enables benchmarking only for the gemm test.
After the test, json files are dumped which are
then used for visualization along with information
about the github sha and ref name.

Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod committed Sep 24, 2024
1 parent 0327398 commit 2177106
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 5 deletions.
10 changes: 9 additions & 1 deletion .github/workflows/perf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,15 @@ jobs:
path: ${{ env.PIP_CACHE_DIR }}
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }}

- name: Set current date as env variable
run: echo "NOW=$(date +'%Y-%m-%dT%H:%M:%S')" >> "$GITHUB_ENV"

- name: Install pip deps
run: |
python -m pip install --no-compile --upgrade pip
# Note: We install in three steps in order to satisfy requirements
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
# wheels saves signifcant time during setup.
pip install --no-compile -r pytorch-cpu-requirements.txt
pip install --no-cache-dir -r iree-requirements-ci.txt
pip install -r requirements.txt -e .
Expand All @@ -62,7 +65,12 @@ jobs:
run: |
export WAVE_RUN_E2E_TESTS=1
export TEST_PARAMS_PATH="tests/kernel/wave/test_param.json"
export WAVE_BENCHMARK_E2E_TESTS=1
pytest -n 1 ./tests/kernel/wave/
mkdir /artifacts/${{ github.sha }}
echo $'${{github.ref_name}}\n${{github.event.pull_request.title}}\n${{env.NOW}}' > /artifacts/${{ github.sha }}/ghinfo.txt
mv perf*.json /artifacts/${{ github.sha }}
npm run --prefix /infra/tkw/ build
- name: Run LIT tests
if: ${{ !cancelled() }}
run: |
Expand Down
8 changes: 7 additions & 1 deletion shark_turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,13 @@ def compile_and_invoke(
_invoke(ctx.vm_context, device, func, kernel_inputs, kernel_outputs)

if run_bench:
inputs = [inp.numpy() for inp in kernel_inputs]
inputs = []
# TODO: This is a workaround until benchmark_module uses the TemporaryFile API.
for i, inp in enumerate(kernel_inputs):
input_file_name = f"input_{i}.npy"
with open(input_file_name, "wb") as f:
numpy.save(f, inp.numpy())
inputs.append("@" + input_file_name)
benchmark_results = bench.benchmark_module(
mod,
entry_function=func_name,
Expand Down
20 changes: 17 additions & 3 deletions tests/kernel/wave/wave_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import json

_run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0))
_bench_e2e = int(os.environ.get("WAVE_BENCH_E2E_TESTS", 0))
enable_benchmarking = True if _bench_e2e else False
require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled")
# Whether to dump the generated MLIR module.
test_dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
Expand Down Expand Up @@ -107,17 +109,29 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
N: shape[1],
K: shape[2],
}
config = {"backend": "rocm", "device": "hip", "target": "gfx942"}
prefix = f"wave_gemm_{'x'.join(map(str, shape))}"
config = {
"backend": "rocm",
"device": "hip",
"target": "gfx942",
"benchmark_results_file": "perf_" + prefix + ".json",
"benchmark_batch_size": 1000,
"benchmark_repetitions": 3,
}
with tk.gen.TestLaunchContext(
hyperparams, canonicalize=True, run=True, run_config=config
hyperparams,
canonicalize=True,
run=True,
run_config=config,
run_bench=enable_benchmarking,
):
a = torch.randn(shape[0], shape[2], dtype=torch.float16)
b = torch.randn(shape[1], shape[2], dtype=torch.float16)
c = torch.zeros(shape[0], shape[1], dtype=torch.float32)
mb = gemm(a, b, c)

if test_dump_generated_mlir:
filename = f"wave_gemm_{'x'.join(map(str, shape))}.mlir"
filename = prefix + ".mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())

Expand Down

0 comments on commit 2177106

Please sign in to comment.