Skip to content

Commit

Permalink
Add tensor parallelism to the paged llama model
Browse files Browse the repository at this point in the history
This adds one test that checks the sharded vs the unsharded
veriants.

Make `sharktank.examples.paged_llm_v1` support a tensor parallelism
CLI option.

This change adds a lot of sharded variants for PyTorch API-equivalent
ops but some of them lack auto-testing.
index_copy_, index_put_, slicing, flatten, unflatten and reshape have tests.

Check that replication and splitting of un unsharded tensor is not an
actual copy. It is probably unintuitive that when ran through PyTorch
the sharded result shares the same memory.
It may be better to change the semantics and require that it is actually
a copy. During exporting this would insert copies that the compiler
would need to optimize out.

Add test for sharded paged KV cache.
  • Loading branch information
sogartar committed Sep 25, 2024
1 parent 61eacac commit cfa705a
Show file tree
Hide file tree
Showing 27 changed files with 2,408 additions and 151 deletions.
12 changes: 11 additions & 1 deletion sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# TODO: Should be using a base class with the protocol supported.
from ..models.mixtral.mixtral import *
from ..models.llama.llama import *
from ..models.llama.sharding import shard_theta
from ..utils.debugging import trace_tensor
from ..utils.tokenizer import InferenceTokenizer

Expand All @@ -40,9 +41,9 @@ def __init__(
self.tokenizer = tokenizer
if model.cache.is_paged:
self.shared_cache_state = model.cache.paged.allocate(page_cache_size)
self.free_pages = list(range(1, page_cache_size))
else:
self.shared_cache_state = None
self.free_pages = list(range(1, 128))
self.end_token = end_token

@property
Expand Down Expand Up @@ -231,6 +232,12 @@ def main():
action="store_true",
default=False,
)
parser.add_argument(
"--tensor-parallelism-size",
type=int,
default=1,
help="How many devices are involved for tensor parallel sharding.",
)
cli.add_input_dataset_options(parser)
cli.add_tokenizer_options(parser)
args = cli.parse(parser)
Expand All @@ -252,7 +259,10 @@ def main():
activation_dtype=activation_dtype,
attention_dtype=attention_dtype,
use_hf=args.use_hf,
tensor_parallelism_size=args.tensor_parallelism_size,
)
if config.tensor_parallelism_size > 1:
dataset.root_theta = shard_theta(dataset.root_theta, config)

if config.hp.expert_count:
model = PagedMixtralModelV1(dataset.root_theta, config)
Expand Down
3 changes: 2 additions & 1 deletion sharktank/sharktank/layers/ffn_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
import torch.nn.functional as F
from .. import ops

from .base import Theta, ThetaLayer
from .linear import LinearLayer
Expand All @@ -32,7 +33,7 @@ def forward(
self,
h: torch.Tensor,
):
ffn_gate = F.silu(self.ffn_gate(h))
ffn_gate = ops.elementwise(F.silu, self.ffn_gate(h))
ffn_up = self.ffn_up(h)
ffn_down = self.ffn_down(ffn_gate * ffn_up)
return ffn_down
124 changes: 98 additions & 26 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
and dims floating around everywhere.
"""

from typing import Optional
from typing import Optional, Union, List

import abc
import math

import torch

from ..utils.debugging import trace_tensor
from ..types import SplitPrimitiveTensor, ReplicatedTensor
from .. import ops

__all__ = [
"BaseKVCache",
Expand Down Expand Up @@ -138,6 +140,11 @@ class PagedKVCache(BaseKVCache):
Note that the internal page structure matches the organization of the
model, allowing contiguous individual local reads and writes at a sub-block
granularity if indexing deeply into the structure.
When `shard_count > 1`, it would split the `attn_head_count` dimension.
The page slab is a 1D sharded split tensor.
It is reinterpreted as a 6D tensor, by working around the lack of sharded
block-cyclic sharded tensor type.
"""

def __init__(
Expand All @@ -150,30 +157,70 @@ def __init__(
block_seq_stride: int = 16,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
shard_count: int = 1,
):
self.transformer_block_count = transformer_block_count
self.attn_head_count = attn_head_count
self.attn_head_dim = attn_head_dim
self.cache_partition_count = cache_partition_count
self.block_seq_stride = block_seq_stride
self.shard_count = shard_count
if attn_head_count % shard_count != 0:
raise ValueError(
f"The attention head count {attn_head_count} must be a multiple of the tensor parallelism size {shard_count}."
)

# Some derived values based on attributes.
self.sub_page_dims = [
self.transformer_block_count,
self.cache_partition_count,
self.block_seq_stride,
self.attn_head_count,
self.attn_head_count // self.shard_count,
self.attn_head_dim,
]
self.page_slab_flat_dim = math.prod(self.sub_page_dims)
self.device = device
self.dtype = dtype

def unflatten_page_table(self, state: list[torch.Tensor]) -> torch.Tensor:
def unflatten_page_table(
self, state: list[Union[torch.Tensor, SplitPrimitiveTensor]]
) -> Union[torch.Tensor, SplitPrimitiveTensor]:
"""Unflattens the 2D page table to a 6D tensor."""
assert len(state) == 1, f"Expected 1-element state. Got: {len(state)}"
page_slab = state[0]
return page_slab.reshape(
if self.shard_count == 1:
assert not isinstance(page_slab, SplitPrimitiveTensor)
return page_slab.reshape(
[
-1,
]
+ self.sub_page_dims
)
else:
assert self.shard_count == page_slab.shard_count
shards = [
shard.reshape(
[
-1,
]
+ self.sub_page_dims
)
for shard in page_slab.shards
]
return SplitPrimitiveTensor(ts=shards, shard_dim=4)

def shard_state(
self, state: List[torch.Tensor]
) -> List[Union[torch.Tensor, SplitPrimitiveTensor]]:
"""Shard an unsharded state.
We can't just split the slab on the sub page dims.
First it needs to be reinterpreted into the actual shape.
The split the head dimension, then flatten each shard.
This is a work-around for the lack of block-cyclic sharded tensor type."""
if self.shard_count == 1:
return state

page_table = state[0].reshape(
[
-1,
self.transformer_block_count,
Expand All @@ -183,30 +230,51 @@ def unflatten_page_table(self, state: list[torch.Tensor]) -> torch.Tensor:
self.attn_head_dim,
]
)
sharded_page_table = ops.reshard_split(
page_table, dim=4, count=self.shard_count
)
shards = [
ops.flatten(shard, start_dim=1) for shard in sharded_page_table.shards
]
flat_sharded_page_table = SplitPrimitiveTensor(ts=shards, shard_dim=1)
return [flat_sharded_page_table]

@property
def pad_sequence_stride(self) -> int:
return self.block_seq_stride

def allocate(self, page_count: int) -> list[torch.Tensor]:
def allocate(
self, page_count: int
) -> list[Union[torch.Tensor, SplitPrimitiveTensor]]:
"""Allocates tensor state for a page table for the given capacity in
pages.
"""
return [
torch.empty(
[page_count, self.page_slab_flat_dim],
dtype=self.dtype,
device=self.device,
)
]
if self.shard_count == 1:
return [
torch.empty(
[page_count, self.page_slab_flat_dim],
dtype=self.dtype,
device=self.device,
)
]
else:
shards = [
torch.empty(
[page_count, self.page_slab_flat_dim],
dtype=self.dtype,
device=self.device,
)
for _ in range(self.shard_count)
]
return [SplitPrimitiveTensor(ts=shards, shard_dim=1)]

def read(
self,
state: list[torch.Tensor],
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
read_into_partitions: list[torch.Tensor],
read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
transformer_block_index: int,
page_ids: torch.Tensor,
page_ids: Union[torch.Tensor, ReplicatedTensor],
):
"""Reads cache partitions from the page table for the given page_ids.
Expand All @@ -231,7 +299,7 @@ def read(
bs,
block_seq_len,
self.block_seq_stride,
self.attn_head_count,
self.attn_head_count // self.shard_count,
self.attn_head_dim,
]

Expand All @@ -249,7 +317,9 @@ def read(
transformer_block_index * transformer_block_stride
)

def read_cache_partition(index: int, into_partition: torch.Tensor):
def read_cache_partition(
index: int, into_partition: Union[torch.Tensor, SplitPrimitiveTensor]
):
subblock_ids = (
(base_subblock_ids + index) if index > 0 else base_subblock_ids
)
Expand All @@ -262,7 +332,7 @@ def read_cache_partition(index: int, into_partition: torch.Tensor):
# a linear list.
# TODO: Can be rewritten into inplace with out= on index_select.
selected = (
torch.index_select(subblock_table, 0, subblock_ids.flatten(0, 1))
ops.index_select(subblock_table, 0, subblock_ids.flatten(0, 1))
.unflatten(0, blocked_shape[0:2])
.flatten(1, 2)
)
Expand All @@ -274,15 +344,15 @@ def read_cache_partition(index: int, into_partition: torch.Tensor):

def write_timestep(
self,
state: list[torch.Tensor],
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
# List of [bs, 1, attn_head_count, attn_head_dim]
cache_partitions: list[torch.Tensor],
cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
transformer_block_index: int,
# [bs]
seq_positions: torch.Tensor,
seq_positions: Union[torch.Tensor, ReplicatedTensor],
# [bs, max_seqlen // block_pos_stride]
page_ids: torch.Tensor,
page_ids: Union[torch.Tensor, ReplicatedTensor],
):
"""Writes a single batched timestep across all cache partitions.
Expand Down Expand Up @@ -311,11 +381,11 @@ def write_timestep(

def write(
self,
state: list[torch.Tensor],
cache_partitions: list[torch.Tensor],
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
transformer_block_index: int,
page_ids: torch.Tensor,
page_ids: Union[torch.Tensor, ReplicatedTensor],
):
"""Writes cache partitions from a linear layout to the page table.
Expand Down Expand Up @@ -348,7 +418,9 @@ def write(
transformer_block_index * transformer_block_stride
)

def write_cache_partition(index: int, part: torch.Tensor):
def write_cache_partition(
index: int, part: Union[torch.Tensor, SplitPrimitiveTensor]
):
part_block_view = part.reshape(blocked_shape)
subblock_ids = (
(base_subblock_ids + index) if index > 0 else base_subblock_ids
Expand Down
4 changes: 2 additions & 2 deletions sharktank/sharktank/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def __init__(

def forward(self, x: torch.Tensor):
orig_dtype = x.dtype
x = x.to(self.dtype)
x = ops.to(x, self.dtype)
norm = ops.rms_norm(x, self.weight, epsilon=self.epsilon)
# Will automatically upcast to the dtype of the weight, which is
# often in higher precision. Downcast back to expected.
norm = norm.to(orig_dtype)
norm = ops.to(norm, orig_dtype)
return norm
8 changes: 5 additions & 3 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .norm import RMSNormLayer
from .rotary_embedding import RotaryEmbeddingLayer
from .kv_cache import PagedKVCache
from .. import ops

__all__ = [
"PagedLlamaAttentionBlock",
Expand Down Expand Up @@ -140,7 +141,7 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
values = xv.transpose(1, 2)

# Flash attention.
attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = ops.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
self.assert_not_nan(attn_weights)

# Apply attention mask.
Expand All @@ -149,8 +150,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
# self.trace_tensor("attn_mask", attention_mask)
attn_weights = attn_weights + attention_mask

attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(xq)
attn_output = torch.matmul(attn_weights, values) # (bs, heads, slen, head_dim)
attn_weights = ops.softmax(ops.to(attn_weights, dtype=torch.float32), dim=-1)
attn_weights = ops.to(attn_weights, dtype=xq.dtype)
attn_output = ops.matmul(attn_weights, values) # (bs, heads, slen, head_dim)
attn_output = attn_output.transpose(1, 2).reshape(bs, batch_seq_len, -1)

# Project.
Expand Down
Loading

0 comments on commit cfa705a

Please sign in to comment.