Skip to content

Commit

Permalink
Add tensor parallelism to the paged llama model
Browse files Browse the repository at this point in the history
This adds one test the checks the sharded vs the unsharded
veriants.
Some numerics in the cache writing and decode step are wrong.

This change adds a lot of sharded variants for PyTorch API equivalent
ops but they lack auto-testing.
  • Loading branch information
sogartar committed Sep 23, 2024
1 parent c15b751 commit 0f82101
Show file tree
Hide file tree
Showing 23 changed files with 1,936 additions and 151 deletions.
12 changes: 11 additions & 1 deletion sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# TODO: Should be using a base class with the protocol supported.
from ..models.mixtral.mixtral import *
from ..models.llama.llama import *
from ..models.llama.sharding import shard_theta
from ..utils.debugging import trace_tensor
from ..utils.tokenizer import InferenceTokenizer, load_tokenizer

Expand All @@ -38,9 +39,9 @@ def __init__(
self.tokenizer = tokenizer
if model.cache.is_paged:
self.shared_cache_state = model.cache.paged.allocate(page_cache_size)
self.free_pages = list(range(1, page_cache_size))
else:
self.shared_cache_state = None
self.free_pages = list(range(1, 128))
self.end_token = end_token

@property
Expand Down Expand Up @@ -218,6 +219,12 @@ def main():
help="DType to use for activations in the model",
default="float32",
)
parser.add_argument(
"--tensor-parallelism-size",
type=int,
default=1,
help="How many devices are involved for tensor parallel sharding.",
)
cli.add_input_dataset_options(parser)
cli.add_tokenizer_options(parser)
args = cli.parse(parser)
Expand All @@ -236,7 +243,10 @@ def main():
device=device,
activation_dtype=activation_dtype,
attention_dtype=activation_dtype,
tensor_parallelism_size=args.tensor_parallelism_size,
)
if config.tensor_parallelism_size > 1:
dataset.root_theta = shard_theta(dataset.root_theta, config)

if config.hp.expert_count:
model = PagedMixtralModelV1(dataset.root_theta, config)
Expand Down
3 changes: 2 additions & 1 deletion sharktank/sharktank/layers/ffn_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
import torch.nn.functional as F
from .. import ops

from .base import Theta, ThetaLayer
from .linear import LinearLayer
Expand All @@ -32,7 +33,7 @@ def forward(
self,
h: torch.Tensor,
):
ffn_gate = F.silu(self.ffn_gate(h))
ffn_gate = ops.elementwise(F.silu, self.ffn_gate(h))
ffn_up = self.ffn_up(h)
ffn_down = self.ffn_down(ffn_gate * ffn_up)
return ffn_down
118 changes: 84 additions & 34 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
and dims floating around everywhere.
"""

from typing import Optional
from typing import Optional, Union

import abc
import math

import torch

from ..utils.debugging import trace_tensor
from ..types import SplitPrimitiveTensor, ReplicatedTensor
from .. import ops

__all__ = [
"BaseKVCache",
Expand Down Expand Up @@ -138,6 +140,11 @@ class PagedKVCache(BaseKVCache):
Note that the internal page structure matches the organization of the
model, allowing contiguous individual local reads and writes at a sub-block
granularity if indexing deeply into the structure.
When `shard_count > 1`, it would split the `attn_head_count` dimension.
The page slab is a 1D sharded split tensor.
It is reinterpreted as a 6D tensor, by working around the lack of sharded
block-cyclic sharded tensor type.
"""

def __init__(
Expand All @@ -150,63 +157,102 @@ def __init__(
block_seq_stride: int = 16,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
shard_count: int = 1,
):
self.transformer_block_count = transformer_block_count
self.attn_head_count = attn_head_count
self.attn_head_dim = attn_head_dim
self.cache_partition_count = cache_partition_count
self.block_seq_stride = block_seq_stride
self.shard_count = shard_count
if attn_head_count % shard_count != 0:
raise ValueError(
f"The attention head count {attn_head_count} must be a multiple of the tensor parallelism size {shard_count}."
)

# Some derived values based on attributes.
self.sub_page_dims = [
self.transformer_block_count,
self.cache_partition_count,
self.block_seq_stride,
self.attn_head_count,
self.attn_head_count // self.shard_count,
self.attn_head_dim,
]
self.page_slab_flat_dim = math.prod(self.sub_page_dims)
self.device = device
self.dtype = dtype

def unflatten_page_table(self, state: list[torch.Tensor]) -> torch.Tensor:
def unflatten_page_table(
self, state: list[Union[torch.Tensor, SplitPrimitiveTensor]]
) -> Union[torch.Tensor, SplitPrimitiveTensor]:
"""Unflattens the 2D page table to a 6D tensor."""
assert len(state) == 1, f"Expected 1-element state. Got: {len(state)}"
page_slab = state[0]
return page_slab.reshape(
[
-1,
self.transformer_block_count,
self.cache_partition_count,
self.block_seq_stride,
self.attn_head_count,
self.attn_head_dim,
if self.shard_count == 1:
assert not isinstance(page_slab, SplitPrimitiveTensor)
return page_slab.reshape(
[
-1,
self.transformer_block_count,
self.cache_partition_count,
self.block_seq_stride,
self.attn_head_count,
self.attn_head_dim,
]
)
else:
assert self.shard_count == page_slab.shard_count
shards = [
shard.reshape(
[
-1,
self.transformer_block_count,
self.cache_partition_count,
self.block_seq_stride,
self.attn_head_count // self.shard_count,
self.attn_head_dim,
]
)
for shard in page_slab.shards
]
)
return SplitPrimitiveTensor(ts=shards, shard_dim=4)

@property
def pad_sequence_stride(self) -> int:
return self.block_seq_stride

def allocate(self, page_count: int) -> list[torch.Tensor]:
def allocate(
self, page_count: int
) -> list[Union[torch.Tensor, SplitPrimitiveTensor]]:
"""Allocates tensor state for a page table for the given capacity in
pages.
"""
return [
torch.empty(
[page_count, self.page_slab_flat_dim],
dtype=self.dtype,
device=self.device,
)
]
if self.shard_count == 1:
return [
torch.empty(
[page_count, self.page_slab_flat_dim],
dtype=self.dtype,
device=self.device,
)
]
else:
shards = [
torch.empty(
[page_count, self.page_slab_flat_dim],
dtype=self.dtype,
device=self.device,
)
for _ in range(self.shard_count)
]
return [SplitPrimitiveTensor(ts=shards, shard_dim=1)]

def read(
self,
state: list[torch.Tensor],
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
read_into_partitions: list[torch.Tensor],
read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
transformer_block_index: int,
page_ids: torch.Tensor,
page_ids: Union[torch.Tensor, ReplicatedTensor],
):
"""Reads cache partitions from the page table for the given page_ids.
Expand All @@ -231,7 +277,7 @@ def read(
bs,
block_seq_len,
self.block_seq_stride,
self.attn_head_count,
self.attn_head_count // self.shard_count,
self.attn_head_dim,
]

Expand All @@ -249,7 +295,9 @@ def read(
transformer_block_index * transformer_block_stride
)

def read_cache_partition(index: int, into_partition: torch.Tensor):
def read_cache_partition(
index: int, into_partition: Union[torch.Tensor, SplitPrimitiveTensor]
):
subblock_ids = (
(base_subblock_ids + index) if index > 0 else base_subblock_ids
)
Expand All @@ -262,7 +310,7 @@ def read_cache_partition(index: int, into_partition: torch.Tensor):
# a linear list.
# TODO: Can be rewritten into inplace with out= on index_select.
selected = (
torch.index_select(subblock_table, 0, subblock_ids.flatten(0, 1))
ops.index_select(subblock_table, 0, subblock_ids.flatten(0, 1))
.unflatten(0, blocked_shape[0:2])
.flatten(1, 2)
)
Expand All @@ -274,15 +322,15 @@ def read_cache_partition(index: int, into_partition: torch.Tensor):

def write_timestep(
self,
state: list[torch.Tensor],
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
# List of [bs, 1, attn_head_count, attn_head_dim]
cache_partitions: list[torch.Tensor],
cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
transformer_block_index: int,
# [bs]
seq_positions: torch.Tensor,
seq_positions: Union[torch.Tensor, ReplicatedTensor],
# [bs, max_seqlen // block_pos_stride]
page_ids: torch.Tensor,
page_ids: Union[torch.Tensor, ReplicatedTensor],
):
"""Writes a single batched timestep across all cache partitions.
Expand Down Expand Up @@ -311,11 +359,11 @@ def write_timestep(

def write(
self,
state: list[torch.Tensor],
cache_partitions: list[torch.Tensor],
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
transformer_block_index: int,
page_ids: torch.Tensor,
page_ids: Union[torch.Tensor, ReplicatedTensor],
):
"""Writes cache partitions from a linear layout to the page table.
Expand Down Expand Up @@ -348,7 +396,9 @@ def write(
transformer_block_index * transformer_block_stride
)

def write_cache_partition(index: int, part: torch.Tensor):
def write_cache_partition(
index: int, part: Union[torch.Tensor, SplitPrimitiveTensor]
):
part_block_view = part.reshape(blocked_shape)
subblock_ids = (
(base_subblock_ids + index) if index > 0 else base_subblock_ids
Expand Down
4 changes: 2 additions & 2 deletions sharktank/sharktank/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def __init__(

def forward(self, x: torch.Tensor):
orig_dtype = x.dtype
x = x.to(self.dtype)
x = ops.to(x, self.dtype)
norm = ops.rms_norm(x, self.weight, epsilon=self.epsilon)
# Will automatically upcast to the dtype of the weight, which is
# often in higher precision. Downcast back to expected.
norm = norm.to(orig_dtype)
norm = ops.to(norm, orig_dtype)
return norm
8 changes: 5 additions & 3 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .norm import RMSNormLayer
from .rotary_embedding import RotaryEmbeddingLayer
from .kv_cache import PagedKVCache
from .. import ops

__all__ = [
"PagedLlamaAttentionBlock",
Expand Down Expand Up @@ -140,7 +141,7 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
values = xv.transpose(1, 2)

# Flash attention.
attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = ops.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
self.assert_not_nan(attn_weights)

# Apply attention mask.
Expand All @@ -149,8 +150,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
# self.trace_tensor("attn_mask", attention_mask)
attn_weights = attn_weights + attention_mask

attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(xq)
attn_output = torch.matmul(attn_weights, values) # (bs, heads, slen, head_dim)
attn_weights = ops.softmax(ops.to(attn_weights, dtype=torch.float32), dim=-1)
attn_weights = ops.to(attn_weights, dtype=xq.dtype)
attn_output = ops.matmul(attn_weights, values) # (bs, heads, slen, head_dim)
attn_output = attn_output.transpose(1, 2).reshape(bs, batch_seq_len, -1)

# Project.
Expand Down
Loading

0 comments on commit 0f82101

Please sign in to comment.