Skip to content

Commit

Permalink
[WIP] add tensor parallelism to the paged llama model
Browse files Browse the repository at this point in the history
  • Loading branch information
sogartar committed Sep 20, 2024
1 parent c15b751 commit 971c5c9
Show file tree
Hide file tree
Showing 23 changed files with 1,543 additions and 112 deletions.
12 changes: 11 additions & 1 deletion sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# TODO: Should be using a base class with the protocol supported.
from ..models.mixtral.mixtral import *
from ..models.llama.llama import *
from ..models.llama.sharding import shard_theta
from ..utils.debugging import trace_tensor
from ..utils.tokenizer import InferenceTokenizer, load_tokenizer

Expand All @@ -38,9 +39,9 @@ def __init__(
self.tokenizer = tokenizer
if model.cache.is_paged:
self.shared_cache_state = model.cache.paged.allocate(page_cache_size)
self.free_pages = list(range(1, page_cache_size))
else:
self.shared_cache_state = None
self.free_pages = list(range(1, 128))
self.end_token = end_token

@property
Expand Down Expand Up @@ -218,6 +219,12 @@ def main():
help="DType to use for activations in the model",
default="float32",
)
parser.add_argument(
"--tensor-parallelism-size",
type=int,
default=1,
help="How many devices are involved for tensor parallel sharding.",
)
cli.add_input_dataset_options(parser)
cli.add_tokenizer_options(parser)
args = cli.parse(parser)
Expand All @@ -236,7 +243,10 @@ def main():
device=device,
activation_dtype=activation_dtype,
attention_dtype=activation_dtype,
tensor_parallelism_size=args.tensor_parallelism_size,
)
if config.tensor_parallelism_size > 1:
dataset.root_theta = shard_theta(dataset.root_theta, config)

if config.hp.expert_count:
model = PagedMixtralModelV1(dataset.root_theta, config)
Expand Down
112 changes: 79 additions & 33 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
and dims floating around everywhere.
"""

from typing import Optional
from typing import Optional, Union

import abc
import math

import torch

from ..utils.debugging import trace_tensor
from ..types import SplitPrimitiveTensor, ReplicatedTensor

__all__ = [
"BaseKVCache",
Expand Down Expand Up @@ -138,6 +139,8 @@ class PagedKVCache(BaseKVCache):
Note that the internal page structure matches the organization of the
model, allowing contiguous individual local reads and writes at a sub-block
granularity if indexing deeply into the structure.
`shard_count` would split the attn_head_count dimension.
"""

def __init__(
Expand All @@ -150,63 +153,102 @@ def __init__(
block_seq_stride: int = 16,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
shard_count: int = 1,
):
self.transformer_block_count = transformer_block_count
self.attn_head_count = attn_head_count
self.attn_head_dim = attn_head_dim
self.cache_partition_count = cache_partition_count
self.block_seq_stride = block_seq_stride
self.shard_count = shard_count
if attn_head_count % shard_count != 0:
raise ValueError(
f"The attention head count {attn_head_count} must be a multiple of the tensor parallelism size {shard_count}."
)

# Some derived values based on attributes.
self.sub_page_dims = [
self.transformer_block_count,
self.cache_partition_count,
self.block_seq_stride,
self.attn_head_count,
self.attn_head_count // self.shard_count,
self.attn_head_dim,
]
self.page_slab_flat_dim = math.prod(self.sub_page_dims)
self.device = device
self.dtype = dtype

def unflatten_page_table(self, state: list[torch.Tensor]) -> torch.Tensor:
def unflatten_page_table(
self, state: list[Union[torch.Tensor, SplitPrimitiveTensor]]
) -> Union[torch.Tensor, SplitPrimitiveTensor]:
"""Unflattens the 2D page table to a 6D tensor."""
assert len(state) == 1, f"Expected 1-element state. Got: {len(state)}"
page_slab = state[0]
return page_slab.reshape(
[
-1,
self.transformer_block_count,
self.cache_partition_count,
self.block_seq_stride,
self.attn_head_count,
self.attn_head_dim,
if self.shard_count == 1:
assert not isinstance(page_slab, SplitPrimitiveTensor)
return page_slab.reshape(
[
-1,
self.transformer_block_count,
self.cache_partition_count,
self.block_seq_stride,
self.attn_head_count,
self.attn_head_dim,
]
)
else:
assert self.shard_count == page_slab.shard_count
shards = [
shard.reshape(
[
-1,
self.transformer_block_count,
self.cache_partition_count,
self.block_seq_stride,
self.attn_head_count // self.shard_count,
self.attn_head_dim,
]
)
for shard in page_slab.shards
]
)
return SplitPrimitiveTensor(ts=shards, shard_dim=4)

@property
def pad_sequence_stride(self) -> int:
return self.block_seq_stride

def allocate(self, page_count: int) -> list[torch.Tensor]:
def allocate(
self, page_count: int
) -> list[Union[torch.Tensor, SplitPrimitiveTensor]]:
"""Allocates tensor state for a page table for the given capacity in
pages.
"""
return [
torch.empty(
[page_count, self.page_slab_flat_dim],
dtype=self.dtype,
device=self.device,
)
]
if self.shard_count == 1:
return [
torch.empty(
[page_count, self.page_slab_flat_dim],
dtype=self.dtype,
device=self.device,
)
]
else:
shards = [
torch.empty(
[page_count, self.page_slab_flat_dim],
dtype=self.dtype,
device=self.device,
)
for _ in range(self.shard_count)
]
return [SplitPrimitiveTensor(ts=shards, shard_dim=1)]

def read(
self,
state: list[torch.Tensor],
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
read_into_partitions: list[torch.Tensor],
read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
transformer_block_index: int,
page_ids: torch.Tensor,
page_ids: Union[torch.Tensor, ReplicatedTensor],
):
"""Reads cache partitions from the page table for the given page_ids.
Expand All @@ -231,7 +273,7 @@ def read(
bs,
block_seq_len,
self.block_seq_stride,
self.attn_head_count,
self.attn_head_count // self.shard_count,
self.attn_head_dim,
]

Expand All @@ -249,7 +291,9 @@ def read(
transformer_block_index * transformer_block_stride
)

def read_cache_partition(index: int, into_partition: torch.Tensor):
def read_cache_partition(
index: int, into_partition: Union[torch.Tensor, SplitPrimitiveTensor]
):
subblock_ids = (
(base_subblock_ids + index) if index > 0 else base_subblock_ids
)
Expand All @@ -274,15 +318,15 @@ def read_cache_partition(index: int, into_partition: torch.Tensor):

def write_timestep(
self,
state: list[torch.Tensor],
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
# List of [bs, 1, attn_head_count, attn_head_dim]
cache_partitions: list[torch.Tensor],
cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
transformer_block_index: int,
# [bs]
seq_positions: torch.Tensor,
seq_positions: Union[torch.Tensor, ReplicatedTensor],
# [bs, max_seqlen // block_pos_stride]
page_ids: torch.Tensor,
page_ids: Union[torch.Tensor, ReplicatedTensor],
):
"""Writes a single batched timestep across all cache partitions.
Expand Down Expand Up @@ -311,11 +355,11 @@ def write_timestep(

def write(
self,
state: list[torch.Tensor],
cache_partitions: list[torch.Tensor],
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
transformer_block_index: int,
page_ids: torch.Tensor,
page_ids: Union[torch.Tensor, ReplicatedTensor],
):
"""Writes cache partitions from a linear layout to the page table.
Expand Down Expand Up @@ -348,7 +392,9 @@ def write(
transformer_block_index * transformer_block_stride
)

def write_cache_partition(index: int, part: torch.Tensor):
def write_cache_partition(
index: int, part: Union[torch.Tensor, SplitPrimitiveTensor]
):
part_block_view = part.reshape(blocked_shape)
subblock_ids = (
(base_subblock_ids + index) if index > 0 else base_subblock_ids
Expand Down
4 changes: 2 additions & 2 deletions sharktank/sharktank/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def __init__(

def forward(self, x: torch.Tensor):
orig_dtype = x.dtype
x = x.to(self.dtype)
x = ops.to(x, self.dtype)
norm = ops.rms_norm(x, self.weight, epsilon=self.epsilon)
# Will automatically upcast to the dtype of the weight, which is
# often in higher precision. Downcast back to expected.
norm = norm.to(orig_dtype)
norm = ops.to(norm, orig_dtype)
return norm
8 changes: 5 additions & 3 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .norm import RMSNormLayer
from .rotary_embedding import RotaryEmbeddingLayer
from .kv_cache import PagedKVCache
from .. import ops

__all__ = [
"PagedLlamaAttentionBlock",
Expand Down Expand Up @@ -140,7 +141,7 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
values = xv.transpose(1, 2)

# Flash attention.
attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = ops.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
self.assert_not_nan(attn_weights)

# Apply attention mask.
Expand All @@ -149,8 +150,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
# self.trace_tensor("attn_mask", attention_mask)
attn_weights = attn_weights + attention_mask

attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(xq)
attn_output = torch.matmul(attn_weights, values) # (bs, heads, slen, head_dim)
attn_weights = ops.softmax(ops.to(attn_weights, dtype=torch.float32), dim=-1)
attn_weights = ops.to(attn_weights, dtype=xq.dtype)
attn_output = ops.matmul(attn_weights, values) # (bs, heads, slen, head_dim)
attn_output = attn_output.transpose(1, 2).reshape(bs, batch_seq_len, -1)

# Project.
Expand Down
56 changes: 53 additions & 3 deletions sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torch

from .base import BaseLayer
from .. import ops
from ..types import SplitPrimitiveTensor, ReplicatedTensor, unbox_tensor


class RotaryEmbeddingLayer(BaseLayer):
Expand All @@ -23,6 +25,7 @@ def __init__(
device: Optional[torch.device] = None,
use_hf: bool = False,
static_tables: bool = True,
tensor_parallelism_size: int = 1,
):
super().__init__()
# Force static_tables until compiler limitations are solved.
Expand All @@ -33,9 +36,10 @@ def __init__(
self.max_seqlen = max_seqlen
self.use_hf = use_hf
self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0
self.tensor_parallelism_size = tensor_parallelism_size
if static_tables:
self.register_buffer(
"static_rotary_embed_table", self._create_rotary_embed_table()
ops.module_register_buffer(
self, "static_rotary_embed_table", self._create_rotary_embed_table()
)
else:
self.static_rotary_embed_table = None
Expand All @@ -48,6 +52,48 @@ def rotary_embed_table(self):
return self.static_rotary_embed_table

def forward(self, *, xq: torch.Tensor, xk: torch.Tensor, start_index: int):
if isinstance(xq, SplitPrimitiveTensor):
assert (
isinstance(xk, SplitPrimitiveTensor)
and xq.shard_count == xk.shard_count
and xk.shard_dim == xq.shard_dim
)
assert (
isinstance(self.rotary_embed_table, ReplicatedTensor)
and xq.shard_count == self.rotary_embed_table.shard_count
)
xqk_shards = [
self.forward_unsharded(
xq=unbox_tensor(xq_shard),
xk=unbox_tensor(xk_shard),
start_index=start_index,
rotary_embed_table=unbox_tensor(rotary_embed_table_shard),
)
for xq_shard, xk_shard, rotary_embed_table_shard in zip(
xq.shards, xk.shards, self.rotary_embed_table.shards
)
]
xq_shards = [xqk[0] for xqk in xqk_shards]
xk_shards = [xqk[1] for xqk in xqk_shards]
xq = SplitPrimitiveTensor(ts=xq_shards, shard_dim=xq.shard_dim)
xk = SplitPrimitiveTensor(ts=xk_shards, shard_dim=xk.shard_dim)
return xq, xk
else:
return self.forward_unsharded(
xq=xq,
xk=xk,
start_index=start_index,
rotary_embed_table=self.rotary_embed_table,
)

def forward_unsharded(
self,
*,
xq: torch.Tensor,
xk: torch.Tensor,
start_index: int,
rotary_embed_table: torch.Tensor,
):
# xq_, xk_ shape: bs, sl, _, dim
# freqs_cis shape: max_sl, dim

Expand Down Expand Up @@ -97,7 +143,7 @@ def create_ordering_tensor(dim):
_, sl, _, dim = xq_.shape

# Offset the table based on starting position.
freqs_cis = self.rotary_embed_table[start_index : start_index + sl, :]
freqs_cis = rotary_embed_table[start_index : start_index + sl, :]
assert freqs_cis.shape[-1] == dim
assert (
freqs_cis.shape[0] >= sl
Expand Down Expand Up @@ -200,4 +246,8 @@ def _create_rotary_embed_table(
else torch.polar(torch.ones_like(freqs), freqs)
)

if self.tensor_parallelism_size > 1:
# Replicate across all devices, the data is not a lot and the computation is cheap.
freqs_cis = ops.replicate(freqs_cis, self.tensor_parallelism_size)

return freqs_cis
Loading

0 comments on commit 971c5c9

Please sign in to comment.