From b55065a2b747ed4a4b755a9f34d04e6bcabdfa4d Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Wed, 9 Oct 2024 12:20:55 -0400 Subject: [PATCH 1/7] Enable check for sharded Conv2D test (#263) The fix https://github.com/iree-org/iree-turbine/pull/205 solves the issue with this test. Xfail the Unet Resnet block test with maybe low accuracy. --- .../layers/sharded_conv2d_with_iree_test.py | 14 +++++------ .../sharded_resnet_block_with_iree_test.py | 24 ++++++++++++------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/sharktank/tests/layers/sharded_conv2d_with_iree_test.py b/sharktank/tests/layers/sharded_conv2d_with_iree_test.py index 2a6ecace2..9b29e5761 100644 --- a/sharktank/tests/layers/sharded_conv2d_with_iree_test.py +++ b/sharktank/tests/layers/sharded_conv2d_with_iree_test.py @@ -173,14 +173,12 @@ def run_test_sharded_conv2d_with_iree( ) assert len(actual_result.shards) == len(expected_result.shards) assert actual_result.shard_dim == expected_result.shard_dim - # TODO: reenable this check once numerical issues are resolved. - # See https://github.com/iree-org/iree/issues/18283 - # for actual_shard, expected_shard in zip( - # actual_result.shards, expected_result.shards - # ): - # torch.testing.assert_close( - # unbox_tensor(actual_shard), unbox_tensor(expected_shard) - # ) + for actual_shard, expected_shard in zip( + actual_result.shards, expected_result.shards + ): + torch.testing.assert_close( + unbox_tensor(actual_shard), unbox_tensor(expected_shard) + ) def test_sharded_conv2d_with_iree( diff --git a/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py b/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py index 86bb41c71..581584369 100644 --- a/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py +++ b/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py @@ -19,6 +19,7 @@ import iree.runtime from typing import List, Optional import os +import pytest vm_context: iree.runtime.VmContext = None @@ -207,19 +208,26 @@ def run_test_sharded_resnet_block_with_iree( parameters_path=parameters_path, ) assert len(actual_result.shards) == len(expected_result.shards) - # TODO: reenable this check once numerical issues are resolved. - # See https://github.com/iree-org/iree/issues/18283 - # for actual_shard, expected_shard in zip( - # actual_result.shards, expected_result.shards - # ): - # torch.testing.assert_close( - # unbox_tensor(actual_shard), unbox_tensor(expected_shard) - # ) + # TODO: reenable this test once numerical issues are resolved. + # The absolute accuracy is > 0.00042. Is this good enough? + # Maybe add a test with fp64, where if the accuracy is high would give us more + # confidence that fp32 is also OK. + for actual_shard, expected_shard in zip( + actual_result.shards, expected_result.shards + ): + torch.testing.assert_close( + unbox_tensor(actual_shard), unbox_tensor(expected_shard) + ) global vm_context del vm_context +@pytest.mark.xfail( + reason="Maybe numerical issues with low accuracy.", + strict=True, + raises=AssertionError, +) def test_sharded_resnet_block_with_iree( mlir_path: Optional[Path], module_path: Optional[Path], From a0d5d10542a698db9ed96d1d73da663f87b84ded Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 11 Oct 2024 13:40:07 +0200 Subject: [PATCH 2/7] Cleanup workflow files (#270) * Removes extra path from command line * Adds quotes to make call compatible with Windows CI * Removes no longer required deps --- .github/workflows/ci_linux_x64-libshortfin.yml | 11 ++++------- .github/workflows/ci_linux_x64_nogil-libshortfin.yml | 5 ++--- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci_linux_x64-libshortfin.yml index bdb4620be..d7450cbe7 100644 --- a/.github/workflows/ci_linux_x64-libshortfin.yml +++ b/.github/workflows/ci_linux_x64-libshortfin.yml @@ -41,7 +41,6 @@ jobs: run: | sudo apt update sudo apt install clang lld cmake ninja-build - sudo apt install libspdlog-dev libxtensor-dev - name: Checkout repository uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 @@ -89,9 +88,8 @@ jobs: -DCMAKE_CXX_COMPILER=clang++-18 \ -DCMAKE_LINKER_TYPE=LLD \ -DSHORTFIN_BUNDLE_DEPS=ON \ - -DSHORTFIN_IREE_SOURCE_DIR=${{ env.IREE_REPO_DIR }} \ - -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \ - .. + -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \ + -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON cmake --build build --target all pip install -v -e build/ @@ -113,10 +111,9 @@ jobs: -DCMAKE_C_COMPILER=clang-18 \ -DCMAKE_CXX_COMPILER=clang++-18 \ -DCMAKE_LINKER_TYPE=LLD \ - -DSHORTFIN_IREE_SOURCE_DIR=${{ env.IREE_REPO_DIR }} \ + -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \ -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \ -DSHORTFIN_HAVE_AMDGPU=OFF \ -DSHORTFIN_BUILD_STATIC=ON \ - -DSHORTFIN_BUILD_DYNAMIC=ON \ - .. + -DSHORTFIN_BUILD_DYNAMIC=ON cmake --build build-host-only --target all diff --git a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml index 08f5e62da..12efdadda 100644 --- a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml +++ b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml @@ -86,9 +86,8 @@ jobs: -DCMAKE_CXX_COMPILER=clang++-18 \ -DCMAKE_LINKER_TYPE=LLD \ -DSHORTFIN_BUNDLE_DEPS=ON \ - -DSHORTFIN_IREE_SOURCE_DIR=${{ env.IREE_REPO_DIR }} \ - -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \ - .. + -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \ + -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON cmake --build build --target all pip install -v -e build/ From e4bcf99bb2387a2280da3dc26993a73878bad8b8 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Fri, 11 Oct 2024 13:01:53 -0400 Subject: [PATCH 3/7] [tuner] Fix mfma constructor arguments (#266) --- tuner/tuner/candidate_gen.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 8faf70f85..16f0cf724 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -452,11 +452,11 @@ def generate_solutions(problem_size: ProblemSize, num_subgrups: int): lookup(subgroup_size), [lookup(wg_x), lookup(wg_y), lookup(wg_z)], MfmaIntrinsic( - problem_size.lhs_type.element_type, + problem_size.res_type.element_type, lookup(intrinsic_mn), lookup(intrinsic_mn), lookup(intrinsic_k), - problem_size.res_type.element_type, + problem_size.lhs_type.element_type, ), [lookup(m), lookup(n), lookup(k)], lookup(sg_m_cnt), From 468fb29ee6999473845d59af2ad17cbd834a2409 Mon Sep 17 00:00:00 2001 From: Mihaescu Vlad <52869843+mihaescuvlad@users.noreply.github.com> Date: Fri, 11 Oct 2024 20:26:30 +0300 Subject: [PATCH 4/7] [tuner] Use JSON for benchmark output (#256) ### Notes - Adds `extract_benchmark_from_run_result` method to help in fetching the "benchmarks" data - Updates `IREEBenchmarkResult` model to reflect that the result is no longer stored as string but as a list of benchmarks - Updates the parsing of dispatches and models to read from Json ### Testing - Updates tests to verify that `get_mean_time` functions as expected - Updates tests to verify that Json data is properly parsed and processed --- tuner/examples/dispatch/dispatch_tuner.py | 1 + tuner/examples/punet/punet_autotune.py | 6 +- tuner/tuner/libtuner.py | 85 +++++++++--- tuner/tuner/libtuner_test.py | 161 ++++++++++++++++------ 4 files changed, 192 insertions(+), 61 deletions(-) diff --git a/tuner/examples/dispatch/dispatch_tuner.py b/tuner/examples/dispatch/dispatch_tuner.py index 98086fbbb..3c2d77f64 100644 --- a/tuner/examples/dispatch/dispatch_tuner.py +++ b/tuner/examples/dispatch/dispatch_tuner.py @@ -58,6 +58,7 @@ def get_dispatch_benchmark_command( f"--module={compiled_vmfb_path.resolve()}", "--batch_size=1000", "--benchmark_repetitions=3", + "--benchmark_format=json", ] return command diff --git a/tuner/examples/punet/punet_autotune.py b/tuner/examples/punet/punet_autotune.py index b78989991..3503c86df 100644 --- a/tuner/examples/punet/punet_autotune.py +++ b/tuner/examples/punet/punet_autotune.py @@ -58,8 +58,7 @@ def get_dispatch_benchmark_command( "--hip_allow_inline_execution=true", "--batch_size=1000", "--benchmark_repetitions=3", - f"--benchmark_out=dispatch_{candidate_tracker.candidate_id}_bm.json", - "--benchmark_out_format=json", + "--benchmark_format=json", ] return command @@ -110,8 +109,7 @@ def get_model_benchmark_command( "--input=2x6xf16", "--input=1xf16", "--benchmark_repetitions=5", - f"--benchmark_out=model_{candidate_tracker.candidate_id}_bm.json", - "--benchmark_out_format=json", + "--benchmark_format=json", ] return command diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index 30ce732bd..91c7b417a 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -36,6 +36,7 @@ from typing import Type, Optional, Callable, Iterable, Any import pickle import random +import json from abc import ABC, abstractmethod import iree.runtime as ireert from . import candidate_gen @@ -236,20 +237,48 @@ class ParsedDisptachBenchmarkResult: class IREEBenchmarkResult: # Default format follows output of iree-benchmark-module candidate_id: int - result_str: str - def get_mean_time(self) -> Optional[float]: - if not self.result_str: - return None - pattern = r"process_time/real_time_mean\s+([\d.]+)\s\w{2}" - match = re.search(pattern, self.result_str) - if not match: - return None - try: - return float(match.group(1)) - except ValueError: + # A list of dictionaries, each representing a benchmark result + # Each dictionary contains fields like: aggregate_name: string, real_time: float, cpu_time: float, time_unit: str, repetitions: int, etc. + result_json: list[dict[str, Any]] + + def get_mean_time_us(self) -> Optional[float]: + """Compute the mean time (in microseconds) for all of the benchmarks""" + if not self.result_json: return None + mean_benchmark = self.find_mean_benchmark(self.result_json) + + if mean_benchmark: + real_time = mean_benchmark.get("real_time") + time_unit = mean_benchmark.get("time_unit") + + if real_time is not None: + return self.unit_to_microseconds(real_time, time_unit) + + return None + + @staticmethod + def find_mean_benchmark(result_json: list[dict[str, Any]]) -> Optional[dict]: + for benchmark in result_json: + if benchmark.get("aggregate_name") == "mean": + return benchmark + + return None + + @staticmethod + def unit_to_microseconds(real_time: float, time_unit: str) -> float: + unit_conversions = { + "s": 1e6, + "ms": 1e3, + "us": 1, + "ns": 1e-3, + } + + assert time_unit in unit_conversions, f"Unsupported time unit: {time_unit}" + + return real_time * unit_conversions[time_unit] + def generate_display_DBR(candidate_id: int, mean_time: float) -> str: """Generate dispatch_benchmark_result string for displaying""" @@ -619,6 +648,26 @@ def multiprocess_progress_wrapper( return results +def extract_benchmark_from_run_result( + run_result: RunResult, +) -> Optional[list[dict[str, Any]]]: + """Extract the benchmark from the result JSON""" + if run_result.process_res and run_result.process_res.stdout: + try: + result_json = json.loads(run_result.process_res.stdout) + + return result_json.get("benchmarks", None) + except json.JSONDecodeError as e: + handle_error( + condition=True, + msg=f"Failed to parse JSON from stdout: {e}", + error_type=ValueError, + exit_program=True, + ) + + return None + + def numerical_sort_key(path: Path) -> tuple[int | float, str]: """ Define a sort key function that splits the filename into a numeric and a string part. @@ -896,9 +945,9 @@ def parse_dispatch_benchmark_results( incomplete_list.append(candidate_id) continue - res_str = process_res.stdout - res = IREEBenchmarkResult(candidate_id, res_str) - benchmark_time = res.get_mean_time() + res_json = extract_benchmark_from_run_result(benchmark_result.run_result) + res = IREEBenchmarkResult(candidate_id, res_json) + benchmark_time = res.get_mean_time_us() assert benchmark_time is not None candidate_trackers[candidate_id].first_benchmark_time = benchmark_time candidate_trackers[ @@ -1185,9 +1234,9 @@ def parse_model_benchmark_results( baseline_time = None continue - result_str = process_res.stdout - res = IREEBenchmarkResult(candidate_id, result_str) - benchmark_time = res.get_mean_time() + result_json = extract_benchmark_from_run_result(task_result.run_result) + res = IREEBenchmarkResult(candidate_id, result_json) + benchmark_time = res.get_mean_time_us() assert benchmark_time is not None # Record baseline benchmarking result and skip rest processes @@ -1328,7 +1377,7 @@ def benchmark_models( ) -def summerize_top_candidates( +def summarize_top_candidates( path_config: PathConfig, candidate_trackers: list[CandidateTracker] ): dump_list = [] diff --git a/tuner/tuner/libtuner_test.py b/tuner/tuner/libtuner_test.py index 3cbaa5ed0..36bda3bd5 100644 --- a/tuner/tuner/libtuner_test.py +++ b/tuner/tuner/libtuner_test.py @@ -6,11 +6,12 @@ import argparse import pytest +import json from unittest.mock import call, patch, MagicMock from . import libtuner """ -Usage: python -m pytest test_libtuner.py +Usage: python -m pytest libtuner_test.py """ @@ -57,34 +58,77 @@ def test_collision_handler(): def test_IREEBenchmarkResult_get(): - # Time is int - normal_str = r""" - ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ - Benchmark Time CPU Iterations UserCounters... - ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ - BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time 271 us 275 us 3000 items_per_second=3.65611k/s - BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time 274 us 275 us 3000 items_per_second=3.65481k/s - BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time 273 us 275 us 3000 items_per_second=3.65671k/s - BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_mean 274 us 275 us 3 items_per_second=3.65587k/s - BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_mean 275 us 275 us 3 items_per_second=3.65611k/s - BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_stddev 0.073 us 0.179 us 3 items_per_second=0.971769/s - BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_cv 0.03 % 0.07 % 3 items_per_second=0.03% - """ - res = libtuner.IREEBenchmarkResult(candidate_id=1, result_str=normal_str) - assert res.get_mean_time() == float(274) - - # Time is float + # Time is int in us + int_json = [{"aggregate_name": "mean", "real_time": 1, "time_unit": "us"}] + + res = libtuner.IREEBenchmarkResult(candidate_id=1, result_json=int_json) + assert res.get_mean_time_us() == float(1) + + # Time is float in us + float_json = [{"aggregate_name": "mean", "real_time": 123.45, "time_unit": "us"}] + + res = libtuner.IREEBenchmarkResult(candidate_id=2, result_json=float_json) + assert res.get_mean_time_us() == 123.45 + + # Time is in seconds + seconds_json = [{"aggregate_name": "mean", "real_time": 1.0, "time_unit": "s"}] + + res = libtuner.IREEBenchmarkResult(candidate_id=3, result_json=seconds_json) + assert res.get_mean_time_us() == 1.0 * 1e6 + + # Time is in miliseconds + miliseconds_json = [{"aggregate_name": "mean", "real_time": 1.0, "time_unit": "ms"}] + + res = libtuner.IREEBenchmarkResult(candidate_id=4, result_json=miliseconds_json) + assert res.get_mean_time_us() == 1.0 * 1e3 + + # Time is in nanoseconds + nanoseconds_json = [{"aggregate_name": "mean", "real_time": 1.0, "time_unit": "ns"}] + + res = libtuner.IREEBenchmarkResult(candidate_id=5, result_json=nanoseconds_json) + assert res.get_mean_time_us() == 1.0 * 1e-3 + + small_number_json = [ + { + "aggregate_name": "mean", + "real_time": 3.4591828516259519e-02, + "time_unit": "ms", + } + ] + + res = libtuner.IREEBenchmarkResult(candidate_id=6, result_json=small_number_json) + assert res.get_mean_time_us() == 34.591828516259519 + + # Invalid json: missing real_time + invalid_real_time_json = [{"aggregate_name": "mean", "real_time": None}] + res = libtuner.IREEBenchmarkResult( - candidate_id=2, - result_str="process_time/real_time_mean 123.45 us, process_time/real_time_mean 246.78 us", + candidate_id=7, result_json=invalid_real_time_json ) - assert res.get_mean_time() == 123.45 + assert res.get_mean_time_us() == None - # Invalid str - res = libtuner.IREEBenchmarkResult(candidate_id=3, result_str="hello world") - assert res.get_mean_time() == None - res = libtuner.IREEBenchmarkResult(candidate_id=4, result_str="") - assert res.get_mean_time() == None + # Invalid json: empty dictionary + res = libtuner.IREEBenchmarkResult(candidate_id=8, result_json={}) + assert res.get_mean_time_us() is None + + # Invalid json: invalid time unit + invalid_time_unit_json = [ + {"aggregate_name": "mean", "real_time": 1.0, "time_unit": "invalid_unit"} + ] + + with pytest.raises(AssertionError, match="Unsupported time unit: invalid_unit"): + res = libtuner.IREEBenchmarkResult( + candidate_id=9, result_json=invalid_time_unit_json + ) + res.get_mean_time_us() + + # Invalid json: missing aggregate_name + invalid_aggregate_name_json = [{"real_time": 1.0, "time_unit": "us"}] + + res = libtuner.IREEBenchmarkResult( + candidate_id=10, result_json=invalid_aggregate_name_json + ) + assert res.get_mean_time_us() is None def test_generate_display_BR(): @@ -110,15 +154,37 @@ def test_parse_dispatch_benchmark_results(): object.__setattr__(path_config, "specs_dir", spec_dir) mock_result_1 = MagicMock() - mock_result_1.run_result.process_res.stdout = "process_time/real_time_mean 100.0 us" + mock_json_1 = { + "benchmarks": [ + {"aggregate_name": "mean", "real_time": 100.0, "time_unit": "us"} + ] + } + mock_result_1.run_result.process_res.stdout = json.dumps(mock_json_1) mock_result_1.candidate_id = 1 mock_result_2 = MagicMock() - mock_result_2.run_result.process_res.stdout = "process_time/real_time_mean 200.0 us" + mock_json_2 = { + "benchmarks": [ + {"aggregate_name": "mean", "real_time": 200.0, "time_unit": "us"} + ] + } + mock_result_2.run_result.process_res.stdout = json.dumps(mock_json_2) mock_result_2.candidate_id = 2 mock_result_3 = MagicMock() - mock_result_3.run_result.process_res = None # Incomplete result + mock_json_3 = { + "benchmarks": [ + { + "aggregate_name": "mean", + "real_time": 3.4591828516259519e-02, + "time_unit": "ms", + } + ] + } + mock_result_3.run_result.process_res.stdout = json.dumps(mock_json_3) mock_result_3.candidate_id = 3 - benchmark_results = [mock_result_1, mock_result_2, mock_result_3] + mock_result_4 = MagicMock() + mock_result_4.run_result.process_res = None # Incomplete result + mock_result_4.candidate_id = 4 + benchmark_results = [mock_result_1, mock_result_2, mock_result_3, mock_result_4] candidate_trackers = [] for i in range(4): @@ -139,11 +205,18 @@ def test_parse_dispatch_benchmark_results(): candidate_mlir=libtuner.Path("/mock/mlir/path/2.mlir"), candidate_spec_mlir=libtuner.Path("/mock/base/dir/specs/2_spec.mlir"), ), + libtuner.ParsedDisptachBenchmarkResult( + candidate_id=3, + benchmark_time_in_seconds=34.591828516259519, + candidate_mlir=libtuner.Path("/mock/mlir/path/3.mlir"), + candidate_spec_mlir=libtuner.Path("/mock/base/dir/specs/3_spec.mlir"), + ), ] expected_dump_list = [ "1\tMean Time: 100.0\n", "2\tMean Time: 200.0\n", - "Candidate 3 not completed", + "3\tMean Time: 34.6\n", + "Candidate 4 not completed", ] parsed_results, dump_list = libtuner.parse_dispatch_benchmark_results( @@ -160,6 +233,10 @@ def test_parse_dispatch_benchmark_results(): assert candidate_trackers[2].spec_path == libtuner.Path( "/mock/base/dir/specs/2_spec.mlir" ) + assert candidate_trackers[3].first_benchmark_time == 34.591828516259519 + assert candidate_trackers[3].spec_path == libtuner.Path( + "/mock/base/dir/specs/3_spec.mlir" + ) def test_parse_model_benchmark_results(): @@ -180,22 +257,26 @@ def test_parse_model_benchmark_results(): # Setup mock data for task results result1 = MagicMock() - result1.run_result.process_res.stdout = "1.23" + result_json_1 = {"benchmarks": [{"real_time": 1.23}]} + result1.run_result.process_res.stdout = json.dumps(result_json_1) result1.candidate_id = 1 result1.device_id = "device1" result2 = MagicMock() - result2.run_result.process_res.stdout = "4.56" + result_json_2 = {"benchmarks": [{"real_time": 4.56}]} + result2.run_result.process_res.stdout = json.dumps(result_json_2) result2.candidate_id = 2 result2.device_id = "device2" result3 = MagicMock() - result3.run_result.process_res.stdout = "0.98" + result_json_3 = {"benchmarks": [{"real_time": 0.98}]} + result3.run_result.process_res.stdout = json.dumps(result_json_3) result3.candidate_id = 0 result3.device_id = "device1" result4 = MagicMock() - result4.run_result.process_res.stdout = "4.13" + result_json_4 = {"benchmarks": [{"real_time": 4.13}]} + result4.run_result.process_res.stdout = json.dumps(result_json_4) result4.candidate_id = 0 result4.device_id = "device2" @@ -206,7 +287,8 @@ def test_parse_model_benchmark_results(): result5.device_id = "device3" result6 = MagicMock() - result6.run_result.process_res.stdout = "3.38" + result_json_6 = {"benchmarks": [{"real_time": 3.38}]} + result6.run_result.process_res.stdout = json.dumps(result_json_6) result6.candidate_id = 3 result6.device_id = "device3" @@ -214,12 +296,13 @@ def test_parse_model_benchmark_results(): baseline_results = [result3, result4, result5] # Skip real benchmark extraction, directly use given values from above - def mock_get_mean_time(self): - return float(self.result_str) if self.result_str else None + def mock_get_mean_time_us(self): + return float(self.result_json[0]["real_time"]) if self.result_json else None # Mock IREEBenchmarkResult to return wanted benchmark times with patch( - f"{libtuner.__name__}.IREEBenchmarkResult.get_mean_time", new=mock_get_mean_time + f"{libtuner.__name__}.IREEBenchmarkResult.get_mean_time_us", + new=mock_get_mean_time_us, ): # Mock handle_error to avoid actual logging during tests with patch(f"{libtuner.__name__}.handle_error") as mock_handle_error: From 459de98f4a74159d459b68e742516334e2013748 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 11 Oct 2024 19:59:40 +0200 Subject: [PATCH 5/7] Pin (and update) actions (#268) Updates the checkout action as this uses a deprecated Node.js version and the old version will therefore be forced to run on node20. Further pins actions as suggested byt OpenSSF Scorecard, see https://github.com/ossf/scorecard/blob/main/docs/checks.md#pinned-dependencies. --- .github/workflows/ci-tuner.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci-tuner.yml b/.github/workflows/ci-tuner.yml index 5de7d4182..1944caa6a 100644 --- a/.github/workflows/ci-tuner.yml +++ b/.github/workflows/ci-tuner.yml @@ -20,10 +20,10 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4.1.7 + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1 with: python-version: '3.10.12' From 3015ec7c24052fbd1826cfb4190f7f7d7d8d7c90 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 11 Oct 2024 14:41:54 -0700 Subject: [PATCH 6/7] [sharktank] Add test for sharded rotary table (#274) We should be able to validate the sharded rotary table via comparison with the unsharded version. This runs the sharded and unsharded implementations, asserting near identical results. --- .../layers/sharded_rotary_embedding_test.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 sharktank/tests/layers/sharded_rotary_embedding_test.py diff --git a/sharktank/tests/layers/sharded_rotary_embedding_test.py b/sharktank/tests/layers/sharded_rotary_embedding_test.py new file mode 100644 index 000000000..963b9b432 --- /dev/null +++ b/sharktank/tests/layers/sharded_rotary_embedding_test.py @@ -0,0 +1,56 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +import torch + +from sharktank.layers import RotaryEmbeddingLayer +from sharktank import ops +from sharktank.types import ( + ShardedTensor, + SplitPrimitiveTensor, + unbox_tensor, +) + +import unittest +from typing import List, Optional +import os + + +def test_sharded_rotary_table(): + bs = 4 + rope_dims = 16 + heads = 8 + max_seqlen = 128 + rope_freq_base = None + + # First we setup and get the default rotary embedding layer + xq = torch.rand((bs, max_seqlen, heads, rope_dims), dtype=torch.float) + xk = torch.rand((bs, max_seqlen, heads, rope_dims), dtype=torch.float) + default_layer = RotaryEmbeddingLayer( + rope_dimension_count=rope_dims, + max_seqlen=max_seqlen, + rope_freq_base=rope_freq_base, + ) + oq, ok = default_layer(xq=xq, xk=xk, start_index=0) + + # Then we can shard the same inputs and layer + xq = SplitPrimitiveTensor(ts=xq, shard_dim=2, shard_count=4) + xk = SplitPrimitiveTensor(ts=xk, shard_dim=2, shard_count=4) + shard_layer = RotaryEmbeddingLayer( + rope_dimension_count=rope_dims, + max_seqlen=max_seqlen, + rope_freq_base=rope_freq_base, + tensor_parallelism_size=4, + ) + sq, sk = shard_layer(xq=xq, xk=xk, start_index=0) + + # Gathering and unboxing should yield the same results + sq = ops.unshard(sq) + sk = ops.unshard(sk) + + torch.testing.assert_close(sq, oq) + torch.testing.assert_close(sk, ok) From 355761ba28a489bb33028f8f1403ed5e16302afa Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Mon, 14 Oct 2024 15:20:52 -0400 Subject: [PATCH 7/7] Add sharded paged attention test (#276) Verify that the sharded Llama paged attention block behaves in PyTorch as the unsharded variant. The fp32 accuracy seems low and this test is xfailed. The fp64 accuracy is fine. --- .../sharktank/layers/rotary_embedding.py | 2 +- sharktank/sharktank/models/llama/sharding.py | 2 +- .../sharded_paged_llama_attention_block.py | 163 ++++++++++++++++++ 3 files changed, 165 insertions(+), 2 deletions(-) create mode 100644 sharktank/tests/layers/sharded_paged_llama_attention_block.py diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py index 39e8490d3..834ea349f 100644 --- a/sharktank/sharktank/layers/rotary_embedding.py +++ b/sharktank/sharktank/layers/rotary_embedding.py @@ -21,7 +21,7 @@ def __init__( *, rope_dimension_count: int, max_seqlen: int, - rope_freq_base: float, + rope_freq_base: Optional[float], device: Optional[torch.device] = None, use_hf: bool = False, static_tables: bool = True, diff --git a/sharktank/sharktank/models/llama/sharding.py b/sharktank/sharktank/models/llama/sharding.py index 1a98419e6..3715a3923 100644 --- a/sharktank/sharktank/models/llama/sharding.py +++ b/sharktank/sharktank/models/llama/sharding.py @@ -4,7 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""Specifications describing how blocks/layers of llama are sharded.""" +"""Specifications describing how the Llama model is sharded.""" from ...types.sharding import * from ...types import Theta diff --git a/sharktank/tests/layers/sharded_paged_llama_attention_block.py b/sharktank/tests/layers/sharded_paged_llama_attention_block.py new file mode 100644 index 000000000..c94fd44ab --- /dev/null +++ b/sharktank/tests/layers/sharded_paged_llama_attention_block.py @@ -0,0 +1,163 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest +from sharktank.layers import ( + PagedLlamaAttentionBlock, + PagedKVCache, + RotaryEmbeddingLayer, +) +from sharktank.layers.testing import make_llama_attention_block_theta, make_rand_torch +from sharktank.models.llama.sharding import PagedLlamaAttentionBlockSharding +from sharktank.types import SplitPrimitiveTensor, unbox_tensor +import torch +from sharktank import ops +from copy import deepcopy +import pytest + + +class ShardedPagedLlamaAttentionBlockTest(unittest.TestCase): + """Verify that the sharded Llama paged attention block behaves in PyTorch as the + unsharded variant.""" + + def setUp(self): + torch.manual_seed(12345) + self.transformer_block_count = 13 + self.block_index = 1 + self.shard_count = 3 + self.head_count_kv = 2 * self.shard_count + self.attention_head_count = 5 * self.head_count_kv + self.attention_head_dim = 11 * 2 + self.rms_epsilon = 0.01 + self.block_seq_stride = 17 + self.cache_partition_count = 2 + self.page_count = 23 + self.embedding_length = self.attention_head_count * self.attention_head_dim + self.rope_dimension_count = self.attention_head_dim + self.block_seqlen = 7 + self.max_seqlen = self.block_seq_stride * self.block_seqlen + self.rope_freq_base = None + self.batch_size = 3 + self.start_index = 0 + + def testSmallSizedLayerFp64(self): + self.runTestSmallSizedLayer(dtype=torch.float64) + + @pytest.mark.xfail( + reason="The accuracy seems low (atol=0.0018, rtol=0.5065)", + strict=True, + raises=AssertionError, + ) + def testSmallSizedLayerFp32(self): + self.runTestSmallSizedLayer(dtype=torch.float32) + + def runTestSmallSizedLayer(self, dtype: torch.dtype): + torch.set_default_dtype(dtype) + + def make_paged_kv_cache(shard_count: int) -> PagedKVCache: + return PagedKVCache( + transformer_block_count=self.transformer_block_count, + attn_head_count=self.head_count_kv, + attn_head_dim=self.attention_head_dim, + cache_partition_count=self.cache_partition_count, + block_seq_stride=self.block_seq_stride, + dtype=dtype, + shard_count=shard_count, + ) + + cache = make_paged_kv_cache(shard_count=1) + sharded_cache = make_paged_kv_cache(shard_count=self.shard_count) + + def make_unsharded_and_sharded_equal_cache_states() -> tuple[ + list[torch.Tensor], list[SplitPrimitiveTensor] + ]: + cache_state = cache.allocate(self.page_count) + cache_state[0] = make_rand_torch(cache_state[0].shape, dtype=dtype) + sharded_cache_state = sharded_cache.shard_state(deepcopy(cache_state)) + return cache_state, sharded_cache_state + + ( + cache_state, + sharded_cache_state, + ) = make_unsharded_and_sharded_equal_cache_states() + + input_tensor = make_rand_torch( + ( + self.batch_size, + self.max_seqlen, + self.attention_head_count * self.attention_head_dim, + ), + dtype=dtype, + ) + seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view( + self.batch_size, -1 + ) + embedding_module = RotaryEmbeddingLayer( + rope_dimension_count=self.rope_dimension_count, + max_seqlen=self.max_seqlen, + rope_freq_base=self.rope_freq_base, + ) + + theta = make_llama_attention_block_theta( + head_count=self.attention_head_count, + head_count_kv=self.head_count_kv, + head_dim=self.attention_head_dim, + embedding_length=self.embedding_length, + ) + attention_block = PagedLlamaAttentionBlock( + theta=theta, + block_index=self.block_index, + cache=cache, + head_count=self.attention_head_count, + head_dim=self.attention_head_dim, + head_count_kv=self.head_count_kv, + rms_epsilon=self.rms_epsilon, + ) + expected_result = attention_block( + input_tensor, + embedding=embedding_module, + seq_block_ids=seq_block_ids, + start_index=self.start_index, + cache_state=cache_state, + ) + + sharded_input_tensor = ops.replicate(input_tensor, count=self.shard_count) + sharded_seq_block_ids = ops.replicate(seq_block_ids, count=self.shard_count) + sharded_embedding_module = RotaryEmbeddingLayer( + rope_dimension_count=self.rope_dimension_count, + max_seqlen=self.max_seqlen, + rope_freq_base=self.rope_freq_base, + tensor_parallelism_size=self.shard_count, + ) + + theta_sharding = PagedLlamaAttentionBlockSharding(shard_count=self.shard_count) + sharded_theta = ops.reshard(theta, theta_sharding) + sharded_attention_block = PagedLlamaAttentionBlock( + theta=sharded_theta, + block_index=self.block_index, + cache=sharded_cache, + head_count=self.attention_head_count, + head_dim=self.attention_head_dim, + head_count_kv=self.head_count_kv, + rms_epsilon=self.rms_epsilon, + ) + sharded_result = sharded_attention_block( + sharded_input_tensor, + embedding=sharded_embedding_module, + seq_block_ids=sharded_seq_block_ids, + start_index=self.start_index, + cache_state=sharded_cache_state, + ) + + actual_result = unbox_tensor(ops.unshard(sharded_result)) + actual_cache_state = unbox_tensor( + ops.unshard( + sharded_cache.unflatten_page_table(sharded_cache_state) + ).flatten(start_dim=1) + ) + + torch.testing.assert_close(actual_result, expected_result) + torch.testing.assert_close(actual_cache_state, cache_state[0])