Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support moving theta and models to a specific device. #9

Merged
merged 2 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

"""Inference support for the PagedLLMV1 protocol of models."""

from typing import Optional

import math
import sys

Expand Down Expand Up @@ -50,8 +52,8 @@ def begin_batch(self, prompts: list[str]):
token_ids, seq_lens = self.tokenizer.encode(
prompts, pad_to_multiple_of=self.model.cache.pad_sequence_stride
)
token_ids = torch.tensor(token_ids)
seq_lens = torch.tensor(seq_lens)
token_ids = torch.tensor(token_ids, device=self.model.device)
seq_lens = torch.tensor(seq_lens, device=self.model.device)
if self.shared_cache_state is not None:
cache_state = self.shared_cache_state
else:
Expand Down Expand Up @@ -153,14 +155,16 @@ def prefill(self):
seq_block_ids=seq_block_ids_tensor,
cache_state=self.cache_state,
)

# TODO: Generalize the sampling and don't make it swap on/off cpu.
# TODO: Normalize the output of extract_tokens_from_logits into
# tensor [bs, 1].
tokens = torch.tensor(
model.extract_tokens_from_logits(logits, self.seq_lens)
).unsqueeze(1)
print(f":: Prefill results:\n{tokens.tolist()}")
self.add_result_token(tokens)
self.next_tokens = tokens
self.next_tokens = tokens.to(device=model.device)

def decode(self):
model = self.parent.model
Expand Down Expand Up @@ -191,15 +195,16 @@ def decode(self):
# TODO: Normalize the output of extract_tokens_from_logits into
# tensor [bs, 1].
tokens = torch.tensor(
model.extract_tokens_from_logits(logits, [1] * self.bs)
model.extract_tokens_from_logits(logits, [1] * self.bs),
device=self.parent.model.device,
).unsqueeze(1)
self.add_result_token(tokens)
self.next_tokens = tokens

def pad_block_ids(self) -> torch.Tensor:
max_length = max(len(r) for r in self.seq_block_ids)
rows = [r + (max_length - len(r)) * [0] for r in self.seq_block_ids]
return torch.tensor(rows)
return torch.tensor(rows, device=self.parent.model.device)


def main():
Expand All @@ -208,19 +213,22 @@ def main():
parser = cli.create_parser()
parser.add_argument("prompt", nargs="+", help="Prompt strings")
parser.add_argument("--kv-cache-type", default="paged", help="KV cache type")
parser.add_argument("--device", help="Torch device (or default)")
cli.add_gguf_dataset_options(parser)
cli.add_tokenizer_options(parser)
args = cli.parse(parser)

device = torch.device(args.device) if args.device else None
data_files = cli.get_gguf_data_files(args)
tokenizer = cli.get_tokenizer(args, data_files=data_files)
dataset = Dataset.load(data_files["gguf"])
dataset = Dataset.load(data_files["gguf"], device=device)
prompts = args.prompt

config = LlamaModelConfig(
hp=configs.LlamaHParams.from_gguf_props(dataset.properties),
block_seq_stride=16,
kv_cache_type=args.kv_cache_type,
device=device,
)
model = PagedLlamaModelV1(dataset.root_theta, config)
generator = TorchGenerator(model, tokenizer)
Expand Down
27 changes: 21 additions & 6 deletions sharktank/sharktank/layers/causal_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@ class BaseCausalLMModel(ThetaLayer):
"""

def __init__(
self, theta: Theta, *, context_length: int, static_context_mask: bool = True
self,
theta: Theta,
*,
context_length: int,
static_context_mask: bool = True,
device: Optional[torch.device] = None,
):
super().__init__(theta)
self.device = device
self.context_length = context_length

if static_context_mask:
Expand All @@ -36,6 +42,13 @@ def __init__(
else:
self.causal_context_mask = None

def _assert_device(self, *ts: torch.Tensor):
if self.device is not None:
for t in ts:
assert (
t.device == self.device
), f"Expected tensor to be on device {self.device} but it is on {t.device}"

def _maximally_negative_value(self, dtype):
"""Returns a maximally negative value for the given dtype.

Expand All @@ -46,7 +59,9 @@ def _maximally_negative_value(self, dtype):
def generate_causal_context_mask(self) -> torch.Tensor:
context_length = self.context_length
causal_context_mask = torch.triu(
torch.ones([context_length, context_length], dtype=torch.bool),
torch.ones(
[context_length, context_length], dtype=torch.bool, device=self.device
),
diagonal=1,
)[None, None, :, :]
return causal_context_mask
Expand All @@ -62,7 +77,7 @@ def input_mask(
The mask will be [bs, batch_seqlen] with True at any position that is
masked.
"""
range_vector = torch.arange(0, batch_seqlen, 1)
range_vector = torch.arange(0, batch_seqlen, 1, device=self.device)
matrix = torch.unsqueeze(seq_lens, dim=-1)
mask = range_vector >= matrix
return mask
Expand All @@ -74,14 +89,14 @@ def decode_attention_mask(
numeric_mask.masked_fill_(
boolean_input_mask, self._maximally_negative_value(dtype)
)
return numeric_mask.unsqueeze(1).unsqueeze(1)
return numeric_mask.unsqueeze(1).unsqueeze(1).to(self.device)

def attention_mask(
self,
input_mask: torch.Tensor,
*,
dtype: torch.dtype,
causal_context_mask: Optional[torch.Tensor] = None
causal_context_mask: Optional[torch.Tensor] = None,
):
"""Generates a causal attention mask of [1, 1, sl, sl] of activation dtype.

Expand All @@ -103,7 +118,7 @@ def attention_mask(
boolean_mask = causal_mask + input_mask[:, None, None, :]
numeric_mask = torch.zeros_like(boolean_mask, dtype=dtype)
numeric_mask.masked_fill_(boolean_mask, self._maximally_negative_value(dtype))
return numeric_mask
return numeric_mask.to(self.device)

def extract_tokens_from_logits(
self, logits: torch.Tensor, seq_lens: list[int]
Expand Down
18 changes: 15 additions & 3 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
and dims floating around everywhere.
"""

from typing import Optional

import abc
import math

Expand Down Expand Up @@ -88,12 +90,14 @@ def __init__(
attn_head_count: int,
attn_head_dim: int,
seq_length: int,
device: Optional[torch.device] = None,
):
self.block_seq_stride = block_seq_stride
self.transformer_block_count = transformer_block_count
self.attn_head_count = attn_head_count
self.attn_head_dim = attn_head_dim
self.seq_length = seq_length
self.device = device

@property
def pad_sequence_stride(self) -> int:
Expand All @@ -109,6 +113,7 @@ def allocate(self, *, bs: int, dtype: torch.dtype) -> list[torch.Tensor]:
torch.empty(
[bs, self.seq_length, self.attn_head_count, self.attn_head_dim],
dtype=dtype,
device=self.device,
)
for _ in range(2 * self.transformer_block_count)
]
Expand Down Expand Up @@ -141,6 +146,7 @@ def __init__(
attn_head_dim: int,
cache_partition_count: int = 2,
block_seq_stride: int = 16,
device: Optional[torch.device] = None,
):
self.transformer_block_count = transformer_block_count
self.attn_head_count = attn_head_count
Expand All @@ -157,6 +163,7 @@ def __init__(
self.attn_head_dim,
]
self.page_slab_flat_dim = math.prod(self.sub_page_dims)
self.device = device

def unflatten_page_table(self, state: list[torch.Tensor]) -> torch.Tensor:
"""Unflattens the 2D page table to a 6D tensor."""
Expand All @@ -181,7 +188,11 @@ def allocate(self, page_count: int, dtype: torch.dtype) -> list[torch.Tensor]:
"""Allocates tensor state for a page table for the given capacity in
pages.
"""
return [torch.empty([page_count, self.page_slab_flat_dim], dtype=dtype)]
return [
torch.empty(
[page_count, self.page_slab_flat_dim], dtype=dtype, device=self.device
)
]

def read(
self,
Expand Down Expand Up @@ -272,6 +283,7 @@ def write_timestep(
Note that this internally loops over the batch size, which cannot be
dynamic.
"""
device = self.device
page_table = self.unflatten_page_table(state) # 6D
bs, *_ = seq_positions.shape
assert len(cache_partitions) == self.cache_partition_count
Expand All @@ -285,8 +297,8 @@ def write_timestep(
cache_partition = cache_partitions[partition_index]
indices = (
page_id,
torch.tensor([transformer_block_index]),
torch.tensor([partition_index]),
torch.tensor([transformer_block_index], device=device),
torch.tensor([partition_index], device=device),
page_offset.unsqueeze(0),
)
page_table.index_put_(indices=indices, values=cache_partition[i, 0])
Expand Down
25 changes: 19 additions & 6 deletions sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Optional

import torch

from .base import BaseLayer
Expand All @@ -12,10 +14,18 @@
class RotaryEmbeddingLayer(BaseLayer):
"""Computes a rotary embedding in the style popularized by llama (RoPE)."""

def __init__(self, *, rope_dimension_count: int, max_seqlen: int):
def __init__(
self,
*,
rope_dimension_count: int,
max_seqlen: int,
device: Optional[torch.device] = None,
):
super().__init__()
self.device = device
self._table = self._create_rotary_embed_table(
max_seqlen=max_seqlen, dim=rope_dimension_count
max_seqlen=max_seqlen,
dim=rope_dimension_count,
)

def forward(self, *, xq: torch.Tensor, xk: torch.Tensor, start_index: int):
Expand Down Expand Up @@ -50,7 +60,7 @@ def compute_batch_mask(
Tensor of [bs, sl, 1, d] that will be later passed to apply_batch_mask.
"""
self.trace_tensor("rope.start_positions", start_positions)
positions_seq = torch.arange(0, batch_seq_len).unsqueeze(
positions_seq = torch.arange(0, batch_seq_len, device=self.device).unsqueeze(
0
) + start_positions.unsqueeze(1)
# Broadcast lookup to [b, ...].
Expand Down Expand Up @@ -81,12 +91,15 @@ def apply_batched_mask(
xk_out = torch.view_as_real(xk_ * mask).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)

@staticmethod
def _create_rotary_embed_table(
max_seqlen: int, dim: int, theta_value: float = 10000.0
self,
max_seqlen: int,
dim: int,
theta_value: float = 10000.0,
):
freqs = 1.0 / (
theta_value ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
theta_value
** (torch.arange(0, dim, 2, device=self.device)[: (dim // 2)].float() / dim)
stellaraccident marked this conversation as resolved.
Show resolved Hide resolved
)
t = torch.arange(max_seqlen, device=freqs.device)
freqs = torch.outer(t, freqs).float()
Expand Down
20 changes: 19 additions & 1 deletion sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class LlamaModelConfig:
# Either "paged" or "direct".
kv_cache_type: str = "paged"

# The device on which to place intermediate state.
device: Optional[torch.device] = None

def create_kv_cache(self) -> BaseKVCache:
hp = self.hp
if self.kv_cache_type == "direct":
Expand All @@ -46,6 +49,7 @@ def create_kv_cache(self) -> BaseKVCache:
attn_head_count=hp.attention_head_count_kv,
attn_head_dim=hp.attn_head_dim,
seq_length=hp.context_length,
device=self.device,
)
elif self.kv_cache_type == "paged":
return PagedKVCache(
Expand All @@ -54,6 +58,7 @@ def create_kv_cache(self) -> BaseKVCache:
attn_head_dim=hp.attn_head_dim,
cache_partition_count=2, # One for each of K/V.
block_seq_stride=self.block_seq_stride,
device=self.device,
)
else:
raise NotImplementedError(f"kv_cache_type = {self.kv_cache_type}")
Expand Down Expand Up @@ -88,7 +93,9 @@ class PagedLlamaModelV1(BaseCausalLMModel):

def __init__(self, theta: Theta, config: LlamaModelConfig):
hp = config.hp
super().__init__(theta, context_length=config.hp.context_length)
super().__init__(
theta, context_length=config.hp.context_length, device=config.device
)
self.config = config
self.hp = hp
self.cache = config.create_kv_cache()
Expand All @@ -101,6 +108,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
RotaryEmbeddingLayer(
rope_dimension_count=hp.rope_dimension_count,
max_seqlen=hp.context_length,
device=self.device,
),
)
self.add_module(
Expand Down Expand Up @@ -137,6 +145,10 @@ def prefill(
seq_block_ids: torch.Tensor,
cache_state: list[torch.Tensor],
):
self._assert_device(tokens)
self._assert_device(attention_mask)
self._assert_device(seq_block_ids)
self._assert_device(*cache_state)
h = self.token_embedding(tokens)
self.trace_tensor("llama.token_embedding", h)

Expand Down Expand Up @@ -171,6 +183,10 @@ def decode(
seq_block_ids: torch.Tensor,
cache_state: list[torch.Tensor],
):
self._assert_device(tokens)
self._assert_device(attention_mask)
self._assert_device(start_positions)
self._assert_device(*cache_state)
bs, _ = tokens.shape
# Precompute a position based mask for computing rope embeddings
# as it is the same for all blocks.
Expand All @@ -189,6 +205,7 @@ def decode(
self.hp.attn_head_dim,
],
dtype=self.hp.activation_dtype,
device=self.device,
)
xv_temp = torch.empty(
[
Expand All @@ -198,6 +215,7 @@ def decode(
self.hp.attn_head_dim,
],
dtype=self.hp.activation_dtype,
device=self.device,
)

h = self.token_embedding(tokens)
Expand Down
Loading
Loading