Skip to content

Commit

Permalink
Support moving theta and models to a specific device.
Browse files Browse the repository at this point in the history
* Threads explicit device through models.
* Implements functional InferenceTensor, Theta and Dataset transformations and uses it to implement `to(device=)`.
* Adds `--device foo` to example runner.
* With iree-org/iree-turbine#3 and supporting patches, this allows custom ops and kernels to be transparently be used on CUDA/ROCM devices (instead of just CPU).
  • Loading branch information
stellaraccident committed Apr 25, 2024
1 parent 714dedf commit 473798c
Show file tree
Hide file tree
Showing 10 changed files with 331 additions and 41 deletions.
20 changes: 14 additions & 6 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

"""Inference support for the PagedLLMV1 protocol of models."""

from typing import Optional

import math
import sys

Expand Down Expand Up @@ -50,8 +52,8 @@ def begin_batch(self, prompts: list[str]):
token_ids, seq_lens = self.tokenizer.encode(
prompts, pad_to_multiple_of=self.model.cache.pad_sequence_stride
)
token_ids = torch.tensor(token_ids)
seq_lens = torch.tensor(seq_lens)
token_ids = torch.tensor(token_ids, device=self.model.device)
seq_lens = torch.tensor(seq_lens, device=self.model.device)
if self.shared_cache_state is not None:
cache_state = self.shared_cache_state
else:
Expand Down Expand Up @@ -153,14 +155,16 @@ def prefill(self):
seq_block_ids=seq_block_ids_tensor,
cache_state=self.cache_state,
)

# TODO: Generalize the sampling and don't make it swap on/off cpu.
# TODO: Normalize the output of extract_tokens_from_logits into
# tensor [bs, 1].
tokens = torch.tensor(
model.extract_tokens_from_logits(logits, self.seq_lens)
).unsqueeze(1)
print(f":: Prefill results:\n{tokens.tolist()}")
self.add_result_token(tokens)
self.next_tokens = tokens
self.next_tokens = tokens.to(device=model.device)

def decode(self):
model = self.parent.model
Expand Down Expand Up @@ -191,15 +195,16 @@ def decode(self):
# TODO: Normalize the output of extract_tokens_from_logits into
# tensor [bs, 1].
tokens = torch.tensor(
model.extract_tokens_from_logits(logits, [1] * self.bs)
model.extract_tokens_from_logits(logits, [1] * self.bs),
device=self.parent.model.device,
).unsqueeze(1)
self.add_result_token(tokens)
self.next_tokens = tokens

def pad_block_ids(self) -> torch.Tensor:
max_length = max(len(r) for r in self.seq_block_ids)
rows = [r + (max_length - len(r)) * [0] for r in self.seq_block_ids]
return torch.tensor(rows)
return torch.tensor(rows, device=self.parent.model.device)


def main():
Expand All @@ -208,19 +213,22 @@ def main():
parser = cli.create_parser()
parser.add_argument("prompt", nargs="+", help="Prompt strings")
parser.add_argument("--kv-cache-type", default="paged", help="KV cache type")
parser.add_argument("--device", help="Torch device (or default)")
cli.add_gguf_dataset_options(parser)
cli.add_tokenizer_options(parser)
args = cli.parse(parser)

device = torch.device(args.device) if args.device else None
data_files = cli.get_gguf_data_files(args)
tokenizer = cli.get_tokenizer(args, data_files=data_files)
dataset = Dataset.load(data_files["gguf"])
dataset = Dataset.load(data_files["gguf"], device=device)
prompts = args.prompt

config = LlamaModelConfig(
hp=configs.LlamaHParams.from_gguf_props(dataset.properties),
block_seq_stride=16,
kv_cache_type=args.kv_cache_type,
device=device,
)
model = PagedLlamaModelV1(dataset.root_theta, config)
generator = TorchGenerator(model, tokenizer)
Expand Down
27 changes: 21 additions & 6 deletions sharktank/sharktank/layers/causal_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@ class BaseCausalLMModel(ThetaLayer):
"""

def __init__(
self, theta: Theta, *, context_length: int, static_context_mask: bool = True
self,
theta: Theta,
*,
context_length: int,
static_context_mask: bool = True,
device: Optional[torch.device] = None,
):
super().__init__(theta)
self.device = device
self.context_length = context_length

if static_context_mask:
Expand All @@ -36,6 +42,13 @@ def __init__(
else:
self.causal_context_mask = None

def _assert_device(self, *ts: torch.Tensor):
if self.device is not None:
for t in ts:
assert (
t.device == self.device
), f"Expected tensor to be on device {self.device} but it is on {t.device}"

def _maximally_negative_value(self, dtype):
"""Returns a maximally negative value for the given dtype.
Expand All @@ -46,7 +59,9 @@ def _maximally_negative_value(self, dtype):
def generate_causal_context_mask(self) -> torch.Tensor:
context_length = self.context_length
causal_context_mask = torch.triu(
torch.ones([context_length, context_length], dtype=torch.bool),
torch.ones(
[context_length, context_length], dtype=torch.bool, device=self.device
),
diagonal=1,
)[None, None, :, :]
return causal_context_mask
Expand All @@ -62,7 +77,7 @@ def input_mask(
The mask will be [bs, batch_seqlen] with True at any position that is
masked.
"""
range_vector = torch.arange(0, batch_seqlen, 1)
range_vector = torch.arange(0, batch_seqlen, 1, device=self.device)
matrix = torch.unsqueeze(seq_lens, dim=-1)
mask = range_vector >= matrix
return mask
Expand All @@ -74,14 +89,14 @@ def decode_attention_mask(
numeric_mask.masked_fill_(
boolean_input_mask, self._maximally_negative_value(dtype)
)
return numeric_mask.unsqueeze(1).unsqueeze(1)
return numeric_mask.unsqueeze(1).unsqueeze(1).to(self.device)

def attention_mask(
self,
input_mask: torch.Tensor,
*,
dtype: torch.dtype,
causal_context_mask: Optional[torch.Tensor] = None
causal_context_mask: Optional[torch.Tensor] = None,
):
"""Generates a causal attention mask of [1, 1, sl, sl] of activation dtype.
Expand All @@ -103,7 +118,7 @@ def attention_mask(
boolean_mask = causal_mask + input_mask[:, None, None, :]
numeric_mask = torch.zeros_like(boolean_mask, dtype=dtype)
numeric_mask.masked_fill_(boolean_mask, self._maximally_negative_value(dtype))
return numeric_mask
return numeric_mask.to(self.device)

def extract_tokens_from_logits(
self, logits: torch.Tensor, seq_lens: list[int]
Expand Down
18 changes: 15 additions & 3 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
and dims floating around everywhere.
"""

from typing import Optional

import abc
import math

Expand Down Expand Up @@ -88,12 +90,14 @@ def __init__(
attn_head_count: int,
attn_head_dim: int,
seq_length: int,
device: Optional[torch.device] = None,
):
self.block_seq_stride = block_seq_stride
self.transformer_block_count = transformer_block_count
self.attn_head_count = attn_head_count
self.attn_head_dim = attn_head_dim
self.seq_length = seq_length
self.device = device

@property
def pad_sequence_stride(self) -> int:
Expand All @@ -109,6 +113,7 @@ def allocate(self, *, bs: int, dtype: torch.dtype) -> list[torch.Tensor]:
torch.empty(
[bs, self.seq_length, self.attn_head_count, self.attn_head_dim],
dtype=dtype,
device=self.device,
)
for _ in range(2 * self.transformer_block_count)
]
Expand Down Expand Up @@ -141,6 +146,7 @@ def __init__(
attn_head_dim: int,
cache_partition_count: int = 2,
block_seq_stride: int = 16,
device: Optional[torch.device] = None,
):
self.transformer_block_count = transformer_block_count
self.attn_head_count = attn_head_count
Expand All @@ -157,6 +163,7 @@ def __init__(
self.attn_head_dim,
]
self.page_slab_flat_dim = math.prod(self.sub_page_dims)
self.device = device

def unflatten_page_table(self, state: list[torch.Tensor]) -> torch.Tensor:
"""Unflattens the 2D page table to a 6D tensor."""
Expand All @@ -181,7 +188,11 @@ def allocate(self, page_count: int, dtype: torch.dtype) -> list[torch.Tensor]:
"""Allocates tensor state for a page table for the given capacity in
pages.
"""
return [torch.empty([page_count, self.page_slab_flat_dim], dtype=dtype)]
return [
torch.empty(
[page_count, self.page_slab_flat_dim], dtype=dtype, device=self.device
)
]

def read(
self,
Expand Down Expand Up @@ -272,6 +283,7 @@ def write_timestep(
Note that this internally loops over the batch size, which cannot be
dynamic.
"""
device = self.device
page_table = self.unflatten_page_table(state) # 6D
bs, *_ = seq_positions.shape
assert len(cache_partitions) == self.cache_partition_count
Expand All @@ -285,8 +297,8 @@ def write_timestep(
cache_partition = cache_partitions[partition_index]
indices = (
page_id,
torch.tensor([transformer_block_index]),
torch.tensor([partition_index]),
torch.tensor([transformer_block_index], device=device),
torch.tensor([partition_index], device=device),
page_offset.unsqueeze(0),
)
page_table.index_put_(indices=indices, values=cache_partition[i, 0])
Expand Down
25 changes: 19 additions & 6 deletions sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Optional

import torch

from .base import BaseLayer
Expand All @@ -12,10 +14,18 @@
class RotaryEmbeddingLayer(BaseLayer):
"""Computes a rotary embedding in the style popularized by llama (RoPE)."""

def __init__(self, *, rope_dimension_count: int, max_seqlen: int):
def __init__(
self,
*,
rope_dimension_count: int,
max_seqlen: int,
device: Optional[torch.device] = None,
):
super().__init__()
self.device = device
self._table = self._create_rotary_embed_table(
max_seqlen=max_seqlen, dim=rope_dimension_count
max_seqlen=max_seqlen,
dim=rope_dimension_count,
)

def forward(self, *, xq: torch.Tensor, xk: torch.Tensor, start_index: int):
Expand Down Expand Up @@ -50,7 +60,7 @@ def compute_batch_mask(
Tensor of [bs, sl, 1, d] that will be later passed to apply_batch_mask.
"""
self.trace_tensor("rope.start_positions", start_positions)
positions_seq = torch.arange(0, batch_seq_len).unsqueeze(
positions_seq = torch.arange(0, batch_seq_len, device=self.device).unsqueeze(
0
) + start_positions.unsqueeze(1)
# Broadcast lookup to [b, ...].
Expand Down Expand Up @@ -81,12 +91,15 @@ def apply_batched_mask(
xk_out = torch.view_as_real(xk_ * mask).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)

@staticmethod
def _create_rotary_embed_table(
max_seqlen: int, dim: int, theta_value: float = 10000.0
self,
max_seqlen: int,
dim: int,
theta_value: float = 10000.0,
):
freqs = 1.0 / (
theta_value ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
theta_value
** (torch.arange(0, dim, 2, device=self.device)[: (dim // 2)].float() / dim)
)
t = torch.arange(max_seqlen, device=freqs.device)
freqs = torch.outer(t, freqs).float()
Expand Down
20 changes: 19 additions & 1 deletion sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class LlamaModelConfig:
# Either "paged" or "direct".
kv_cache_type: str = "paged"

# The device on which to place intermediate state.
device: Optional[torch.device] = None

def create_kv_cache(self) -> BaseKVCache:
hp = self.hp
if self.kv_cache_type == "direct":
Expand All @@ -46,6 +49,7 @@ def create_kv_cache(self) -> BaseKVCache:
attn_head_count=hp.attention_head_count_kv,
attn_head_dim=hp.attn_head_dim,
seq_length=hp.context_length,
device=self.device,
)
elif self.kv_cache_type == "paged":
return PagedKVCache(
Expand All @@ -54,6 +58,7 @@ def create_kv_cache(self) -> BaseKVCache:
attn_head_dim=hp.attn_head_dim,
cache_partition_count=2, # One for each of K/V.
block_seq_stride=self.block_seq_stride,
device=self.device,
)
else:
raise NotImplementedError(f"kv_cache_type = {self.kv_cache_type}")
Expand Down Expand Up @@ -88,7 +93,9 @@ class PagedLlamaModelV1(BaseCausalLMModel):

def __init__(self, theta: Theta, config: LlamaModelConfig):
hp = config.hp
super().__init__(theta, context_length=config.hp.context_length)
super().__init__(
theta, context_length=config.hp.context_length, device=config.device
)
self.config = config
self.hp = hp
self.cache = config.create_kv_cache()
Expand All @@ -101,6 +108,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
RotaryEmbeddingLayer(
rope_dimension_count=hp.rope_dimension_count,
max_seqlen=hp.context_length,
device=self.device,
),
)
self.add_module(
Expand Down Expand Up @@ -137,6 +145,10 @@ def prefill(
seq_block_ids: torch.Tensor,
cache_state: list[torch.Tensor],
):
self._assert_device(tokens)
self._assert_device(attention_mask)
self._assert_device(seq_block_ids)
self._assert_device(*cache_state)
h = self.token_embedding(tokens)
self.trace_tensor("llama.token_embedding", h)

Expand Down Expand Up @@ -171,6 +183,10 @@ def decode(
seq_block_ids: torch.Tensor,
cache_state: list[torch.Tensor],
):
self._assert_device(tokens)
self._assert_device(attention_mask)
self._assert_device(start_positions)
self._assert_device(*cache_state)
bs, _ = tokens.shape
# Precompute a position based mask for computing rope embeddings
# as it is the same for all blocks.
Expand All @@ -189,6 +205,7 @@ def decode(
self.hp.attn_head_dim,
],
dtype=self.hp.activation_dtype,
device=self.device,
)
xv_temp = torch.empty(
[
Expand All @@ -198,6 +215,7 @@ def decode(
self.hp.attn_head_dim,
],
dtype=self.hp.activation_dtype,
device=self.device,
)

h = self.token_embedding(tokens)
Expand Down
Loading

0 comments on commit 473798c

Please sign in to comment.