Skip to content

Commit

Permalink
Merge branch 'shortfin-system-selection' of https://github.com/renxid…
Browse files Browse the repository at this point in the history
…a/SHARK-Platform into shortfin-system-selection
  • Loading branch information
renxida committed Oct 15, 2024
2 parents 5a8dbd7 + beb3505 commit 6697077
Show file tree
Hide file tree
Showing 7 changed files with 330 additions and 48 deletions.
40 changes: 26 additions & 14 deletions sharktank/sharktank/kernels/mmt_block_scaled_offset_q4.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,38 +37,43 @@ def select(self, ksel: KernelSelection):
m_desc = ksel.arg_tensor(3) # Shape [N, K // BLOCK_SIZE, 1]

# a arg
*batch_dims, a_m, a_k = a_desc.t.shape
*a_batch_dims, a_m, a_k = a_desc.t.shape
torch._check(
a_desc.t.dtype.is_floating_point,
lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'a': Expected floating point (got {a_desc.t.dtype})",
)
torch._check(
len(batch_dims) == 1,
len(a_batch_dims) == 1,
lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'a': Expected 3d tensor (got {a_desc.t.shape})",
)

# qs arg
qs_n, qs_group0, qs_bs_div_2, *rest = qs_desc.t.shape
*qs_batch_dims, qs_n, qs_group0, qs_bs_div_2 = qs_desc.t.shape
torch._check(
len(rest) == 0 and (qs_group0 * qs_bs_div_2 * 2) == a_k,
(
len(qs_batch_dims) == 0
or len(qs_batch_dims) == 1
and qs_batch_dims == a_batch_dims
)
and (qs_group0 * qs_bs_div_2 * 2) == a_k,
lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'qs': Incorrect shape (got {qs_desc.t.shape})",
)
block_size = qs_bs_div_2 * 2

# d arg
d_n, d_group0, d_one, *rest = d_desc.t.shape
*d_batch_dims, d_n, d_group0, d_one = d_desc.t.shape
torch._check(
len(rest) == 0
d_batch_dims == qs_batch_dims
and (d_group0 * block_size) == a_k
and d_one == 1
and d_n == qs_n,
lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'd': Incorrect shape (got {d_desc.t.shape})",
)

# m arg
m_n, m_group0, m_one, *rest = m_desc.t.shape
*m_batch_dims, m_n, m_group0, m_one = m_desc.t.shape
torch._check(
len(rest) == 0
m_batch_dims == qs_batch_dims
and (m_group0 * block_size) == a_k
and m_one == 1
and m_n == qs_n,
Expand All @@ -81,12 +86,17 @@ def select(self, ksel: KernelSelection):

# Specialize on K, N, BS
a_desc.specialize_dims(-1)
qs_desc.specialize_all_dims()
d_desc.specialize_all_dims()
m_desc.specialize_all_dims()
if len(qs_batch_dims) == 0:
qs_desc.specialize_all_dims()
d_desc.specialize_all_dims()
m_desc.specialize_all_dims()
else:
qs_desc.specialize_dims(1, 2, 3)
d_desc.specialize_dims(1, 2, 3)
m_desc.specialize_dims(1, 2, 3)

# Shape batch..., m, n
c_desc = ksel.return_new_tensor(batch_dims + [a_m, d_n], dtype=a_desc.t.dtype)
c_desc = ksel.return_new_tensor(a_batch_dims + [a_m, d_n], dtype=a_desc.t.dtype)
c_desc.specialize_dims(-1)

def generate(self, ksel: KernelSelection, kb: KernelBuilder):
Expand All @@ -99,13 +109,14 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):

rank = a_tensor_type.rank
k = a_tensor_type.get_dim_size(rank - 1)
n, group0, bs_i8 = qs_tensor_type.shape
*qs_batch_dims, n, group0, bs_i8 = qs_tensor_type.shape
batched_rhs = len(qs_batch_dims) == 1
bs = bs_i8 * 2 # 2 nibbles per byte.
a_type_str = str(a_tensor_type.element_type)
scale_type_str = str(d_tensor_type.element_type)

template_file = "mmt_block_scaled_offset_q4_unsigned.mlir"
target_function_name = f"sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{bs}_{a_type_str}"
target_function_name = f"sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{bs}_{a_type_str}_{batched_rhs}"

target_function = inline_template_function(
kb,
Expand All @@ -118,5 +129,6 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
group0=group0,
a_type=a_type_str,
scale_type=scale_type_str,
batched_rhs=batched_rhs,
)
kb.yield_results(*call_function(target_function, *kb.arg_bindings))
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,25 @@
!accum_type = {{accum_type}}
!a_tensor_type = tensor<?x?x{{k}}x!a_type>
!aexp_tensor_type = tensor<?x?x{{group0}}x{{bs}}x!a_type>
{% if batched_rhs %}
!qs_raw_tensor_type = tensor<?x{{n}}x{{group0}}x{{bs_i8}}xi8>
!qs_tensor_type = tensor<?x{{n}}x{{group0}}x{{bs}}x!lowp_type>
!d_tensor_type = tensor<?x{{n}}x{{group0}}x1x!scale_type>
!m_tensor_type = tensor<?x{{n}}x{{group0}}x1x!scale_type>
!b_grouped_tensor_type = tensor<?x{{n}}x{{group0}}x{{bs}}x!a_type>
{% else %}
!qs_raw_tensor_type = tensor<{{n}}x{{group0}}x{{bs_i8}}xi8>
!qs_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!lowp_type>
!d_tensor_type = tensor<{{n}}x{{group0}}x1x!scale_type>
!m_tensor_type = tensor<{{n}}x{{group0}}x1x!scale_type>
!b_grouped_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!a_type>
{% endif %}
!accum_tensor_type = tensor<?x?x{{n}}x!accum_type>
!c_tensor_type = tensor<?x?x{{n}}x!a_type>
!b_grouped_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!a_type>

module {

util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_{{bs}}_{{a_type}}(
util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_{{bs}}_{{a_type}}_{{batched_rhs}}(
%a: !a_tensor_type, %d: !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type)
-> !c_tensor_type {
%zero = arith.constant 0.0: !accum_type
Expand All @@ -32,17 +40,31 @@ util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_
%m_dim = tensor.dim %a, %c1 : !a_tensor_type

// Cast qs_raw from i8 to lowp type.
{% if batched_rhs %}
%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{ %batch0_dim } -> !qs_tensor_type{ %batch0_dim }
%b_grouped = tensor.empty(%batch0_dim) : !b_grouped_tensor_type
{% else %}
%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type -> !qs_tensor_type
%b_grouped = tensor.empty() : !b_grouped_tensor_type
{% endif %}

// Dequantize.
%b_grouped = tensor.empty() : !b_grouped_tensor_type
%b_grouped_dequant = linalg.generic {
{% if batched_rhs %}
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
{% else %}
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"] }
{% endif %}
ins(%d, %m, %qs : !d_tensor_type, !m_tensor_type, !qs_tensor_type)
outs(%b_grouped : !b_grouped_tensor_type) {
^bb0(%d_element: !scale_type, %m_element: !scale_type, %q_element: !lowp_type, %out: !a_type):
Expand Down Expand Up @@ -70,7 +92,7 @@ util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_
indexing_maps = [
// d0 = b, d1 = m, d2 = n, d3 = group0 (r), d4 = block (r)
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> ({% if batched_rhs %}d0,{% endif %} d2, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] }
ins(%aexp, %b_grouped_dequant : !aexp_tensor_type, !b_grouped_tensor_type)
Expand Down
70 changes: 46 additions & 24 deletions sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from collections import namedtuple
from typing import Optional, Union

import torch
Expand All @@ -21,17 +22,20 @@ def __init__(
*,
rope_dimension_count: int,
max_seqlen: int,
rope_freq_base: float,
rope_freq_base: Optional[float],
device: Optional[torch.device] = None,
use_hf: bool = False,
static_tables: bool = True,
static_tables: bool = False,
use_table: bool = True,
tensor_parallelism_size: int = 1,
):
super().__init__()
self.device = device
self.rope_dimension_count = rope_dimension_count
self.max_seqlen = max_seqlen
self.use_hf = use_hf
self.static_tables = static_tables
self.use_table = use_table

self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0
self.tensor_parallelism_size = tensor_parallelism_size
Expand All @@ -44,10 +48,16 @@ def __init__(

@property
def rotary_embed_table(self):
if self.static_rotary_embed_table is None:
if self.use_table:
if self.static_tables:
return self.static_rotary_embed_table
return self._create_rotary_embed_table()
else:
return self.static_rotary_embed_table

if self.tensor_parallelism_size == 1:
return None

nt = namedtuple("replicated_tensor", ["shards"])
return nt([None] * self.tensor_parallelism_size)

def forward(
self,
Expand Down Expand Up @@ -96,7 +106,7 @@ def forward_unsharded(
xq: torch.Tensor,
xk: torch.Tensor,
start_index: int,
rotary_embed_table: torch.Tensor,
rotary_embed_table: Optional[torch.Tensor],
):
# xq_, xk_ shape: bs, sl, _, dim
# freqs_cis shape: max_sl, dim
Expand Down Expand Up @@ -142,12 +152,18 @@ def create_ordering_tensor(dim):
xq = xq[..., create_interleaved_tensor(xq.shape[-1])]
xk = xk[..., create_interleaved_tensor(xq.shape[-1])]

xq_ = torch.view_as_complex(xq.reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.reshape(*xk.shape[:-1], -1, 2))
xq_ = torch.view_as_complex(xq.unflatten(-1, (-1, 2)))
xk_ = torch.view_as_complex(xk.unflatten(-1, (-1, 2)))
_, sl, _, dim = xq_.shape

# Offset the table based on starting position.
freqs_cis = rotary_embed_table[start_index : start_index + sl, :]
if self.use_table:
freqs_cis = rotary_embed_table[start_index : start_index + sl, :]
else:
freqs_cis = torch.arange(start_index, start_index + sl, device=xq.device)
freqs_cis = self._compute_rotary_embed_table(freqs_cis)
freqs_cis = self._replicate(freqs_cis)

assert freqs_cis.shape[-1] == dim
assert (
freqs_cis.shape[0] >= sl
Expand Down Expand Up @@ -206,7 +222,13 @@ def compute_batch_mask(
) + start_positions.unsqueeze(1)
# Broadcast lookup to [b, ...].
self.trace_tensor("rope.positions_seq", positions_seq)
freqs_cis = self.rotary_embed_table[positions_seq]

if self.use_table:
freqs_cis = self.rotary_embed_table[positions_seq]
else:
shape = positions_seq.shape
freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten())
freqs_cis = freqs_cis.unflatten(0, shape)

# Unsqueeze a unit dim for attention heads.
broadcast_freqs_cis = freqs_cis.unsqueeze(2)
Expand All @@ -225,10 +247,6 @@ def apply_batched_mask(
and xq.shard_count == xk.shard_count
and xk.shard_dim == xq.shard_dim
)
assert (
isinstance(self.rotary_embed_table, ReplicatedTensor)
and xq.shard_count == self.rotary_embed_table.shard_count
)
assert (
isinstance(mask, ReplicatedTensor)
and mask.shard_count == xq.shard_count
Expand Down Expand Up @@ -263,24 +281,20 @@ def apply_batched_mask_unsharded(
"""
# xq_, xk_ shape: bs, sl, _, dim
# freqs_cis shape: max_sl, dim
xq_ = torch.view_as_complex(xq.reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.reshape(*xk.shape[:-1], -1, 2))
xq_ = torch.view_as_complex(xq.unflatten(-1, (-1, 2)))
xk_ = torch.view_as_complex(xk.unflatten(-1, (-1, 2)))
_, sl, _, dim = xq_.shape

xq_out = torch.view_as_real(xq_ * mask).flatten(3)
xk_out = torch.view_as_real(xk_ * mask).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)

def _create_rotary_embed_table(
self,
):
def _compute_rotary_embed_table(self, t):
dim = self.rope_dimension_count
max_seqlen = self.max_seqlen
freqs = 1.0 / (
self.rope_freq_base
** (torch.arange(0, dim, 2, device=self.device)[: (dim // 2)].float() / dim)
** (torch.arange(0, dim, 2, device=t.device)[: (dim // 2)].float() / dim)
)
t = torch.arange(max_seqlen, device=freqs.device)
freqs = torch.outer(t, freqs).float()

freqs_cis = (
Expand All @@ -289,8 +303,16 @@ def _create_rotary_embed_table(
else torch.polar(torch.ones_like(freqs), freqs)
)

return freqs_cis

def _create_rotary_embed_table(self):
t = torch.arange(self.max_seqlen, device=self.device)
freqs_cis = self._compute_rotary_embed_table(t)
return self._replicate(freqs_cis)

def _replicate(self, t):
if self.tensor_parallelism_size > 1:
# Replicate across all devices, the data is not a lot and the computation is cheap.
freqs_cis = ops.replicate(freqs_cis, self.tensor_parallelism_size)
t = ops.replicate(t, self.tensor_parallelism_size)

return freqs_cis
return t
2 changes: 1 addition & 1 deletion sharktank/sharktank/models/llama/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Specifications describing how blocks/layers of llama are sharded."""
"""Specifications describing how the Llama model is sharded."""

from ...types.sharding import *
from ...types import Theta
Expand Down
Loading

0 comments on commit 6697077

Please sign in to comment.