Skip to content

Commit

Permalink
Merge branch 'main' into einsum_matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleHerndon authored Oct 14, 2024
2 parents 623e3e7 + 355761b commit 4a2dbef
Show file tree
Hide file tree
Showing 14 changed files with 445 additions and 93 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci-tuner.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ jobs:

steps:
- name: Checkout code
uses: actions/[email protected]
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1
with:
python-version: '3.10.12'

Expand Down
11 changes: 4 additions & 7 deletions .github/workflows/ci_linux_x64-libshortfin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ jobs:
run: |
sudo apt update
sudo apt install clang lld cmake ninja-build
sudo apt install libspdlog-dev libxtensor-dev
- name: Checkout repository
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
Expand Down Expand Up @@ -89,9 +88,8 @@ jobs:
-DCMAKE_CXX_COMPILER=clang++-18 \
-DCMAKE_LINKER_TYPE=LLD \
-DSHORTFIN_BUNDLE_DEPS=ON \
-DSHORTFIN_IREE_SOURCE_DIR=${{ env.IREE_REPO_DIR }} \
-DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \
..
-DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \
-DSHORTFIN_BUILD_PYTHON_BINDINGS=ON
cmake --build build --target all
pip install -v -e build/
Expand All @@ -113,10 +111,9 @@ jobs:
-DCMAKE_C_COMPILER=clang-18 \
-DCMAKE_CXX_COMPILER=clang++-18 \
-DCMAKE_LINKER_TYPE=LLD \
-DSHORTFIN_IREE_SOURCE_DIR=${{ env.IREE_REPO_DIR }} \
-DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \
-DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \
-DSHORTFIN_HAVE_AMDGPU=OFF \
-DSHORTFIN_BUILD_STATIC=ON \
-DSHORTFIN_BUILD_DYNAMIC=ON \
..
-DSHORTFIN_BUILD_DYNAMIC=ON
cmake --build build-host-only --target all
5 changes: 2 additions & 3 deletions .github/workflows/ci_linux_x64_nogil-libshortfin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,8 @@ jobs:
-DCMAKE_CXX_COMPILER=clang++-18 \
-DCMAKE_LINKER_TYPE=LLD \
-DSHORTFIN_BUNDLE_DEPS=ON \
-DSHORTFIN_IREE_SOURCE_DIR=${{ env.IREE_REPO_DIR }} \
-DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \
..
-DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \
-DSHORTFIN_BUILD_PYTHON_BINDINGS=ON
cmake --build build --target all
pip install -v -e build/
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
*,
rope_dimension_count: int,
max_seqlen: int,
rope_freq_base: float,
rope_freq_base: Optional[float],
device: Optional[torch.device] = None,
use_hf: bool = False,
static_tables: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/models/llama/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Specifications describing how blocks/layers of llama are sharded."""
"""Specifications describing how the Llama model is sharded."""

from ...types.sharding import *
from ...types import Theta
Expand Down
14 changes: 6 additions & 8 deletions sharktank/tests/layers/sharded_conv2d_with_iree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,12 @@ def run_test_sharded_conv2d_with_iree(
)
assert len(actual_result.shards) == len(expected_result.shards)
assert actual_result.shard_dim == expected_result.shard_dim
# TODO: reenable this check once numerical issues are resolved.
# See https://github.com/iree-org/iree/issues/18283
# for actual_shard, expected_shard in zip(
# actual_result.shards, expected_result.shards
# ):
# torch.testing.assert_close(
# unbox_tensor(actual_shard), unbox_tensor(expected_shard)
# )
for actual_shard, expected_shard in zip(
actual_result.shards, expected_result.shards
):
torch.testing.assert_close(
unbox_tensor(actual_shard), unbox_tensor(expected_shard)
)


def test_sharded_conv2d_with_iree(
Expand Down
163 changes: 163 additions & 0 deletions sharktank/tests/layers/sharded_paged_llama_attention_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import unittest
from sharktank.layers import (
PagedLlamaAttentionBlock,
PagedKVCache,
RotaryEmbeddingLayer,
)
from sharktank.layers.testing import make_llama_attention_block_theta, make_rand_torch
from sharktank.models.llama.sharding import PagedLlamaAttentionBlockSharding
from sharktank.types import SplitPrimitiveTensor, unbox_tensor
import torch
from sharktank import ops
from copy import deepcopy
import pytest


class ShardedPagedLlamaAttentionBlockTest(unittest.TestCase):
"""Verify that the sharded Llama paged attention block behaves in PyTorch as the
unsharded variant."""

def setUp(self):
torch.manual_seed(12345)
self.transformer_block_count = 13
self.block_index = 1
self.shard_count = 3
self.head_count_kv = 2 * self.shard_count
self.attention_head_count = 5 * self.head_count_kv
self.attention_head_dim = 11 * 2
self.rms_epsilon = 0.01
self.block_seq_stride = 17
self.cache_partition_count = 2
self.page_count = 23
self.embedding_length = self.attention_head_count * self.attention_head_dim
self.rope_dimension_count = self.attention_head_dim
self.block_seqlen = 7
self.max_seqlen = self.block_seq_stride * self.block_seqlen
self.rope_freq_base = None
self.batch_size = 3
self.start_index = 0

def testSmallSizedLayerFp64(self):
self.runTestSmallSizedLayer(dtype=torch.float64)

@pytest.mark.xfail(
reason="The accuracy seems low (atol=0.0018, rtol=0.5065)",
strict=True,
raises=AssertionError,
)
def testSmallSizedLayerFp32(self):
self.runTestSmallSizedLayer(dtype=torch.float32)

def runTestSmallSizedLayer(self, dtype: torch.dtype):
torch.set_default_dtype(dtype)

def make_paged_kv_cache(shard_count: int) -> PagedKVCache:
return PagedKVCache(
transformer_block_count=self.transformer_block_count,
attn_head_count=self.head_count_kv,
attn_head_dim=self.attention_head_dim,
cache_partition_count=self.cache_partition_count,
block_seq_stride=self.block_seq_stride,
dtype=dtype,
shard_count=shard_count,
)

cache = make_paged_kv_cache(shard_count=1)
sharded_cache = make_paged_kv_cache(shard_count=self.shard_count)

def make_unsharded_and_sharded_equal_cache_states() -> tuple[
list[torch.Tensor], list[SplitPrimitiveTensor]
]:
cache_state = cache.allocate(self.page_count)
cache_state[0] = make_rand_torch(cache_state[0].shape, dtype=dtype)
sharded_cache_state = sharded_cache.shard_state(deepcopy(cache_state))
return cache_state, sharded_cache_state

(
cache_state,
sharded_cache_state,
) = make_unsharded_and_sharded_equal_cache_states()

input_tensor = make_rand_torch(
(
self.batch_size,
self.max_seqlen,
self.attention_head_count * self.attention_head_dim,
),
dtype=dtype,
)
seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view(
self.batch_size, -1
)
embedding_module = RotaryEmbeddingLayer(
rope_dimension_count=self.rope_dimension_count,
max_seqlen=self.max_seqlen,
rope_freq_base=self.rope_freq_base,
)

theta = make_llama_attention_block_theta(
head_count=self.attention_head_count,
head_count_kv=self.head_count_kv,
head_dim=self.attention_head_dim,
embedding_length=self.embedding_length,
)
attention_block = PagedLlamaAttentionBlock(
theta=theta,
block_index=self.block_index,
cache=cache,
head_count=self.attention_head_count,
head_dim=self.attention_head_dim,
head_count_kv=self.head_count_kv,
rms_epsilon=self.rms_epsilon,
)
expected_result = attention_block(
input_tensor,
embedding=embedding_module,
seq_block_ids=seq_block_ids,
start_index=self.start_index,
cache_state=cache_state,
)

sharded_input_tensor = ops.replicate(input_tensor, count=self.shard_count)
sharded_seq_block_ids = ops.replicate(seq_block_ids, count=self.shard_count)
sharded_embedding_module = RotaryEmbeddingLayer(
rope_dimension_count=self.rope_dimension_count,
max_seqlen=self.max_seqlen,
rope_freq_base=self.rope_freq_base,
tensor_parallelism_size=self.shard_count,
)

theta_sharding = PagedLlamaAttentionBlockSharding(shard_count=self.shard_count)
sharded_theta = ops.reshard(theta, theta_sharding)
sharded_attention_block = PagedLlamaAttentionBlock(
theta=sharded_theta,
block_index=self.block_index,
cache=sharded_cache,
head_count=self.attention_head_count,
head_dim=self.attention_head_dim,
head_count_kv=self.head_count_kv,
rms_epsilon=self.rms_epsilon,
)
sharded_result = sharded_attention_block(
sharded_input_tensor,
embedding=sharded_embedding_module,
seq_block_ids=sharded_seq_block_ids,
start_index=self.start_index,
cache_state=sharded_cache_state,
)

actual_result = unbox_tensor(ops.unshard(sharded_result))
actual_cache_state = unbox_tensor(
ops.unshard(
sharded_cache.unflatten_page_table(sharded_cache_state)
).flatten(start_dim=1)
)

torch.testing.assert_close(actual_result, expected_result)
torch.testing.assert_close(actual_cache_state, cache_state[0])
56 changes: 56 additions & 0 deletions sharktank/tests/layers/sharded_rotary_embedding_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception


import torch

from sharktank.layers import RotaryEmbeddingLayer
from sharktank import ops
from sharktank.types import (
ShardedTensor,
SplitPrimitiveTensor,
unbox_tensor,
)

import unittest
from typing import List, Optional
import os


def test_sharded_rotary_table():
bs = 4
rope_dims = 16
heads = 8
max_seqlen = 128
rope_freq_base = None

# First we setup and get the default rotary embedding layer
xq = torch.rand((bs, max_seqlen, heads, rope_dims), dtype=torch.float)
xk = torch.rand((bs, max_seqlen, heads, rope_dims), dtype=torch.float)
default_layer = RotaryEmbeddingLayer(
rope_dimension_count=rope_dims,
max_seqlen=max_seqlen,
rope_freq_base=rope_freq_base,
)
oq, ok = default_layer(xq=xq, xk=xk, start_index=0)

# Then we can shard the same inputs and layer
xq = SplitPrimitiveTensor(ts=xq, shard_dim=2, shard_count=4)
xk = SplitPrimitiveTensor(ts=xk, shard_dim=2, shard_count=4)
shard_layer = RotaryEmbeddingLayer(
rope_dimension_count=rope_dims,
max_seqlen=max_seqlen,
rope_freq_base=rope_freq_base,
tensor_parallelism_size=4,
)
sq, sk = shard_layer(xq=xq, xk=xk, start_index=0)

# Gathering and unboxing should yield the same results
sq = ops.unshard(sq)
sk = ops.unshard(sk)

torch.testing.assert_close(sq, oq)
torch.testing.assert_close(sk, ok)
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import iree.runtime
from typing import List, Optional
import os
import pytest

vm_context: iree.runtime.VmContext = None

Expand Down Expand Up @@ -207,19 +208,26 @@ def run_test_sharded_resnet_block_with_iree(
parameters_path=parameters_path,
)
assert len(actual_result.shards) == len(expected_result.shards)
# TODO: reenable this check once numerical issues are resolved.
# See https://github.com/iree-org/iree/issues/18283
# for actual_shard, expected_shard in zip(
# actual_result.shards, expected_result.shards
# ):
# torch.testing.assert_close(
# unbox_tensor(actual_shard), unbox_tensor(expected_shard)
# )
# TODO: reenable this test once numerical issues are resolved.
# The absolute accuracy is > 0.00042. Is this good enough?
# Maybe add a test with fp64, where if the accuracy is high would give us more
# confidence that fp32 is also OK.
for actual_shard, expected_shard in zip(
actual_result.shards, expected_result.shards
):
torch.testing.assert_close(
unbox_tensor(actual_shard), unbox_tensor(expected_shard)
)

global vm_context
del vm_context


@pytest.mark.xfail(
reason="Maybe numerical issues with low accuracy.",
strict=True,
raises=AssertionError,
)
def test_sharded_resnet_block_with_iree(
mlir_path: Optional[Path],
module_path: Optional[Path],
Expand Down
1 change: 1 addition & 0 deletions tuner/examples/dispatch/dispatch_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def get_dispatch_benchmark_command(
f"--module={compiled_vmfb_path.resolve()}",
"--batch_size=1000",
"--benchmark_repetitions=3",
"--benchmark_format=json",
]

return command
Expand Down
6 changes: 2 additions & 4 deletions tuner/examples/punet/punet_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ def get_dispatch_benchmark_command(
"--hip_allow_inline_execution=true",
"--batch_size=1000",
"--benchmark_repetitions=3",
f"--benchmark_out=dispatch_{candidate_tracker.candidate_id}_bm.json",
"--benchmark_out_format=json",
"--benchmark_format=json",
]

return command
Expand Down Expand Up @@ -110,8 +109,7 @@ def get_model_benchmark_command(
"--input=2x6xf16",
"--input=1xf16",
"--benchmark_repetitions=5",
f"--benchmark_out=model_{candidate_tracker.candidate_id}_bm.json",
"--benchmark_out_format=json",
"--benchmark_format=json",
]
return command

Expand Down
Loading

0 comments on commit 4a2dbef

Please sign in to comment.