Skip to content

Commit

Permalink
[sharktank] Add test for sharded rotary table (#274)
Browse files Browse the repository at this point in the history
We should be able to validate the sharded rotary table via comparison
with the unsharded version. This runs the sharded and unsharded
implementations, asserting near identical results.
  • Loading branch information
rsuderman authored and eagarvey-amd committed Oct 16, 2024
1 parent 0692511 commit 63ad399
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions sharktank/tests/layers/sharded_rotary_embedding_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception


import torch

from sharktank.layers import RotaryEmbeddingLayer
from sharktank import ops
from sharktank.types import (
ShardedTensor,
SplitPrimitiveTensor,
unbox_tensor,
)

import unittest
from typing import List, Optional
import os


def test_sharded_rotary_table():
bs = 4
rope_dims = 16
heads = 8
max_seqlen = 128
rope_freq_base = None

# First we setup and get the default rotary embedding layer
xq = torch.rand((bs, max_seqlen, heads, rope_dims), dtype=torch.float)
xk = torch.rand((bs, max_seqlen, heads, rope_dims), dtype=torch.float)
default_layer = RotaryEmbeddingLayer(
rope_dimension_count=rope_dims,
max_seqlen=max_seqlen,
rope_freq_base=rope_freq_base,
)
oq, ok = default_layer(xq=xq, xk=xk, start_index=0)

# Then we can shard the same inputs and layer
xq = SplitPrimitiveTensor(ts=xq, shard_dim=2, shard_count=4)
xk = SplitPrimitiveTensor(ts=xk, shard_dim=2, shard_count=4)
shard_layer = RotaryEmbeddingLayer(
rope_dimension_count=rope_dims,
max_seqlen=max_seqlen,
rope_freq_base=rope_freq_base,
tensor_parallelism_size=4,
)
sq, sk = shard_layer(xq=xq, xk=xk, start_index=0)

# Gathering and unboxing should yield the same results
sq = ops.unshard(sq)
sk = ops.unshard(sk)

torch.testing.assert_close(sq, oq)
torch.testing.assert_close(sk, ok)

0 comments on commit 63ad399

Please sign in to comment.