Skip to content

Commit

Permalink
Merge branch 'main' into debug_array_ops
Browse files Browse the repository at this point in the history
  • Loading branch information
stbaione authored Oct 30, 2024
2 parents 4153f86 + c6c7321 commit c1d077a
Show file tree
Hide file tree
Showing 12 changed files with 400 additions and 372 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci-llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ name: Llama Benchmarking Tests

on:
workflow_dispatch:
pull_request:
schedule:
# Weekdays at 5:00 AM UTC = 10:00 PM PST.
- cron: "0 5 * * 1-5"
Expand Down Expand Up @@ -75,7 +76,7 @@ jobs:
"numpy<2.0"
- name: Run llama test
run: pytest sharktank/tests/models/llama/benchmark_amdgpu_tests.py -v -s --longrun
run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --longrun --iree-hip-target=gfx942

- name: Upload llama executable files
uses: actions/upload-artifact@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-sdxl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,4 @@ jobs:
working-directory: ${{ env.LIBSHORTFIN_DIR }}
run: |
ctest --timeout 30 --output-on-failure --test-dir build
pytest tests/apps/sd/e2e_test.py -v -s --system=amdgpu
HIP_VISIBLE_DEVICES=0 pytest tests/apps/sd/e2e_test.py -v -s --system=amdgpu
7 changes: 0 additions & 7 deletions sharktank/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,6 @@ def caching(request: FixtureRequest) -> Optional[bool]:
return set_fixture_from_cli_option(request, "caching")


@pytest.fixture(scope="class")
def iree_hip_target_type(request: FixtureRequest) -> Optional[str]:
return set_fixture_from_cli_option(
request, "iree_hip_target", "iree_hip_target_type"
)


@pytest.fixture(scope="class")
def tensor_parallelism_size(request: FixtureRequest) -> Optional[str]:
return set_fixture_from_cli_option(
Expand Down
12 changes: 7 additions & 5 deletions sharktank/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,23 @@

from setuptools import find_namespace_packages, setup # type: ignore

THIS_DIR = Path(__file__).resolve().parent
REPO_DIR = THIS_DIR.parent
VERSION_INFO_FILE = REPO_DIR / "version_info.json"
SETUPPY_DIR = os.path.realpath(os.path.dirname(__file__))


with open(
os.path.join(
THIS_DIR,
SETUPPY_DIR,
"README.md",
),
"rt",
) as f:
README = f.read()


# Setup and get version information.
VERSION_INFO_FILE = os.path.join(SETUPPY_DIR, "version_info.json")


def load_version_info():
with open(VERSION_INFO_FILE, "rt") as f:
return json.load(f)
Expand Down Expand Up @@ -54,7 +56,7 @@ def load_requirement_pins(requirements_file: Path):
requirement_pins.update(dict(pin_pairs))


load_requirement_pins(REPO_DIR / "requirements.txt")
load_requirement_pins("requirements.txt")


def get_version_spec(dep: str):
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def main():
"--attention-kernel",
type=str,
default="decomposed",
choices=["decomposed", "torch_sdpa"],
choices=["decomposed", "torch"],
)

args = cli.parse(parser)
Expand Down
6 changes: 2 additions & 4 deletions sharktank/sharktank/models/llama/tools/shard_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,10 @@ def main():
dataset = cli.get_input_dataset(args)

if args.output_file is None:
print(f"Need file destination for IRPA file")
return
raise RuntimeError(f"Need file destination for IRPA file")

if args.shard_count < 2:
print(f"Expect sharding greater than 1 found {args.shard_count}")
return
raise RuntimeError(f"Expect sharding greater than 1 found {args.shard_count}")

hp = configs.LlamaHParams.from_gguf_props(dataset.properties)
llama_config = LlamaModelConfig(hp)
Expand Down
86 changes: 80 additions & 6 deletions sharktank/sharktank/utils/export_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import os
import sys
import subprocess
import logging
import time
from pathlib import Path
from datetime import timedelta
from typing import List

import iree.compiler as ireec

Expand All @@ -25,6 +27,7 @@
class ExportArtifacts:
def __init__(
self,
*,
irpa_path: str,
batch_size: int,
iree_hip_target: str,
Expand Down Expand Up @@ -59,9 +62,44 @@ def wrapper(*args, **kwargs):

return wrapper

@timeit
def shard_irpa_file(
self,
*,
output_file: str,
):
shard_irpa_args = [
"python3",
"-m",
"sharktank.models.llama.tools.shard_llama",
"--irpa-file",
self.irpa_path,
"--output-file",
output_file,
"--shard_count",
str(self.tensor_parallelism_size),
]

cwd = self.sharktank_dir
cmd = subprocess.list2cmdline(shard_irpa_args)

logger.info(f"Sharding irpa file:\n" f"cd {cwd} && {cmd}")

proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd)
if proc.returncode != 0:
logger.error(
f"Error sharding irpa file with shard_llama.py\n"
f"{proc.stdout+proc.stderr}"
)
else:
logger.info(f"Sharded irpa file successfully:\n" f"{proc.stdout}")

return proc.returncode

@timeit
def export_to_mlir(
self,
*,
mlir_path: str,
json_path: str,
):
Expand All @@ -78,18 +116,16 @@ def export_to_mlir(
"--bs",
str(self.batch_size),
]
if self.attention_kernel == "decomposed":
if self.attention_kernel in ["decomposed", "torch"]:
export_args.append("--attention-kernel")
export_args.append(self.attention_kernel)
elif self.attention_kernel == "torch_sdpa":
raise NotImplementedError("attention_kernel torch_sdpa not implemented yet")

cwd = self.sharktank_dir
cmd = subprocess.list2cmdline(export_args)

logger.info(f"Exporting mlir:\n" f"cd {cwd} && {cmd}")

proc = subprocess.run(export_args, capture_output=True, cwd=cwd, text=True)
proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd)
if proc.returncode != 0:
logger.error(
f"Error exporting mlir with export_paged_llm_v1.py\n"
Expand All @@ -103,12 +139,14 @@ def export_to_mlir(
@timeit
def compile_to_vmfb(
self,
*,
mlir_path,
vmfb_path,
hal_dump_path,
):
# TODO: Control flag to enable multiple backends
compile_flags = ["--iree-hip-target=" + self.iree_hip_target]

compile_flags += [f"--iree-hal-dump-executable-files-to={hal_dump_path}/files"]
try:
ireec.compile_file(
input_file=mlir_path,
Expand All @@ -121,7 +159,43 @@ def compile_to_vmfb(
else:
logger.info(f"Compiled to vmfb successfully:\n" f"{vmfb_path}")

def create_file(self, suffix, prefix):
def iree_benchmark_vmfb(
self,
*,
hip_device_id: str,
vmfb_name: str,
irpa_path: str,
args: List[str],
cwd: str | Path,
):
"""Runs a compiled program with the given args using `iree-benchmark-module`.
This assumes that the `iree-benchmark-module` command is available (usually via PATH).
Args:
vmfb_name: Name of the .vmfb file (relative to `cwd`).
args: List of arguments to pass to `iree-benchmark-module`.
cwd: Working directory to run the command within. (either string or Path works)
compile_cmd: Command used to compile the program, for inclusion in error messages.
Raises Exception if running fails for some reason.
"""
benchmark_args = [
f"ROCR_VISIBLE_DEVICES={hip_device_id}",
"iree-benchmark-module",
f"--device=hip://{hip_device_id}",
"--hip_use_streams=true",
"--hip_allow_inline_execution=true",
"--device_allocator=caching",
f"--module={vmfb_name}",
f"--parameters=model={irpa_path}",
]
benchmark_args += args
cmd = subprocess.list2cmdline(benchmark_args)
logging.getLogger().info(f"Launching run command:\n" f"cd {cwd} && {cmd}")
proc = subprocess.run(cmd, shell=True, stdout=sys.stdout, cwd=cwd)
return_code = proc.returncode
if return_code != 0:
raise RuntimeError(f"Error running benchmark {cmd} in cwd {cwd}")

def create_file(self, *, suffix, prefix):
file_path = Path(prefix).with_suffix(suffix)
f = open(file_path, "w")
return file_path
Expand Down
Loading

0 comments on commit c1d077a

Please sign in to comment.