diff --git a/.github/workflows/ci-llama-large-tests.yaml b/.github/workflows/ci-llama-large-tests.yaml index 2f23ce812..07e3f88ed 100644 --- a/.github/workflows/ci-llama-large-tests.yaml +++ b/.github/workflows/ci-llama-large-tests.yaml @@ -68,6 +68,7 @@ jobs: pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" + # Test with nightly releases, not what iree-turbine uses. pip install -f https://iree.dev/pip-release-links.html --upgrade \ iree-base-compiler \ diff --git a/.github/workflows/ci-llama-quick-tests.yaml b/.github/workflows/ci-llama-quick-tests.yaml index 1bd952e61..25c4a6342 100644 --- a/.github/workflows/ci-llama-quick-tests.yaml +++ b/.github/workflows/ci-llama-quick-tests.yaml @@ -69,6 +69,7 @@ jobs: pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" + # Test with nightly releases, not what iree-turbine uses. pip install -f https://iree.dev/pip-release-links.html --upgrade \ iree-base-compiler \ diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index 8a6bf6b6d..7141b4921 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -195,6 +195,7 @@ def compile_to_vmfb( vmfb_path, cwd, hal_dump_path: Optional[Path] = None, + args: Optional[List[str]] = None, ): # TODO: Control flag to enable multiple backends compile_args = [ @@ -202,11 +203,6 @@ def compile_to_vmfb( f"{mlir_path}", f"--iree-hip-target={self.iree_hip_target}", f"--iree-hal-target-backends={self.iree_hal_target_backends}", - "--iree-dispatch-creation-enable-aggressive-fusion=true", - "--iree-global-opt-propagate-transposes=true", - "--iree-opt-aggressively-propagate-transposes=true", - "--iree-opt-data-tiling=false", - '--iree-preprocessing-pass-pipeline="builtin.module\\(util.func\\(iree-preprocessing-generalize-linalg-matmul-experimental\\)\\)"', f"-o={vmfb_path}", ] if self.tensor_parallelism_size > 1: @@ -214,17 +210,14 @@ def compile_to_vmfb( f"--iree-hal-target-device=hip[{i}]" for i in range(self.tensor_parallelism_size) ] - tp_flags = [ - "--iree-hal-force-indirect-command-buffers=true", - "--iree-stream-resource-memory-model=discrete", - "--iree-hip-legacy-sync=false", - ] compile_args += iree_hal_target_devices - compile_args += tp_flags if hal_dump_path: compile_args += [ f"--iree-hal-dump-executable-files-to={hal_dump_path}/files" ] + # Append optional arguments if provided + if args: + compile_args += args cmd = subprocess.list2cmdline(compile_args) logging.getLogger().info(f"Launching compile command:\n" f"cd {cwd} && {cmd}") diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_test.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py index b8d7dbc34..0a0d85d49 100644 --- a/sharktank/tests/models/llama/benchmark_amdgpu_test.py +++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py @@ -49,6 +49,13 @@ def setUpClass(cls): def setUp(self): self.hip_device_id = os.getenv("HIP_DEVICE_ID", default="0") + self.compile_args = [ + "--iree-dispatch-creation-enable-aggressive-fusion=true", + "--iree-global-opt-propagate-transposes=true", + "--iree-opt-aggressively-propagate-transposes=true", + "--iree-opt-data-tiling=false", + '--iree-preprocessing-pass-pipeline="builtin.module\\(util.func\\(iree-preprocessing-generalize-linalg-matmul-experimental\\)\\)"', + ] @is_mi300x @@ -154,6 +161,7 @@ def testBenchmark8B_f16_Decomposed(self): vmfb_path=output_vmfb, hal_dump_path=output_file_name, cwd=self.repo_root, + args=self.compile_args, ) # benchmark prefill self.llama8b_f16_decomposed_artifacts.iree_benchmark_vmfb( @@ -195,6 +203,7 @@ def testBenchmark8B_f16_Non_Decomposed(self): vmfb_path=output_vmfb, hal_dump_path=output_file_name, cwd=self.repo_root, + args=self.compile_args, ) # benchmark prefill self.llama8b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb( @@ -236,6 +245,7 @@ def testBenchmark8B_fp8_Decomposed(self): vmfb_path=output_vmfb, hal_dump_path=output_file_name, cwd=self.repo_root, + args=self.compile_args, ) # benchmark prefill self.llama8b_fp8_decomposed_artifacts.iree_benchmark_vmfb( @@ -277,6 +287,7 @@ def testBenchmark8B_fp8_Non_Decomposed(self): vmfb_path=output_vmfb, hal_dump_path=output_file_name, cwd=self.repo_root, + args=self.compile_args, ) # benchmark prefill self.llama8b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb( @@ -379,6 +390,11 @@ def setUp(self): f"--input=@{self.decode_args_fp8}/cache_state_f16.npy", "--benchmark_repetitions=3", ] + self.compile_args += [ + "--iree-hal-force-indirect-command-buffers=true", + "--iree-stream-resource-memory-model=discrete", + "--iree-hip-legacy-sync=false", + ] @pytest.mark.xfail( reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException @@ -409,6 +425,7 @@ def testBenchmark70B_f16_TP8_Decomposed(self): vmfb_path=output_vmfb, hal_dump_path=output_file_name, cwd=self.repo_root, + args=self.compile_args, ) # benchmark prefill self.llama70b_f16_decomposed_artifacts.iree_benchmark_vmfb( @@ -454,6 +471,7 @@ def testBenchmark70B_f16_TP8_Non_Decomposed(self): vmfb_path=output_vmfb, hal_dump_path=output_file_name, cwd=self.repo_root, + args=self.compile_args, ) # benchmark prefill self.llama70b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb( @@ -501,6 +519,7 @@ def testBenchmark70B_fp8_TP8_Decomposed(self): vmfb_path=output_vmfb, hal_dump_path=output_file_name, cwd=self.repo_root, + args=self.compile_args, ) # benchmark prefill self.llama70b_fp8_decomposed_artifacts.iree_benchmark_vmfb( @@ -548,6 +567,7 @@ def testBenchmark70B_fp8_TP8_Non_Decomposed(self): vmfb_path=output_vmfb, hal_dump_path=output_file_name, cwd=self.repo_root, + args=self.compile_args, ) # benchmark prefill self.llama70b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb( @@ -680,6 +700,7 @@ def testBenchmark405B_f16_TP8_Decomposed(self): vmfb_path=output_vmfb, hal_dump_path=output_file_name, cwd=self.repo_root, + args=self.compile_args, ) # benchmark prefill self.llama405b_f16_decomposed_artifacts.iree_benchmark_vmfb( @@ -725,6 +746,7 @@ def testBenchmark405B_f16_TP8_Non_Decomposed(self): vmfb_path=output_vmfb, hal_dump_path=output_file_name, cwd=self.repo_root, + args=self.compile_args, ) # benchmark prefill self.llama405b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb( @@ -772,6 +794,7 @@ def testBenchmark405B_fp8_TP8_Decomposed(self): vmfb_path=output_vmfb, hal_dump_path=output_file_name, cwd=self.repo_root, + args=self.compile_args, ) # benchmark prefill self.llama405b_fp8_decomposed_artifacts.iree_benchmark_vmfb( @@ -819,6 +842,7 @@ def testBenchmark405B_fp8_TP8_Non_Decomposed(self): vmfb_path=output_vmfb, hal_dump_path=output_file_name, cwd=self.repo_root, + args=self.compile_args, ) # benchmark prefill self.llama405b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb(