Skip to content

Commit

Permalink
Merge branch 'main' into shortfin_numa
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaraccident authored Sep 25, 2024
2 parents a814f4c + a9d3d41 commit d5da6e0
Show file tree
Hide file tree
Showing 29 changed files with 2,444 additions and 163 deletions.
18 changes: 18 additions & 0 deletions libshortfin/.readthedocs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
version: "2"

build:
os: "ubuntu-24.04"
tools:
python: "3.12"
apt_packages:
- clang
jobs:
pre_build:
- CC=clang CXX=clang++ python -m pip install -v libshortfin/

python:
install:
- requirements: libshortfin/docs/requirements.txt

sphinx:
configuration: libshortfin/docs/conf.py
12 changes: 10 additions & 2 deletions libshortfin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,16 @@ option(SHORTFIN_ENABLE_TRACING "Enable runtime tracing for iree and shortfin" OF
set(SHORTFIN_IREE_SOURCE_DIR "" CACHE FILEPATH "Path to IREE source")

# Options for building static or dynamic libraries.
option(SHORTFIN_BUILD_STATIC "Builds static libraries" OFF)
option(SHORTFIN_BUILD_DYNAMIC "Builds dynamic libraries" ON)
# Default to dynamic linking, unless on Windows.
# TODO(#211): Unify the defaults once Windows dynamic linking issues are fixed.
set(SHORTFIN_BUILD_STATIC_DEFAULT OFF)
set(SHORTFIN_BUILD_DYNAMIC_DEFAULT ON)
if(WIN32)
set(SHORTFIN_BUILD_STATIC_DEFAULT ON)
set(SHORTFIN_BUILD_DYNAMIC_DEFAULT OFF)
endif()
option(SHORTFIN_BUILD_STATIC "Builds static libraries" ${SHORTFIN_BUILD_STATIC_DEFAULT})
option(SHORTFIN_BUILD_DYNAMIC "Builds dynamic libraries" ${SHORTFIN_BUILD_DYNAMIC_DEFAULT})
cmake_dependent_option(SHORTFIN_LINK_DYNAMIC "Links internal binaries against static libshortfin.a" ON "SHORTFIN_BUILD_DYNAMIC" OFF)
if(NOT SHORTFIN_BUILD_STATIC AND NOT SHORTFIN_BUILD_DYNAMIC)
message(FATAL_ERROR "One of SHORTFIN_BUILD_STATIC or SHORTFIN_BUILD_DYNAMIC must be ON")
Expand Down
12 changes: 11 additions & 1 deletion sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# TODO: Should be using a base class with the protocol supported.
from ..models.mixtral.mixtral import *
from ..models.llama.llama import *
from ..models.llama.sharding import shard_theta
from ..utils.debugging import trace_tensor
from ..utils.tokenizer import InferenceTokenizer

Expand All @@ -40,9 +41,9 @@ def __init__(
self.tokenizer = tokenizer
if model.cache.is_paged:
self.shared_cache_state = model.cache.paged.allocate(page_cache_size)
self.free_pages = list(range(1, page_cache_size))
else:
self.shared_cache_state = None
self.free_pages = list(range(1, 128))
self.end_token = end_token

@property
Expand Down Expand Up @@ -231,6 +232,12 @@ def main():
action="store_true",
default=False,
)
parser.add_argument(
"--tensor-parallelism-size",
type=int,
default=1,
help="How many devices are involved for tensor parallel sharding.",
)
cli.add_input_dataset_options(parser)
cli.add_tokenizer_options(parser)
args = cli.parse(parser)
Expand All @@ -252,7 +259,10 @@ def main():
activation_dtype=activation_dtype,
attention_dtype=attention_dtype,
use_hf=args.use_hf,
tensor_parallelism_size=args.tensor_parallelism_size,
)
if config.tensor_parallelism_size > 1:
dataset.root_theta = shard_theta(dataset.root_theta, config)

if config.hp.expert_count:
model = PagedMixtralModelV1(dataset.root_theta, config)
Expand Down
3 changes: 2 additions & 1 deletion sharktank/sharktank/layers/ffn_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
import torch.nn.functional as F
from .. import ops

from .base import Theta, ThetaLayer
from .linear import LinearLayer
Expand All @@ -32,7 +33,7 @@ def forward(
self,
h: torch.Tensor,
):
ffn_gate = F.silu(self.ffn_gate(h))
ffn_gate = ops.elementwise(F.silu, self.ffn_gate(h))
ffn_up = self.ffn_up(h)
ffn_down = self.ffn_down(ffn_gate * ffn_up)
return ffn_down
124 changes: 98 additions & 26 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
and dims floating around everywhere.
"""

from typing import Optional
from typing import Optional, Union, List

import abc
import math

import torch

from ..utils.debugging import trace_tensor
from ..types import SplitPrimitiveTensor, ReplicatedTensor
from .. import ops

__all__ = [
"BaseKVCache",
Expand Down Expand Up @@ -138,6 +140,11 @@ class PagedKVCache(BaseKVCache):
Note that the internal page structure matches the organization of the
model, allowing contiguous individual local reads and writes at a sub-block
granularity if indexing deeply into the structure.
When `shard_count > 1`, it would split the `attn_head_count` dimension.
The page slab is a 1D sharded split tensor.
It is reinterpreted as a 6D tensor, by working around the lack of sharded
block-cyclic sharded tensor type.
"""

def __init__(
Expand All @@ -150,30 +157,70 @@ def __init__(
block_seq_stride: int = 16,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
shard_count: int = 1,
):
self.transformer_block_count = transformer_block_count
self.attn_head_count = attn_head_count
self.attn_head_dim = attn_head_dim
self.cache_partition_count = cache_partition_count
self.block_seq_stride = block_seq_stride
self.shard_count = shard_count
if attn_head_count % shard_count != 0:
raise ValueError(
f"The attention head count {attn_head_count} must be a multiple of the tensor parallelism size {shard_count}."
)

# Some derived values based on attributes.
self.sub_page_dims = [
self.transformer_block_count,
self.cache_partition_count,
self.block_seq_stride,
self.attn_head_count,
self.attn_head_count // self.shard_count,
self.attn_head_dim,
]
self.page_slab_flat_dim = math.prod(self.sub_page_dims)
self.device = device
self.dtype = dtype

def unflatten_page_table(self, state: list[torch.Tensor]) -> torch.Tensor:
def unflatten_page_table(
self, state: list[Union[torch.Tensor, SplitPrimitiveTensor]]
) -> Union[torch.Tensor, SplitPrimitiveTensor]:
"""Unflattens the 2D page table to a 6D tensor."""
assert len(state) == 1, f"Expected 1-element state. Got: {len(state)}"
page_slab = state[0]
return page_slab.reshape(
if self.shard_count == 1:
assert not isinstance(page_slab, SplitPrimitiveTensor)
return page_slab.reshape(
[
-1,
]
+ self.sub_page_dims
)
else:
assert self.shard_count == page_slab.shard_count
shards = [
shard.reshape(
[
-1,
]
+ self.sub_page_dims
)
for shard in page_slab.shards
]
return SplitPrimitiveTensor(ts=shards, shard_dim=4)

def shard_state(
self, state: List[torch.Tensor]
) -> List[Union[torch.Tensor, SplitPrimitiveTensor]]:
"""Shard an unsharded state.
We can't just split the slab on the sub page dims.
First it needs to be reinterpreted into the actual shape.
The split the head dimension, then flatten each shard.
This is a work-around for the lack of block-cyclic sharded tensor type."""
if self.shard_count == 1:
return state

page_table = state[0].reshape(
[
-1,
self.transformer_block_count,
Expand All @@ -183,30 +230,51 @@ def unflatten_page_table(self, state: list[torch.Tensor]) -> torch.Tensor:
self.attn_head_dim,
]
)
sharded_page_table = ops.reshard_split(
page_table, dim=4, count=self.shard_count
)
shards = [
ops.flatten(shard, start_dim=1) for shard in sharded_page_table.shards
]
flat_sharded_page_table = SplitPrimitiveTensor(ts=shards, shard_dim=1)
return [flat_sharded_page_table]

@property
def pad_sequence_stride(self) -> int:
return self.block_seq_stride

def allocate(self, page_count: int) -> list[torch.Tensor]:
def allocate(
self, page_count: int
) -> list[Union[torch.Tensor, SplitPrimitiveTensor]]:
"""Allocates tensor state for a page table for the given capacity in
pages.
"""
return [
torch.empty(
[page_count, self.page_slab_flat_dim],
dtype=self.dtype,
device=self.device,
)
]
if self.shard_count == 1:
return [
torch.empty(
[page_count, self.page_slab_flat_dim],
dtype=self.dtype,
device=self.device,
)
]
else:
shards = [
torch.empty(
[page_count, self.page_slab_flat_dim],
dtype=self.dtype,
device=self.device,
)
for _ in range(self.shard_count)
]
return [SplitPrimitiveTensor(ts=shards, shard_dim=1)]

def read(
self,
state: list[torch.Tensor],
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
read_into_partitions: list[torch.Tensor],
read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
transformer_block_index: int,
page_ids: torch.Tensor,
page_ids: Union[torch.Tensor, ReplicatedTensor],
):
"""Reads cache partitions from the page table for the given page_ids.
Expand All @@ -231,7 +299,7 @@ def read(
bs,
block_seq_len,
self.block_seq_stride,
self.attn_head_count,
self.attn_head_count // self.shard_count,
self.attn_head_dim,
]

Expand All @@ -249,7 +317,9 @@ def read(
transformer_block_index * transformer_block_stride
)

def read_cache_partition(index: int, into_partition: torch.Tensor):
def read_cache_partition(
index: int, into_partition: Union[torch.Tensor, SplitPrimitiveTensor]
):
subblock_ids = (
(base_subblock_ids + index) if index > 0 else base_subblock_ids
)
Expand All @@ -262,7 +332,7 @@ def read_cache_partition(index: int, into_partition: torch.Tensor):
# a linear list.
# TODO: Can be rewritten into inplace with out= on index_select.
selected = (
torch.index_select(subblock_table, 0, subblock_ids.flatten(0, 1))
ops.index_select(subblock_table, 0, subblock_ids.flatten(0, 1))
.unflatten(0, blocked_shape[0:2])
.flatten(1, 2)
)
Expand All @@ -274,15 +344,15 @@ def read_cache_partition(index: int, into_partition: torch.Tensor):

def write_timestep(
self,
state: list[torch.Tensor],
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
# List of [bs, 1, attn_head_count, attn_head_dim]
cache_partitions: list[torch.Tensor],
cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
transformer_block_index: int,
# [bs]
seq_positions: torch.Tensor,
seq_positions: Union[torch.Tensor, ReplicatedTensor],
# [bs, max_seqlen // block_pos_stride]
page_ids: torch.Tensor,
page_ids: Union[torch.Tensor, ReplicatedTensor],
):
"""Writes a single batched timestep across all cache partitions.
Expand Down Expand Up @@ -311,11 +381,11 @@ def write_timestep(

def write(
self,
state: list[torch.Tensor],
cache_partitions: list[torch.Tensor],
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
transformer_block_index: int,
page_ids: torch.Tensor,
page_ids: Union[torch.Tensor, ReplicatedTensor],
):
"""Writes cache partitions from a linear layout to the page table.
Expand Down Expand Up @@ -348,7 +418,9 @@ def write(
transformer_block_index * transformer_block_stride
)

def write_cache_partition(index: int, part: torch.Tensor):
def write_cache_partition(
index: int, part: Union[torch.Tensor, SplitPrimitiveTensor]
):
part_block_view = part.reshape(blocked_shape)
subblock_ids = (
(base_subblock_ids + index) if index > 0 else base_subblock_ids
Expand Down
4 changes: 2 additions & 2 deletions sharktank/sharktank/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def __init__(

def forward(self, x: torch.Tensor):
orig_dtype = x.dtype
x = x.to(self.dtype)
x = ops.to(x, self.dtype)
norm = ops.rms_norm(x, self.weight, epsilon=self.epsilon)
# Will automatically upcast to the dtype of the weight, which is
# often in higher precision. Downcast back to expected.
norm = norm.to(orig_dtype)
norm = ops.to(norm, orig_dtype)
return norm
8 changes: 5 additions & 3 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .norm import RMSNormLayer
from .rotary_embedding import RotaryEmbeddingLayer
from .kv_cache import PagedKVCache
from .. import ops

__all__ = [
"PagedLlamaAttentionBlock",
Expand Down Expand Up @@ -140,7 +141,7 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
values = xv.transpose(1, 2)

# Flash attention.
attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = ops.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
self.assert_not_nan(attn_weights)

# Apply attention mask.
Expand All @@ -149,8 +150,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
# self.trace_tensor("attn_mask", attention_mask)
attn_weights = attn_weights + attention_mask

attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(xq)
attn_output = torch.matmul(attn_weights, values) # (bs, heads, slen, head_dim)
attn_weights = ops.softmax(ops.to(attn_weights, dtype=torch.float32), dim=-1)
attn_weights = ops.to(attn_weights, dtype=xq.dtype)
attn_output = ops.matmul(attn_weights, values) # (bs, heads, slen, head_dim)
attn_output = attn_output.transpose(1, 2).reshape(bs, batch_seq_len, -1)

# Project.
Expand Down
Loading

0 comments on commit d5da6e0

Please sign in to comment.