From 473798c6d86af825a9400bcdbfc84616c146b467 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 24 Apr 2024 19:06:07 -0700 Subject: [PATCH] Support moving theta and models to a specific device. * Threads explicit device through models. * Implements functional InferenceTensor, Theta and Dataset transformations and uses it to implement `to(device=)`. * Adds `--device foo` to example runner. * With https://github.com/iree-org/iree-turbine/pull/3 and supporting patches, this allows custom ops and kernels to be transparently be used on CUDA/ROCM devices (instead of just CPU). --- sharktank/sharktank/examples/paged_llm_v1.py | 20 ++-- sharktank/sharktank/layers/causal_llm.py | 27 ++++-- sharktank/sharktank/layers/kv_cache.py | 18 +++- .../sharktank/layers/rotary_embedding.py | 25 +++-- sharktank/sharktank/models/llama/llama.py | 20 +++- sharktank/sharktank/types/tensors.py | 92 +++++++++++++++++-- sharktank/sharktank/types/theta.py | 68 ++++++++++++-- sharktank/sharktank/utils/tokenizer.py | 7 +- sharktank/tests/types/dataset_test.py | 36 ++++++++ sharktank/tests/types/tensors_test.py | 59 ++++++++++++ 10 files changed, 331 insertions(+), 41 deletions(-) create mode 100644 sharktank/tests/types/tensors_test.py diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index 48c126b9d..6ceb1cf90 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -6,6 +6,8 @@ """Inference support for the PagedLLMV1 protocol of models.""" +from typing import Optional + import math import sys @@ -50,8 +52,8 @@ def begin_batch(self, prompts: list[str]): token_ids, seq_lens = self.tokenizer.encode( prompts, pad_to_multiple_of=self.model.cache.pad_sequence_stride ) - token_ids = torch.tensor(token_ids) - seq_lens = torch.tensor(seq_lens) + token_ids = torch.tensor(token_ids, device=self.model.device) + seq_lens = torch.tensor(seq_lens, device=self.model.device) if self.shared_cache_state is not None: cache_state = self.shared_cache_state else: @@ -153,6 +155,8 @@ def prefill(self): seq_block_ids=seq_block_ids_tensor, cache_state=self.cache_state, ) + + # TODO: Generalize the sampling and don't make it swap on/off cpu. # TODO: Normalize the output of extract_tokens_from_logits into # tensor [bs, 1]. tokens = torch.tensor( @@ -160,7 +164,7 @@ def prefill(self): ).unsqueeze(1) print(f":: Prefill results:\n{tokens.tolist()}") self.add_result_token(tokens) - self.next_tokens = tokens + self.next_tokens = tokens.to(device=model.device) def decode(self): model = self.parent.model @@ -191,7 +195,8 @@ def decode(self): # TODO: Normalize the output of extract_tokens_from_logits into # tensor [bs, 1]. tokens = torch.tensor( - model.extract_tokens_from_logits(logits, [1] * self.bs) + model.extract_tokens_from_logits(logits, [1] * self.bs), + device=self.parent.model.device, ).unsqueeze(1) self.add_result_token(tokens) self.next_tokens = tokens @@ -199,7 +204,7 @@ def decode(self): def pad_block_ids(self) -> torch.Tensor: max_length = max(len(r) for r in self.seq_block_ids) rows = [r + (max_length - len(r)) * [0] for r in self.seq_block_ids] - return torch.tensor(rows) + return torch.tensor(rows, device=self.parent.model.device) def main(): @@ -208,19 +213,22 @@ def main(): parser = cli.create_parser() parser.add_argument("prompt", nargs="+", help="Prompt strings") parser.add_argument("--kv-cache-type", default="paged", help="KV cache type") + parser.add_argument("--device", help="Torch device (or default)") cli.add_gguf_dataset_options(parser) cli.add_tokenizer_options(parser) args = cli.parse(parser) + device = torch.device(args.device) if args.device else None data_files = cli.get_gguf_data_files(args) tokenizer = cli.get_tokenizer(args, data_files=data_files) - dataset = Dataset.load(data_files["gguf"]) + dataset = Dataset.load(data_files["gguf"], device=device) prompts = args.prompt config = LlamaModelConfig( hp=configs.LlamaHParams.from_gguf_props(dataset.properties), block_seq_stride=16, kv_cache_type=args.kv_cache_type, + device=device, ) model = PagedLlamaModelV1(dataset.root_theta, config) generator = TorchGenerator(model, tokenizer) diff --git a/sharktank/sharktank/layers/causal_llm.py b/sharktank/sharktank/layers/causal_llm.py index bf8042df0..e3752da30 100644 --- a/sharktank/sharktank/layers/causal_llm.py +++ b/sharktank/sharktank/layers/causal_llm.py @@ -24,9 +24,15 @@ class BaseCausalLMModel(ThetaLayer): """ def __init__( - self, theta: Theta, *, context_length: int, static_context_mask: bool = True + self, + theta: Theta, + *, + context_length: int, + static_context_mask: bool = True, + device: Optional[torch.device] = None, ): super().__init__(theta) + self.device = device self.context_length = context_length if static_context_mask: @@ -36,6 +42,13 @@ def __init__( else: self.causal_context_mask = None + def _assert_device(self, *ts: torch.Tensor): + if self.device is not None: + for t in ts: + assert ( + t.device == self.device + ), f"Expected tensor to be on device {self.device} but it is on {t.device}" + def _maximally_negative_value(self, dtype): """Returns a maximally negative value for the given dtype. @@ -46,7 +59,9 @@ def _maximally_negative_value(self, dtype): def generate_causal_context_mask(self) -> torch.Tensor: context_length = self.context_length causal_context_mask = torch.triu( - torch.ones([context_length, context_length], dtype=torch.bool), + torch.ones( + [context_length, context_length], dtype=torch.bool, device=self.device + ), diagonal=1, )[None, None, :, :] return causal_context_mask @@ -62,7 +77,7 @@ def input_mask( The mask will be [bs, batch_seqlen] with True at any position that is masked. """ - range_vector = torch.arange(0, batch_seqlen, 1) + range_vector = torch.arange(0, batch_seqlen, 1, device=self.device) matrix = torch.unsqueeze(seq_lens, dim=-1) mask = range_vector >= matrix return mask @@ -74,14 +89,14 @@ def decode_attention_mask( numeric_mask.masked_fill_( boolean_input_mask, self._maximally_negative_value(dtype) ) - return numeric_mask.unsqueeze(1).unsqueeze(1) + return numeric_mask.unsqueeze(1).unsqueeze(1).to(self.device) def attention_mask( self, input_mask: torch.Tensor, *, dtype: torch.dtype, - causal_context_mask: Optional[torch.Tensor] = None + causal_context_mask: Optional[torch.Tensor] = None, ): """Generates a causal attention mask of [1, 1, sl, sl] of activation dtype. @@ -103,7 +118,7 @@ def attention_mask( boolean_mask = causal_mask + input_mask[:, None, None, :] numeric_mask = torch.zeros_like(boolean_mask, dtype=dtype) numeric_mask.masked_fill_(boolean_mask, self._maximally_negative_value(dtype)) - return numeric_mask + return numeric_mask.to(self.device) def extract_tokens_from_logits( self, logits: torch.Tensor, seq_lens: list[int] diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index 7cb0cae2e..a1bfa55a4 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -11,6 +11,8 @@ and dims floating around everywhere. """ +from typing import Optional + import abc import math @@ -88,12 +90,14 @@ def __init__( attn_head_count: int, attn_head_dim: int, seq_length: int, + device: Optional[torch.device] = None, ): self.block_seq_stride = block_seq_stride self.transformer_block_count = transformer_block_count self.attn_head_count = attn_head_count self.attn_head_dim = attn_head_dim self.seq_length = seq_length + self.device = device @property def pad_sequence_stride(self) -> int: @@ -109,6 +113,7 @@ def allocate(self, *, bs: int, dtype: torch.dtype) -> list[torch.Tensor]: torch.empty( [bs, self.seq_length, self.attn_head_count, self.attn_head_dim], dtype=dtype, + device=self.device, ) for _ in range(2 * self.transformer_block_count) ] @@ -141,6 +146,7 @@ def __init__( attn_head_dim: int, cache_partition_count: int = 2, block_seq_stride: int = 16, + device: Optional[torch.device] = None, ): self.transformer_block_count = transformer_block_count self.attn_head_count = attn_head_count @@ -157,6 +163,7 @@ def __init__( self.attn_head_dim, ] self.page_slab_flat_dim = math.prod(self.sub_page_dims) + self.device = device def unflatten_page_table(self, state: list[torch.Tensor]) -> torch.Tensor: """Unflattens the 2D page table to a 6D tensor.""" @@ -181,7 +188,11 @@ def allocate(self, page_count: int, dtype: torch.dtype) -> list[torch.Tensor]: """Allocates tensor state for a page table for the given capacity in pages. """ - return [torch.empty([page_count, self.page_slab_flat_dim], dtype=dtype)] + return [ + torch.empty( + [page_count, self.page_slab_flat_dim], dtype=dtype, device=self.device + ) + ] def read( self, @@ -272,6 +283,7 @@ def write_timestep( Note that this internally loops over the batch size, which cannot be dynamic. """ + device = self.device page_table = self.unflatten_page_table(state) # 6D bs, *_ = seq_positions.shape assert len(cache_partitions) == self.cache_partition_count @@ -285,8 +297,8 @@ def write_timestep( cache_partition = cache_partitions[partition_index] indices = ( page_id, - torch.tensor([transformer_block_index]), - torch.tensor([partition_index]), + torch.tensor([transformer_block_index], device=device), + torch.tensor([partition_index], device=device), page_offset.unsqueeze(0), ) page_table.index_put_(indices=indices, values=cache_partition[i, 0]) diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py index fc4ef9e4d..20f7c2f91 100644 --- a/sharktank/sharktank/layers/rotary_embedding.py +++ b/sharktank/sharktank/layers/rotary_embedding.py @@ -4,6 +4,8 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from typing import Optional + import torch from .base import BaseLayer @@ -12,10 +14,18 @@ class RotaryEmbeddingLayer(BaseLayer): """Computes a rotary embedding in the style popularized by llama (RoPE).""" - def __init__(self, *, rope_dimension_count: int, max_seqlen: int): + def __init__( + self, + *, + rope_dimension_count: int, + max_seqlen: int, + device: Optional[torch.device] = None, + ): super().__init__() + self.device = device self._table = self._create_rotary_embed_table( - max_seqlen=max_seqlen, dim=rope_dimension_count + max_seqlen=max_seqlen, + dim=rope_dimension_count, ) def forward(self, *, xq: torch.Tensor, xk: torch.Tensor, start_index: int): @@ -50,7 +60,7 @@ def compute_batch_mask( Tensor of [bs, sl, 1, d] that will be later passed to apply_batch_mask. """ self.trace_tensor("rope.start_positions", start_positions) - positions_seq = torch.arange(0, batch_seq_len).unsqueeze( + positions_seq = torch.arange(0, batch_seq_len, device=self.device).unsqueeze( 0 ) + start_positions.unsqueeze(1) # Broadcast lookup to [b, ...]. @@ -81,12 +91,15 @@ def apply_batched_mask( xk_out = torch.view_as_real(xk_ * mask).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) - @staticmethod def _create_rotary_embed_table( - max_seqlen: int, dim: int, theta_value: float = 10000.0 + self, + max_seqlen: int, + dim: int, + theta_value: float = 10000.0, ): freqs = 1.0 / ( - theta_value ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + theta_value + ** (torch.arange(0, dim, 2, device=self.device)[: (dim // 2)].float() / dim) ) t = torch.arange(max_seqlen, device=freqs.device) freqs = torch.outer(t, freqs).float() diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 6829e0ed2..672759b1b 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -37,6 +37,9 @@ class LlamaModelConfig: # Either "paged" or "direct". kv_cache_type: str = "paged" + # The device on which to place intermediate state. + device: Optional[torch.device] = None + def create_kv_cache(self) -> BaseKVCache: hp = self.hp if self.kv_cache_type == "direct": @@ -46,6 +49,7 @@ def create_kv_cache(self) -> BaseKVCache: attn_head_count=hp.attention_head_count_kv, attn_head_dim=hp.attn_head_dim, seq_length=hp.context_length, + device=self.device, ) elif self.kv_cache_type == "paged": return PagedKVCache( @@ -54,6 +58,7 @@ def create_kv_cache(self) -> BaseKVCache: attn_head_dim=hp.attn_head_dim, cache_partition_count=2, # One for each of K/V. block_seq_stride=self.block_seq_stride, + device=self.device, ) else: raise NotImplementedError(f"kv_cache_type = {self.kv_cache_type}") @@ -88,7 +93,9 @@ class PagedLlamaModelV1(BaseCausalLMModel): def __init__(self, theta: Theta, config: LlamaModelConfig): hp = config.hp - super().__init__(theta, context_length=config.hp.context_length) + super().__init__( + theta, context_length=config.hp.context_length, device=config.device + ) self.config = config self.hp = hp self.cache = config.create_kv_cache() @@ -101,6 +108,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): RotaryEmbeddingLayer( rope_dimension_count=hp.rope_dimension_count, max_seqlen=hp.context_length, + device=self.device, ), ) self.add_module( @@ -137,6 +145,10 @@ def prefill( seq_block_ids: torch.Tensor, cache_state: list[torch.Tensor], ): + self._assert_device(tokens) + self._assert_device(attention_mask) + self._assert_device(seq_block_ids) + self._assert_device(*cache_state) h = self.token_embedding(tokens) self.trace_tensor("llama.token_embedding", h) @@ -171,6 +183,10 @@ def decode( seq_block_ids: torch.Tensor, cache_state: list[torch.Tensor], ): + self._assert_device(tokens) + self._assert_device(attention_mask) + self._assert_device(start_positions) + self._assert_device(*cache_state) bs, _ = tokens.shape # Precompute a position based mask for computing rope embeddings # as it is the same for all blocks. @@ -189,6 +205,7 @@ def decode( self.hp.attn_head_dim, ], dtype=self.hp.activation_dtype, + device=self.device, ) xv_temp = torch.empty( [ @@ -198,6 +215,7 @@ def decode( self.hp.attn_head_dim, ], dtype=self.hp.activation_dtype, + device=self.device, ) h = self.token_embedding(tokens) diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index fc37faacb..e7a2c2ea1 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -4,14 +4,17 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Any, Optional, Union, TypeVar, Generic, Type +from typing import Any, Callable, Optional, Union, TypeVar, Generic, Type from abc import ABC, abstractmethod from dataclasses import dataclass import torch -from shark_turbine.aot import ParameterArchiveBuilder +from shark_turbine.aot import ( + ExternalTensorTrait, + ParameterArchiveBuilder, +) __all__ = [ "register_quantized_layout", @@ -30,8 +33,7 @@ class QuantizedLayout(ABC): @abstractmethod - def dequant(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: - ... + def dequant(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: ... @classmethod @abstractmethod @@ -46,8 +48,7 @@ def create( shape: list[int], metadata: Optional[dict[str, MetaDataValueType]], planes: dict[str, torch.Tensor], - ) -> "QuantizedLayout": - ... + ) -> "QuantizedLayout": ... @property @abstractmethod @@ -185,6 +186,48 @@ def add_to_archive( """Adds this tensor to the global archive.""" ... + def transform_globals( + self, *transforms: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] + ) -> "InferenceTensor": + """Appplies transformation functions to the InferenceTensors backing + globals. + + Each transformation must produce a new dict of a form that the subclass + can handle. Practically, this means that placement and layout related + changes are always allowed, while more invasive changes (like dtype) + are more case by case. + + Returns a new InferenceTensor, mutated. + """ + prev_globals = self.globals + for transform in transforms: + next_globals = transform(prev_globals) + # Copy any metadata from prior to next. + for k, prev_t in prev_globals.items(): + new_t = next_globals.get(k) + if new_t is None: + continue + if new_t is not prev_t: + ext_trait = ExternalTensorTrait.get(prev_t) + if ext_trait is not None: + ext_trait.set(new_t) + prev_globals = next_globals + return self._clone_with_globals(prev_globals) + + def to( + self, *, device: Optional[Union[str, torch.device]] = None + ) -> "InferenceTensor": + return self.transform_globals( + lambda d: {k: t.to(device=device) for k, t in d.items()} + ) + + def _clone_with_globals( + self, new_globals: dict[str, torch.Tensor] + ) -> "InferenceTensor": + raise NotImplementedError( + f"InferenceTensor {type(self)} does not implement _clone_with_globals" + ) + REGISTERED_INFERENCE_TENSOR_CLASSES: dict[str, Type[InferenceTensor]] = {} @@ -263,6 +306,11 @@ def add_to_archive( builder.add_tensor(self.name, self._data) return InferenceTensorMetadata(self.serialized_name(), {"": self.name}) + def _clone_with_globals( + self, new_globals: dict[str, torch.Tensor] + ) -> "InferenceTensor": + return DefaultPrimitiveTensor(name=self.name, data=new_globals[self.name]) + def __repr__(self): return f"PrimitiveTensor({self.name}, {self.shape}, {self._data.dtype})" @@ -281,8 +329,7 @@ def __init__( self.layout_type = layout_type @abstractmethod - def unpack(self) -> QuantizedLayoutT: - ... + def unpack(self) -> QuantizedLayoutT: ... def to_planar(self) -> "PlanarQuantizedTensor": """Converts this QuantizedTensor to a generic planar form. @@ -333,6 +380,35 @@ def globals(self) -> dict[str, torch.Tensor]: planes = self.layout.planes return {f"{global_name}:{k}": v for k, v in planes.items()} + def _clone_with_globals( + self, new_globals: dict[str, torch.Tensor] + ) -> "InferenceTensor": + # Clone it via layout serialization. + serialized_name = self.layout.serialized_name() + global_prefix = f"{self.name}:" + orig_planes = self.layout.planes + new_planes = {} + for plane_name in orig_planes.keys(): + # Planes are stored in the globals dict with the inference + # tensor's name and colon prepended, so look up that way. + new_planes[plane_name] = new_globals[f"{global_prefix}{plane_name}"] + + # Create a new layout via the serialization adaptor. + try: + layout_clazz = REGISTERED_LAYOUT_CLASSES[serialized_name] + except KeyError: + raise IOError( + f"Cannot deserialize PlanarQuantizedTensor because of unregistered layout " + f"{serialized_name}" + ) + new_layout = layout_clazz.create(self.shape, self.layout.metadata, new_planes) + + return PlanarQuantizedTensor( + name=self.name, + shape=self.shape, + layout=new_layout, + ) + @classmethod def create( cls, diff --git a/sharktank/sharktank/types/theta.py b/sharktank/sharktank/types/theta.py index 2a5095368..386a30694 100644 --- a/sharktank/sharktank/types/theta.py +++ b/sharktank/sharktank/types/theta.py @@ -4,7 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Any, Callable, Optional, Union, Collection +from typing import Any, Callable, Optional, Union, Collection, Sequence import json from pathlib import Path @@ -54,18 +54,41 @@ # implementation for operating on the InferenceTensors. ################################################################################ +InferenceTensorTransform = Callable[[InferenceTensor], InferenceTensor] + + +class InferenceTensorTransforms: + """Container for common transformations on InferenceTensors.""" + + @staticmethod + def identity() -> InferenceTensorTransform: + return lambda x: x + + @staticmethod + def to_device( + device: Optional[Union[str, torch.device]] + ) -> InferenceTensorTransform: + if device is not None: + return lambda it: it.to(device=device) + return InferenceTensorTransforms.identity() + class Theta: """Subset of parameter tensors used for inference.""" def __init__( self, - tensors: dict, + tensors: Union[Sequence[InferenceTensor], dict[str, InferenceTensor]], *, ops: Optional["BaseInferenceOps"] = None, - already_nested: bool = False, ): - self._tensors = tensors if already_nested else _flat_to_nested_dict(tensors) + if not isinstance(tensors, dict): + tensors = {t.name: t for t in tensors} + self._tensors = _flat_to_nested_dict(tensors) + assert ( + isinstance(k, str) and isinstance(v, InferenceTensor) + for k, v in tensors.items() + ) if ops is None: # Use the custom op library by default. Note that since the ops # namespace depends on types, we have to lazy load it. @@ -74,6 +97,22 @@ def __init__( ops = CustomInferenceOps() self.ops = ops + def transform(self, *transforms: InferenceTensorTransform) -> "Theta": + """Transforms all inference tensors by applying transform functions. + + Returns a modified theta. + """ + orig_flat_tensors = self.flatten().values() + tran_flat_tensors = [] + for it in orig_flat_tensors: + for transform in transforms: + it = transform(it) + tran_flat_tensors.append(it) + return Theta(tran_flat_tensors) + + def to(self, *, device: Optional[Union[str, torch.device]] = None) -> "Theta": + return self.transform(InferenceTensorTransforms.to_device(device)) + def flatten(self) -> dict[str, InferenceTensor]: results = {} @@ -212,10 +251,27 @@ def save( @staticmethod def load( - path: Union[str, Path], *, file_type: Optional[str] = None, mmap: bool = True + path: Union[str, Path], + *, + file_type: Optional[str] = None, + mmap: bool = True, + device: Optional[Union[str, torch.device]] = None, ) -> "Dataset": """Loads a dataset from a parameter archive constructed with save.""" - return _dataset_load_helper(path, file_type=file_type, mmap=mmap) + ds = _dataset_load_helper(path, file_type=file_type, mmap=mmap) + if device is not None: + ds.to(device=device) + return ds + + def transform(self, *transforms: InferenceTensorTransform): + """Does an in-place transformation of `root_theta`. + + The result of the transformation is stored back into `root_theta`. + """ + self.root_theta = self.root_theta.transform(*transforms) + + def to(self, *, device: Optional[Union[str, torch.device]] = None): + self.transform(InferenceTensorTransforms.to_device(device)) ################################################################################ diff --git a/sharktank/sharktank/utils/tokenizer.py b/sharktank/sharktank/utils/tokenizer.py index 92a3a6893..ae0789891 100644 --- a/sharktank/sharktank/utils/tokenizer.py +++ b/sharktank/sharktank/utils/tokenizer.py @@ -12,7 +12,6 @@ import math import os - __all__ = [ "load_tokenizer", "InferenceTokenizer", @@ -53,12 +52,10 @@ def decode(self, tokens: Union[list[list[int]]], lens: Optional[list[int]] = Non return self._decode(tokens) @abstractmethod - def _encode(self, texts: list[str]) -> list[list[int]]: - ... + def _encode(self, texts: list[str]) -> list[list[int]]: ... @abstractmethod - def _decode(self, tokens: list[list[int]]) -> list[str]: - ... + def _decode(self, tokens: list[list[int]]) -> list[str]: ... def load_tokenizer(*posargs, tokenizer_type: str = "transformers", **kwargs): diff --git a/sharktank/tests/types/dataset_test.py b/sharktank/tests/types/dataset_test.py index 331a59644..b71f6e5c4 100644 --- a/sharktank/tests/types/dataset_test.py +++ b/sharktank/tests/types/dataset_test.py @@ -55,6 +55,28 @@ def testThetaAccess(self): sub_sub_theta = theta("a", "b") self.assertEqual("a.b.c", sub_sub_theta.tensors[0].name) + def testTransform(self): + t1 = Theta( + _flat_t_dict( + _t("a.b.c", 1, 2), + _t("a.c.d", 10, 11), + _t("1.2.3", 3, 4), + ) + ) + + # We are mainly seeing that the structure/tensors were changed. + # Without a second device, it is otherwise hard to see an effect. + t2 = t1.to(device="cpu:1") + self.assertIsNot(t1, t2) + it1 = t1.tensor("a", "b", "c") + it2 = t2.tensor("a", "b", "c") + self.assertIsNot(it1, it2) + for k in it1.globals.keys(): + pt1 = it1.globals[k] + pt2 = it2.globals[k] + self.assertIsNot(pt1, pt2) + torch.testing.assert_close(pt1, pt2) + class DatasetTest(unittest.TestCase): def setUp(self): @@ -63,6 +85,20 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.temp_dir) + def testDatasetTransform(self): + t1 = Theta( + _flat_t_dict( + _t("a.b.c", 1, 2), + _t("a.c.d", 10, 11), + _t("1.2.3", 3, 4), + ) + ) + ds = Dataset({}, t1) + ds.to(device="cpu:1") + # Just checking that it was in fact transformed. Rely on other + # unit tests for leaves transformed correctly. + self.assertIsNot(t1, ds.root_theta) + def testDatasetRoundtrip(self): theta = Theta( _flat_t_dict( diff --git a/sharktank/tests/types/tensors_test.py b/sharktank/tests/types/tensors_test.py new file mode 100644 index 000000000..2d9d238e1 --- /dev/null +++ b/sharktank/tests/types/tensors_test.py @@ -0,0 +1,59 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest + +import torch + +from sharktank.types import * + + +def _createTestLayout(): + n = 128 + k = 1024 + bs = 32 + + return BlockScaledLayout( + [n, k], + d=torch.empty(n, k // bs, 1, dtype=torch.float32), + qs=torch.empty(n, k // bs, bs, dtype=torch.int8), + m=torch.empty(n, k // bs, bs, dtype=torch.float32), + ) + + +class PlanarQuantizedTensorTest(unittest.TestCase): + def testTransform(self): + pqt1 = PlanarQuantizedTensor("t1", [128, 1024], _createTestLayout()) + + def transform1(d): + new_d = {} + for k, t in d.items(): + if k.endswith(":qs"): + t = t.to(torch.int16) + new_d[k] = t + return new_d + + def transform2(d): + new_d = {} + for k, t in d.items(): + if k.endswith(":d") or k.endswith(":m"): + t = t.to(torch.float16) + new_d[k] = t + return new_d + + pqt2 = pqt1.transform_globals(transform1, transform2) + self.assertIsNot(pqt1, pqt2) + print(pqt2) + self.assertEqual(pqt2.name, pqt1.name) + self.assertEqual(pqt2.shape, pqt1.shape) + new_planes = pqt2.layout.planes + self.assertEqual(new_planes["qs"].dtype, torch.int16) + self.assertEqual(new_planes["m"].dtype, torch.float16) + self.assertEqual(new_planes["d"].dtype, torch.float16) + + +if __name__ == "__main__": + unittest.main()