From 971c5c987c17d96af736511a388c9281a6715732 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 5 Sep 2024 23:50:00 -0500 Subject: [PATCH] [WIP] add tensor parallelism to the paged llama model --- sharktank/sharktank/examples/paged_llm_v1.py | 12 +- sharktank/sharktank/layers/kv_cache.py | 112 ++++-- sharktank/sharktank/layers/norm.py | 4 +- .../layers/paged_llama_attention_block.py | 8 +- .../sharktank/layers/rotary_embedding.py | 56 ++- sharktank/sharktank/models/llama/llama.py | 69 +++- sharktank/sharktank/models/llama/sharding.py | 114 ++++++ sharktank/sharktank/models/llama/testing.py | 85 ++++ sharktank/sharktank/models/llama/theta.py | 10 + sharktank/sharktank/models/punet/sharding.py | 2 +- sharktank/sharktank/ops/default_impls.py | 87 +++- sharktank/sharktank/ops/shape.py | 52 ++- sharktank/sharktank/ops/sharded_impls.py | 370 ++++++++++++++++-- sharktank/sharktank/ops/signatures.py | 270 ++++++++++++- sharktank/sharktank/types/sharding.py | 95 ++++- sharktank/sharktank/types/tensors.py | 131 ++++++- sharktank/sharktank/utils/__init__.py | 7 + sharktank/sharktank/utils/cli.py | 2 + sharktank/sharktank/utils/hf_datasets.py | 17 + sharktank/sharktank/utils/math.py | 6 + sharktank/sharktank/utils/misc.py | 16 + .../tests/models/llama/sharded_llama_test.py | 115 ++++++ sharktank/tests/ops/sharded_test.py | 15 + 23 files changed, 1543 insertions(+), 112 deletions(-) create mode 100644 sharktank/sharktank/models/llama/sharding.py create mode 100644 sharktank/sharktank/models/llama/theta.py create mode 100644 sharktank/sharktank/utils/__init__.py create mode 100644 sharktank/sharktank/utils/misc.py create mode 100644 sharktank/tests/models/llama/sharded_llama_test.py diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index 47d281565..66526d542 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -19,6 +19,7 @@ # TODO: Should be using a base class with the protocol supported. from ..models.mixtral.mixtral import * from ..models.llama.llama import * +from ..models.llama.sharding import shard_theta from ..utils.debugging import trace_tensor from ..utils.tokenizer import InferenceTokenizer, load_tokenizer @@ -38,9 +39,9 @@ def __init__( self.tokenizer = tokenizer if model.cache.is_paged: self.shared_cache_state = model.cache.paged.allocate(page_cache_size) + self.free_pages = list(range(1, page_cache_size)) else: self.shared_cache_state = None - self.free_pages = list(range(1, 128)) self.end_token = end_token @property @@ -218,6 +219,12 @@ def main(): help="DType to use for activations in the model", default="float32", ) + parser.add_argument( + "--tensor-parallelism-size", + type=int, + default=1, + help="How many devices are involved for tensor parallel sharding.", + ) cli.add_input_dataset_options(parser) cli.add_tokenizer_options(parser) args = cli.parse(parser) @@ -236,7 +243,10 @@ def main(): device=device, activation_dtype=activation_dtype, attention_dtype=activation_dtype, + tensor_parallelism_size=args.tensor_parallelism_size, ) + if config.tensor_parallelism_size > 1: + dataset.root_theta = shard_theta(dataset.root_theta, config) if config.hp.expert_count: model = PagedMixtralModelV1(dataset.root_theta, config) diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index 47b465bd2..64dba1a69 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -11,7 +11,7 @@ and dims floating around everywhere. """ -from typing import Optional +from typing import Optional, Union import abc import math @@ -19,6 +19,7 @@ import torch from ..utils.debugging import trace_tensor +from ..types import SplitPrimitiveTensor, ReplicatedTensor __all__ = [ "BaseKVCache", @@ -138,6 +139,8 @@ class PagedKVCache(BaseKVCache): Note that the internal page structure matches the organization of the model, allowing contiguous individual local reads and writes at a sub-block granularity if indexing deeply into the structure. + + `shard_count` would split the attn_head_count dimension. """ def __init__( @@ -150,63 +153,102 @@ def __init__( block_seq_stride: int = 16, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, + shard_count: int = 1, ): self.transformer_block_count = transformer_block_count self.attn_head_count = attn_head_count self.attn_head_dim = attn_head_dim self.cache_partition_count = cache_partition_count self.block_seq_stride = block_seq_stride + self.shard_count = shard_count + if attn_head_count % shard_count != 0: + raise ValueError( + f"The attention head count {attn_head_count} must be a multiple of the tensor parallelism size {shard_count}." + ) # Some derived values based on attributes. self.sub_page_dims = [ self.transformer_block_count, self.cache_partition_count, self.block_seq_stride, - self.attn_head_count, + self.attn_head_count // self.shard_count, self.attn_head_dim, ] self.page_slab_flat_dim = math.prod(self.sub_page_dims) self.device = device self.dtype = dtype - def unflatten_page_table(self, state: list[torch.Tensor]) -> torch.Tensor: + def unflatten_page_table( + self, state: list[Union[torch.Tensor, SplitPrimitiveTensor]] + ) -> Union[torch.Tensor, SplitPrimitiveTensor]: """Unflattens the 2D page table to a 6D tensor.""" assert len(state) == 1, f"Expected 1-element state. Got: {len(state)}" page_slab = state[0] - return page_slab.reshape( - [ - -1, - self.transformer_block_count, - self.cache_partition_count, - self.block_seq_stride, - self.attn_head_count, - self.attn_head_dim, + if self.shard_count == 1: + assert not isinstance(page_slab, SplitPrimitiveTensor) + return page_slab.reshape( + [ + -1, + self.transformer_block_count, + self.cache_partition_count, + self.block_seq_stride, + self.attn_head_count, + self.attn_head_dim, + ] + ) + else: + assert self.shard_count == page_slab.shard_count + shards = [ + shard.reshape( + [ + -1, + self.transformer_block_count, + self.cache_partition_count, + self.block_seq_stride, + self.attn_head_count // self.shard_count, + self.attn_head_dim, + ] + ) + for shard in page_slab.shards ] - ) + return SplitPrimitiveTensor(ts=shards, shard_dim=4) @property def pad_sequence_stride(self) -> int: return self.block_seq_stride - def allocate(self, page_count: int) -> list[torch.Tensor]: + def allocate( + self, page_count: int + ) -> list[Union[torch.Tensor, SplitPrimitiveTensor]]: """Allocates tensor state for a page table for the given capacity in pages. """ - return [ - torch.empty( - [page_count, self.page_slab_flat_dim], - dtype=self.dtype, - device=self.device, - ) - ] + if self.shard_count == 1: + return [ + torch.empty( + [page_count, self.page_slab_flat_dim], + dtype=self.dtype, + device=self.device, + ) + ] + else: + shards = [ + torch.empty( + [page_count, self.page_slab_flat_dim], + dtype=self.dtype, + device=self.device, + ) + for _ in range(self.shard_count) + ] + return [SplitPrimitiveTensor(ts=shards, shard_dim=1)] def read( self, - state: list[torch.Tensor], + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], *, - read_into_partitions: list[torch.Tensor], + read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], transformer_block_index: int, - page_ids: torch.Tensor, + page_ids: Union[torch.Tensor, ReplicatedTensor], ): """Reads cache partitions from the page table for the given page_ids. @@ -231,7 +273,7 @@ def read( bs, block_seq_len, self.block_seq_stride, - self.attn_head_count, + self.attn_head_count // self.shard_count, self.attn_head_dim, ] @@ -249,7 +291,9 @@ def read( transformer_block_index * transformer_block_stride ) - def read_cache_partition(index: int, into_partition: torch.Tensor): + def read_cache_partition( + index: int, into_partition: Union[torch.Tensor, SplitPrimitiveTensor] + ): subblock_ids = ( (base_subblock_ids + index) if index > 0 else base_subblock_ids ) @@ -274,15 +318,15 @@ def read_cache_partition(index: int, into_partition: torch.Tensor): def write_timestep( self, - state: list[torch.Tensor], + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], # List of [bs, 1, attn_head_count, attn_head_dim] - cache_partitions: list[torch.Tensor], + cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], *, transformer_block_index: int, # [bs] - seq_positions: torch.Tensor, + seq_positions: Union[torch.Tensor, ReplicatedTensor], # [bs, max_seqlen // block_pos_stride] - page_ids: torch.Tensor, + page_ids: Union[torch.Tensor, ReplicatedTensor], ): """Writes a single batched timestep across all cache partitions. @@ -311,11 +355,11 @@ def write_timestep( def write( self, - state: list[torch.Tensor], - cache_partitions: list[torch.Tensor], + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], *, transformer_block_index: int, - page_ids: torch.Tensor, + page_ids: Union[torch.Tensor, ReplicatedTensor], ): """Writes cache partitions from a linear layout to the page table. @@ -348,7 +392,9 @@ def write( transformer_block_index * transformer_block_stride ) - def write_cache_partition(index: int, part: torch.Tensor): + def write_cache_partition( + index: int, part: Union[torch.Tensor, SplitPrimitiveTensor] + ): part_block_view = part.reshape(blocked_shape) subblock_ids = ( (base_subblock_ids + index) if index > 0 else base_subblock_ids diff --git a/sharktank/sharktank/layers/norm.py b/sharktank/sharktank/layers/norm.py index d062f1ffb..4fa08050a 100644 --- a/sharktank/sharktank/layers/norm.py +++ b/sharktank/sharktank/layers/norm.py @@ -33,9 +33,9 @@ def __init__( def forward(self, x: torch.Tensor): orig_dtype = x.dtype - x = x.to(self.dtype) + x = ops.to(x, self.dtype) norm = ops.rms_norm(x, self.weight, epsilon=self.epsilon) # Will automatically upcast to the dtype of the weight, which is # often in higher precision. Downcast back to expected. - norm = norm.to(orig_dtype) + norm = ops.to(norm, orig_dtype) return norm diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 473c7cb78..5ffd9999d 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -16,6 +16,7 @@ from .norm import RMSNormLayer from .rotary_embedding import RotaryEmbeddingLayer from .kv_cache import PagedKVCache +from .. import ops __all__ = [ "PagedLlamaAttentionBlock", @@ -140,7 +141,7 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: values = xv.transpose(1, 2) # Flash attention. - attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = ops.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) self.assert_not_nan(attn_weights) # Apply attention mask. @@ -149,8 +150,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: # self.trace_tensor("attn_mask", attention_mask) attn_weights = attn_weights + attention_mask - attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(xq) - attn_output = torch.matmul(attn_weights, values) # (bs, heads, slen, head_dim) + attn_weights = ops.softmax(ops.to(attn_weights, dtype=torch.float32), dim=-1) + attn_weights = ops.to(attn_weights, dtype=xq.dtype) + attn_output = ops.matmul(attn_weights, values) # (bs, heads, slen, head_dim) attn_output = attn_output.transpose(1, 2).reshape(bs, batch_seq_len, -1) # Project. diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py index a5f0eed09..232dc2e1d 100644 --- a/sharktank/sharktank/layers/rotary_embedding.py +++ b/sharktank/sharktank/layers/rotary_embedding.py @@ -9,6 +9,8 @@ import torch from .base import BaseLayer +from .. import ops +from ..types import SplitPrimitiveTensor, ReplicatedTensor, unbox_tensor class RotaryEmbeddingLayer(BaseLayer): @@ -23,6 +25,7 @@ def __init__( device: Optional[torch.device] = None, use_hf: bool = False, static_tables: bool = True, + tensor_parallelism_size: int = 1, ): super().__init__() # Force static_tables until compiler limitations are solved. @@ -33,9 +36,10 @@ def __init__( self.max_seqlen = max_seqlen self.use_hf = use_hf self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0 + self.tensor_parallelism_size = tensor_parallelism_size if static_tables: - self.register_buffer( - "static_rotary_embed_table", self._create_rotary_embed_table() + ops.module_register_buffer( + self, "static_rotary_embed_table", self._create_rotary_embed_table() ) else: self.static_rotary_embed_table = None @@ -48,6 +52,48 @@ def rotary_embed_table(self): return self.static_rotary_embed_table def forward(self, *, xq: torch.Tensor, xk: torch.Tensor, start_index: int): + if isinstance(xq, SplitPrimitiveTensor): + assert ( + isinstance(xk, SplitPrimitiveTensor) + and xq.shard_count == xk.shard_count + and xk.shard_dim == xq.shard_dim + ) + assert ( + isinstance(self.rotary_embed_table, ReplicatedTensor) + and xq.shard_count == self.rotary_embed_table.shard_count + ) + xqk_shards = [ + self.forward_unsharded( + xq=unbox_tensor(xq_shard), + xk=unbox_tensor(xk_shard), + start_index=start_index, + rotary_embed_table=unbox_tensor(rotary_embed_table_shard), + ) + for xq_shard, xk_shard, rotary_embed_table_shard in zip( + xq.shards, xk.shards, self.rotary_embed_table.shards + ) + ] + xq_shards = [xqk[0] for xqk in xqk_shards] + xk_shards = [xqk[1] for xqk in xqk_shards] + xq = SplitPrimitiveTensor(ts=xq_shards, shard_dim=xq.shard_dim) + xk = SplitPrimitiveTensor(ts=xk_shards, shard_dim=xk.shard_dim) + return xq, xk + else: + return self.forward_unsharded( + xq=xq, + xk=xk, + start_index=start_index, + rotary_embed_table=self.rotary_embed_table, + ) + + def forward_unsharded( + self, + *, + xq: torch.Tensor, + xk: torch.Tensor, + start_index: int, + rotary_embed_table: torch.Tensor, + ): # xq_, xk_ shape: bs, sl, _, dim # freqs_cis shape: max_sl, dim @@ -97,7 +143,7 @@ def create_ordering_tensor(dim): _, sl, _, dim = xq_.shape # Offset the table based on starting position. - freqs_cis = self.rotary_embed_table[start_index : start_index + sl, :] + freqs_cis = rotary_embed_table[start_index : start_index + sl, :] assert freqs_cis.shape[-1] == dim assert ( freqs_cis.shape[0] >= sl @@ -200,4 +246,8 @@ def _create_rotary_embed_table( else torch.polar(torch.ones_like(freqs), freqs) ) + if self.tensor_parallelism_size > 1: + # Replicate across all devices, the data is not a lot and the computation is cheap. + freqs_cis = ops.replicate(freqs_cis, self.tensor_parallelism_size) + return freqs_cis diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 8266872ad..528067d33 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -8,6 +8,7 @@ from dataclasses import dataclass import math +from typing import Union import torch import torch.nn as nn @@ -60,6 +61,11 @@ class LlamaModelConfig: # the program and not. static_tables: bool = True + # How many devices are involved for tensor parallel sharding. + # If greater than 1, the model will expect sharded model parameters and function + # arguments. + tensor_parallelism_size: int = 1 + def create_kv_cache(self) -> BaseKVCache: hp = self.hp if self.kv_cache_type == "direct": @@ -81,6 +87,7 @@ def create_kv_cache(self) -> BaseKVCache: block_seq_stride=self.block_seq_stride, device=self.device, dtype=self.attention_dtype, + shard_count=self.tensor_parallelism_size, ) else: raise NotImplementedError(f"kv_cache_type = {self.kv_cache_type}") @@ -111,6 +118,13 @@ class PagedLlamaModelV1(BaseCausalLMModel): to be serviced. Various samplers and schedulers can be interleaved throughout. + + In the case of tensor sharding (config.tensor_parallelism_size > 1) the model's KV + cache head dimension is sharded. + The number of KV cache heads must be divisible by the parallelism size. + With this sharding approach the KV cache is not replicated across devices. + The cache is split across the devices while the indexing logic/computation is + replicated. """ def __init__(self, theta: Theta, config: LlamaModelConfig): @@ -142,6 +156,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): device=self.device, use_hf=self.use_hf, static_tables=config.static_tables, + tensor_parallelism_size=config.tensor_parallelism_size, ), ) self.add_module( @@ -171,18 +186,31 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): def prefill( self, # [bs, batch_seq_len] - tokens: torch.Tensor, + tokens: Union[torch.Tensor, ReplicatedTensor], *, # [1, 1, batch_seq_len, batch_seq_len] - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, ReplicatedTensor], # [bs, batch_seq_len // block_seq_stride] - seq_block_ids: torch.Tensor, - cache_state: list[torch.Tensor], + seq_block_ids: Union[torch.Tensor, ReplicatedTensor], + cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], ): self._assert_device(tokens) self._assert_device(attention_mask, dtype=self.activation_dtype) self._assert_device(seq_block_ids) self._assert_device(*cache_state, dtype=self.activation_dtype) + if self.config.tensor_parallelism_size > 1: + if not isinstance(tokens, ReplicatedTensor): + tokens = ops.replicate( + tokens, count=self.config.tensor_parallelism_size + ) + if not isinstance(attention_mask, ReplicatedTensor): + attention_mask = ops.replicate( + attention_mask, count=self.config.tensor_parallelism_size + ) + if not isinstance(seq_block_ids, ReplicatedTensor): + seq_block_ids = ops.replicate( + seq_block_ids, count=self.config.tensor_parallelism_size + ) h = self.token_embedding(tokens) self.trace_tensor("llama.token_embedding", h) @@ -207,20 +235,37 @@ def prefill( def decode( self, # [bs, 1] - tokens: torch.Tensor, + tokens: Union[torch.Tensor, ReplicatedTensor], *, # [bs, 1, 1, batch_seq_len] - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, ReplicatedTensor], # [bs] of starting positions - start_positions: torch.Tensor, + start_positions: Union[torch.Tensor, ReplicatedTensor], # [bs, batch_seq_len // block_seq_stride] - seq_block_ids: torch.Tensor, - cache_state: list[torch.Tensor], + seq_block_ids: Union[torch.Tensor, ReplicatedTensor], + cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], ): self._assert_device(tokens) self._assert_device(attention_mask, dtype=self.activation_dtype) self._assert_device(start_positions) self._assert_device(*cache_state, dtype=self.activation_dtype) + if self.config.tensor_parallelism_size > 1: + if not isinstance(tokens, ReplicatedTensor): + tokens = ops.replicate( + tokens, count=self.config.tensor_parallelism_size + ) + if not isinstance(attention_mask, ReplicatedTensor): + attention_mask = ops.replicate( + attention_mask, count=self.config.tensor_parallelism_size + ) + if not isinstance(start_positions, ReplicatedTensor): + start_positions = ops.replicate( + start_positions, count=self.config.tensor_parallelism_size + ) + if not isinstance(seq_block_ids, ReplicatedTensor): + seq_block_ids = ops.replicate( + seq_block_ids, count=self.config.tensor_parallelism_size + ) bs, _ = tokens.shape # Precompute a position based mask for computing rope embeddings # as it is the same for all blocks. @@ -252,6 +297,10 @@ def decode( device=self.device, ) + if self.config.tensor_parallelism_size > 1: + tokens = ops.replicate( + tokens, shard_count=self.config.tensor_parallelism_size + ) h = self.token_embedding(tokens) self.trace_tensor("llama.token_embedding", h) @@ -324,7 +373,7 @@ def __init__( def forward( self, - h: torch.Tensor, + h: AnyTensor, *, embedding: RotaryEmbeddingLayer, # [bs, batch_seq_len // block_seq_stride] diff --git a/sharktank/sharktank/models/llama/sharding.py b/sharktank/sharktank/models/llama/sharding.py new file mode 100644 index 000000000..202bdc0d7 --- /dev/null +++ b/sharktank/sharktank/models/llama/sharding.py @@ -0,0 +1,114 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Specifications describing how blocks/layers of llama are sharded.""" + +from ...types.sharding import * +from ...types import Theta +from ... import ops + + +class PagedLlamaAttentionBlockSharding(ThetaLayerSharding): + def __init__(self, shard_count: int): + super().__init__() + self.shard_count = shard_count + + def theta_sharding(self) -> ThetaSharding: + return ThetaSharding( + { + # The size of this is the token embedding length, which is not a memory + # space concern if replicated even for all attention blocks. + "attn_norm": RmsNormReplicatedSharding( + self.shard_count + ).theta_sharding(), + "attn_q": LinearReplicatedInputSplitWeightAndBiasSharding( + shard_count=self.shard_count + ).theta_sharding(), + "attn_k": LinearReplicatedInputSplitWeightAndBiasSharding( + shard_count=self.shard_count + ).theta_sharding(), + "attn_v": LinearReplicatedInputSplitWeightAndBiasSharding( + shard_count=self.shard_count + ).theta_sharding(), + "attn_output": LinearSplitReductionDimSharding( + shard_count=self.shard_count + ).theta_sharding(), + } + ) + + +class AttentionFFNBlockSharding(ThetaLayerSharding): + def __init__(self, shard_count: int): + super().__init__() + self.shard_count = shard_count + + def theta_sharding(self) -> ThetaSharding: + result = PagedLlamaAttentionBlockSharding(self.shard_count).theta_sharding() + result.update(FFNSharding(self.shard_count).theta_sharding()) + result.update( + { + # The size of this is the token embedding length, which is not a memory + # space concern if replicated. + "ffn_norm": RmsNormReplicatedSharding(self.shard_count).theta_sharding() + } + ) + return result + + +class LlamaSharding(ThetaLayerSharding): + """Shards the input channel and output channels of the convolutions.""" + + def __init__(self, shard_count: int, attention_block_count: int): + super().__init__() + self.shard_count = shard_count + self.attention_block_count = attention_block_count + + def theta_sharding(self) -> ThetaSharding: + result = ThetaSharding( + { + # Replicate the vocabulary. For llama 1-3 this will require 0.5 GiB. + # For devices with large memory this may be an acceptable tradeoff where + # we save on communication by not all-gathering the result afterwards. + # The computation is just indexing and replication is not a concern. + # Alternatively, we can try splitting the index dimension, + # this would require custom logic for indexing partitioning and gathering. + "token_embd": TokenEmbeddingLayerReplicatedSharding( + self.shard_count + ).theta_sharding(), + "rope_freqs": Ignore(), + "output_norm": RmsNormReplicatedSharding( + self.shard_count + ).theta_sharding(), + "output": LinearSplitReductionDimSharding( + self.shard_count + ).theta_sharding(), + } + ) + result.update( + { + "blk": ThetaSharding( + { + f"{i}": AttentionFFNBlockSharding( + self.shard_count + ).theta_sharding() + for i in range(self.attention_block_count) + } + ) + } + ) + return result + + +def shard_theta( + theta: Theta, config: "sharktank.models.llama.llama.LlamaModelConfig" +) -> Theta: + return ops.reshard( + theta, + LlamaSharding( + shard_count=config.tensor_parallelism_size, + attention_block_count=config.hp.block_count, + ), + ) diff --git a/sharktank/sharktank/models/llama/testing.py b/sharktank/sharktank/models/llama/testing.py index b63fd5d07..ffe11928f 100644 --- a/sharktank/sharktank/models/llama/testing.py +++ b/sharktank/sharktank/models/llama/testing.py @@ -10,6 +10,9 @@ from ...types.tensors import * from ...types.theta import Theta +from typing import Optional +from .llama import LlamaModelConfig +import torch # Range of torch.rand() is [0,1) @@ -56,6 +59,60 @@ def make_attention_block_theta( ) +def make_attention_block_theta_v2( + *, + head_count: int, + head_count_kv: int, + head_dim: int, + embedding_length: int, + feed_forward_length: int, + dtype: torch.dtype | None = None, +) -> Theta: + return Theta( + { + "attn_q.weight": DefaultPrimitiveTensor( + data=make_rand_torch( + (head_count * head_dim, embedding_length), dtype=dtype + ) + ), + "attn_k.weight": DefaultPrimitiveTensor( + data=make_rand_torch( + (head_count_kv * head_dim, embedding_length), dtype=dtype + ) + ), + "attn_v.weight": DefaultPrimitiveTensor( + data=make_rand_torch( + (head_count_kv * head_dim, embedding_length), dtype=dtype + ) + ), + "attn_output.weight": DefaultPrimitiveTensor( + data=make_rand_torch((embedding_length, embedding_length), dtype=dtype) + ), + "attn_norm.weight": DefaultPrimitiveTensor( + data=make_rand_torch((embedding_length), dtype=dtype) + ), + "ffn_norm.weight": DefaultPrimitiveTensor( + data=make_rand_torch((head_count * head_dim), dtype=dtype) + ), + "ffn_gate.weight": DefaultPrimitiveTensor( + data=make_rand_torch( + (feed_forward_length, embedding_length), dtype=dtype + ) + ), + "ffn_up.weight": DefaultPrimitiveTensor( + data=make_rand_torch( + (feed_forward_length, embedding_length), dtype=dtype + ) + ), + "ffn_down.weight": DefaultPrimitiveTensor( + data=make_rand_torch( + (embedding_length, feed_forward_length), dtype=dtype + ) + ), + } + ) + + def make_moe_block_theta(feature_dim=1024, ffn_dim=6144, num_experts=8) -> Theta: return Theta( { @@ -79,3 +136,31 @@ def make_moe_block_theta(feature_dim=1024, ffn_dim=6144, num_experts=8) -> Theta ), } ) + + +def make_random_llama_theta( + config: LlamaModelConfig, vocab_size: int, dtype: Optional[torch.dtype] = None +) -> Theta: + res = { + "token_embd.weight": DefaultPrimitiveTensor( + data=make_rand_torch((vocab_size, config.hp.embedding_length), dtype=dtype) + ) + } + for i in range(config.hp.block_count): + res[f"blk.{i}"] = make_attention_block_theta_v2( + head_count=config.hp.attention_head_count, + head_count_kv=config.hp.attention_head_count_kv, + head_dim=config.hp.attn_head_dim, + embedding_length=config.hp.embedding_length, + feed_forward_length=config.hp.feed_forward_length, + dtype=dtype, + ).tree + + res[f"output.weight"] = DefaultPrimitiveTensor( + data=make_rand_torch((vocab_size, config.hp.embedding_length), dtype=dtype) + ) + res[f"output_norm.weight"] = DefaultPrimitiveTensor( + data=make_rand_torch((1, config.hp.embedding_length), dtype=dtype) + ) + + return Theta(res) diff --git a/sharktank/sharktank/models/llama/theta.py b/sharktank/sharktank/models/llama/theta.py new file mode 100644 index 000000000..19861d808 --- /dev/null +++ b/sharktank/sharktank/models/llama/theta.py @@ -0,0 +1,10 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Optional +from .llama import LlamaModelConfig +from ...types import Theta, DefaultPrimitiveTensor +import torch diff --git a/sharktank/sharktank/models/punet/sharding.py b/sharktank/sharktank/models/punet/sharding.py index 22237b8bb..cb4fc6ff3 100644 --- a/sharktank/sharktank/models/punet/sharding.py +++ b/sharktank/sharktank/models/punet/sharding.py @@ -4,7 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""Specifications describing how block/layers of punet are sharded.""" +"""Specifications describing how blocks/layers of punet are sharded.""" from ...types.sharding import * diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index 0ab6053c2..83cca9c9c 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -64,9 +64,9 @@ def conv2d_default( # Elementwise @elementwise.override(Tensor) -def elementwise_unary(operator, x): +def elementwise_unary(operator, x, *args, **kwargs): x = unbox_tensor(x) - return operator(x) + return operator(x, *args, **kwargs) @elementwise.override( @@ -74,11 +74,11 @@ def elementwise_unary(operator, x): IsOfType(Tensor, PrimitiveTensor), IsOfType(Tensor, PrimitiveTensor, Number) ) ) -def elementwise_binary(operator, x, y): +def elementwise_binary(operator, x, y, *args, **kwargs): x = unbox_tensor(x) if isinstance(y, PrimitiveTensor): y = unbox_tensor(y) - return operator(x, y) + return operator(x, y, *args, **kwargs) @elementwise.override( @@ -133,6 +133,18 @@ def equal_default(a, b) -> bool: return torch.equal(unbox_tensor(a), unbox_tensor(b)) +@expand.override(Tensor) +def expand_default(tensor: AnyTensor, shape: List[int]) -> AnyTensor: + return unbox_tensor(tensor).expand(*shape) + + +@flatten.override(Tensor) +def flatten_default( + input: Union[Tensor, PrimitiveTensor], start_dim: int, end_dim: int +) -> Tensor: + return torch.flatten(unbox_tensor(input), start_dim, end_dim) + + @gemm.override(AllOfType(Tensor, InferenceTensor)) def gemm( a: AnyTensor, @@ -167,6 +179,17 @@ def group_norm_affine_default(input, weight, bias, *, num_groups, eps): return F.group_norm(input, num_groups=num_groups, weight=weight, bias=bias, eps=eps) +@index_copy_.override(Tensor, Tensor, Tensor) +def index_copy__default( + inout: Union[Tensor, PrimitiveTensor], + dim: int, + index: Union[Tensor, PrimitiveTensor], + tensor: Union[Tensor, PrimitiveTensor], +) -> Union[Tensor, PrimitiveTensor]: + unbox_tensor(inout).index_copy_(dim, unbox_tensor(index), unbox_tensor(tensor)) + return inout + + @interpolate.override(Tensor) def interpolate_default( input: Tensor, @@ -242,13 +265,30 @@ def scaled_dot_product_attention(q, k, v, a) -> Tensor: ) +@mean.override(Tensor) +def mean_default( + x: Tensor, dim: Union[int, List[int]], keepdim: bool, *, dtype: torch.dtype +) -> None: + return torch.mean(unbox_tensor(x), dim=dim, keepdim=keepdim, dtype=dtype) + + +@module_register_buffer.override(torch.nn.Module, Tensor) +def module_register_buffer_default( + module: torch.nn.Module, name: str, tensor: Union[Tensor, InferenceTensor] +) -> None: + return module.register_buffer(name, unbox_tensor(tensor)) + + +@reshape.override(Tensor) +def reshape_default(input: Union[PrimitiveTensor, Tensor], shape: List[int]) -> Tensor: + return torch.reshape(unbox_tensor(input), shape) + + # RMS norm -@rms_norm.override(Tensor, Tensor) +@rms_norm.override(AllOfType(Tensor, InferenceTensor)) def rms_norm_default(x, weight, *, epsilon: float) -> Tensor: - x = unbox_tensor(x) - weight = unbox_tensor(weight) variance = x.pow(2).mean(-1, keepdim=True) - output = x * torch.rsqrt(variance + epsilon) + output = x * elementwise(torch.rsqrt, variance + epsilon) output = output * weight return output @@ -268,6 +308,20 @@ def permute(tensor: Tensor, dims: List[int]): return torch.permute(torch_tensor, dims) +@softmax.override(Tensor) +def softmax_default( + tensor: Union[Tensor, PrimitiveTensor], + dim: Optional[int], + dtype: Optional[torch.dtype], +) -> Tensor: + return F.softmax(unbox_tensor(tensor), dim=dim, dtype=dtype) + + +@to.override(Tensor) +def to_default(tensor: Tensor, *args, **kwargs): + return unbox_tensor(tensor).to(*args, **kwargs) + + @transfer_to_logical_device.override(Tensor) def transfer_to_logical_device_default(tensor: Tensor, ordinal: int): return shark_turbine.ops.iree.transfer_to_logical_device( @@ -275,6 +329,13 @@ def transfer_to_logical_device_default(tensor: Tensor, ordinal: int): ) +@transpose.override(Tensor) +def transpose_default( + tensor: Union[Tensor, PrimitiveTensor], dim0: int, dim1: int +) -> Tensor: + return torch.transpose(unbox_tensor(tensor), dim0, dim1) + + # Sharded default impls (do nothing). @@ -286,3 +347,13 @@ def sharded_cat_unsharded(maybe_sharded): @sharded_sum.override(Tensor) def sharded_sum_unsharded(maybe_sharded): return unbox_tensor(maybe_sharded) + + +@unsqueeze.override(Tensor) +def unsqueeze_default(tensor: Union[Tensor, PrimitiveTensor], dim: int) -> Tensor: + return torch.unsqueeze(tensor, dim) + + +@view.override(Tensor) +def view_default(tensor: Union[Tensor, PrimitiveTensor], shape: List[int]) -> Tensor: + return unbox_tensor(tensor).view(*shape) diff --git a/sharktank/sharktank/ops/shape.py b/sharktank/sharktank/ops/shape.py index 69683f84f..76a7b1f73 100644 --- a/sharktank/sharktank/ops/shape.py +++ b/sharktank/sharktank/ops/shape.py @@ -4,10 +4,33 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Sequence +from typing import Sequence, Tuple, Optional from ..types.tensors import AnyTensor +def broadcast_shape_dims_map(*shapes: Tuple[Sequence[int]]) -> Tuple[Sequence[int]]: + """Return the correspondence between the broadcasted resulting shape and the + shapes. + For each argument shape return a list of indexes that are the dimensions in the + resulting shape. + + example: + ``` + broadcast_shape_dims_map([2, 1, 4, 5], [3, 4, 5], [1, 1, 3, 4, 1]) + ``` + result + ``` + ([1, 2, 3, 4], [2, 3, 4], [0, 1, 2, 3, 4]) + ``` + The result shape is `[1, 2, 3, 4, 5]`. + """ + ranks = [len(shape) for shape in shapes] + broadcast_rank = max(ranks) + return tuple( + [dim + max(0, broadcast_rank - len(s)) for dim in range(len(s))] for s in shapes + ) + + def broadcast_dim( dim: int, shaped_or_shape: Sequence[Sequence[int] | AnyTensor] ) -> int: @@ -53,3 +76,30 @@ def broadcast_dims( ranks = [len(shape) for shape in shaped_or_shape] broadcast_rank = max(ranks) return [dim + max(0, broadcast_rank - rank) for dim, rank in zip(dims, ranks)] + + +def unbroadcast_dim(dim: int, shapes: Sequence[Sequence[int]]) -> Optional[int]: + """Returns the dimension in `shapes[0]` such that it would correspond to `dim` + after broadcasting the shapes `shapes`.""" + ranks = [len(shape) for shape in shapes] + broadcast_rank = max(ranks) + res = dim - max(0, broadcast_rank - ranks[0]) + return None if res < 0 else res + + +# def matmul_result_shape_to_args_dim_map(lhs_shape: Sequence[int], rhs_shape: Sequence[int]) -> Tuple[Sequence[int], Sequence[int]]: +# """Return the correspondence between the resulting dimensions and its +# arguments for a matmul operation. +# -1 for reduction or missing dimensions. + +# example: +# ``` +# matmul_result_shape_to_args_dim_map([2, 1, 4, 5], [3, 5, 6]) +# ``` +# result +# ``` +# ([0, 1, 2, -1], [1, -1, 3]) +# ``` +# The result shape is `[2, 3, 4, 6]`. +# """ +# assert False, "TODO" diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index 6c0c0b39f..ea9a97eb3 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -6,9 +6,11 @@ import torch from torch import Tensor -from typing import List, Optional, Sequence +from typing import List, Optional, Sequence, Union, Any, Tuple import itertools from numbers import Number +import math +import numpy as np from ..types import ( AnyTensor, @@ -24,7 +26,8 @@ from ..types.tensors import unbox_tensor from ._registry import AllOfType from .signatures import * -from .shape import broadcast_dims +from .shape import broadcast_dims, broadcast_dim, unbroadcast_dim +from ..utils import longest_equal_range @all_gather.override(SplitPrimitiveTensor) @@ -241,49 +244,67 @@ def conv2d_split_weight_and_bias( @elementwise.override(ReplicatedTensor) -def replicated_elementwise_unary(operator, x: ReplicatedTensor): - partials = [operator(unbox_tensor(pt)) for pt in x.shards] +def replicated_elementwise_unary(operator, x: ReplicatedTensor, *args, **kwargs): + partials = [operator(unbox_tensor(pt), *args, **kwargs) for pt in x.shards] return ReplicatedTensor(ts=partials) @elementwise.override(SplitPrimitiveTensor) -def split_elementwise_unary(operator, x: SplitPrimitiveTensor): - partials = [operator(unbox_tensor(pt)) for pt in x.shards] +def split_elementwise_unary(operator, x: SplitPrimitiveTensor, *args, **kwargs): + partials = [operator(unbox_tensor(pt), *args, **kwargs) for pt in x.shards] return SplitPrimitiveTensor(shard_dim=x.shard_dim, shape=x.shape, ts=partials) +@elementwise.override(ReplicatedTensor, ReplicatedTensor) +def replicated_elementwise_binary( + operator, x: ReplicatedTensor, y: ReplicatedTensor, *args, **kwargs +): + assert x.shard_count == y.shard_count + shards = [ + operator(unbox_tensor(shard_x), unbox_tensor(shard_y), *args, **kwargs) + for shard_x, shard_y in zip(x.shards, y.shards) + ] + return ReplicatedTensor(ts=shards) + + @elementwise.override(SplitPrimitiveTensor, SplitPrimitiveTensor) def split_elementwise_binary( - operator, x: SplitPrimitiveTensor, y: SplitPrimitiveTensor + operator, x: SplitPrimitiveTensor, y: SplitPrimitiveTensor, *args, **kwargs ): assert x.shard_count == y.shard_count x_shard_dim, y_shard_dim = broadcast_dims([x.shard_dim, y.shard_dim], [x, y]) assert x_shard_dim == y_shard_dim pt_xs = [unbox_tensor(pt) for pt in x.shards] pt_ys = [unbox_tensor(pt) for pt in y.shards] - partials = [operator(pt_x, pt_y) for pt_x, pt_y in zip(pt_xs, pt_ys)] - return SplitPrimitiveTensor(shard_dim=x.shard_dim, shape=x.shape, ts=partials) + partials = [ + operator(pt_x, pt_y, *args, **kwargs) for pt_x, pt_y in zip(pt_xs, pt_ys) + ] + return SplitPrimitiveTensor( + shard_dim=x.shard_dim, + shape=torch.broadcast_shapes(x.shape, y.shape), + ts=partials, + ) @elementwise.override(SplitPrimitiveTensor, Number) def elementwise_binary_split_lhs_scalar_rhs( - operator, x: SplitPrimitiveTensor, y: Number + operator, x: SplitPrimitiveTensor, y: Number, *args, **kwargs ): pt_xs = [unbox_tensor(pt) for pt in x.shards] - partials = [operator(pt_x, y) for pt_x in pt_xs] + partials = [operator(pt_x, y, *args, **kwargs) for pt_x in pt_xs] return SplitPrimitiveTensor(shard_dim=x.shard_dim, shape=x.shape, ts=partials) @elementwise.override(SplitPrimitiveTensor, Tensor) def elementwise_binary_split_lhs_tensor_rhs( - operator, x: SplitPrimitiveTensor, y: Tensor + operator, x: SplitPrimitiveTensor, y: Tensor, *args, **kwargs ): - return elementwise(operator, x, replicate(y, count=x.shard_count)) + return elementwise(operator, x, replicate(y, count=x.shard_count), *args, **kwargs) @elementwise.override(ReplicatedTensor, SplitPrimitiveTensor) def elementwise_binary_replicated_lhs_sharder_rhs( - operator, x: ReplicatedTensor, y: SplitPrimitiveTensor + operator, x: ReplicatedTensor, y: SplitPrimitiveTensor, *args, **kwargs ): if x.shard_count != y.shard_count: raise ValueError( @@ -292,20 +313,48 @@ def elementwise_binary_replicated_lhs_sharder_rhs( # A replicated tensor can be split with no cost. # It is natural to propagate the split instead of the replication. x_sharded = reshard_like(x, like=y) - return elementwise(operator, x_sharded, y) + return elementwise(operator, x_sharded, y, *args, **kwargs) @elementwise.override(SplitPrimitiveTensor, ReplicatedTensor) def elementwise_binary_split_lhs_replicated_rhs( - operator, x: SplitPrimitiveTensor, y: ReplicatedTensor + operator, x: SplitPrimitiveTensor, y: ReplicatedTensor, *args, **kwargs ): assert len(y.shape) > 0, "0-rank not supported" if x.shard_count != y.shard_count: raise ValueError( f"Operands' number of shards not equal ({x.shard_count} != {y.shard_count})" ) + + shard_dim_in_res = broadcast_dim(x.shard_dim, [x.shape, y.shape]) + shard_dim_in_y = unbroadcast_dim(shard_dim_in_res, [y.shape, x.shape]) + is_shard_dim_broadcasted_in_y = ( + shard_dim_in_y is None or y.shape[shard_dim_in_y] == 1 + ) + if is_shard_dim_broadcasted_in_y: + shards = [ + elementwise(operator, x_shard, y_shard) + for x_shard, y_shard in zip(x.shards, y.shards) + ] + return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim_in_res) + y_sharded = reshard_like(y, like=x) - return elementwise(operator, x, y_sharded) + return elementwise(operator, x, y_sharded, *args, **kwargs) + + +# Embedding Lookup +@embedding_lookup.override(ReplicatedTensor, ReplicatedTensor) +def embedding_lookup_default( + input: ReplicatedTensor, embedding_matrix: ReplicatedTensor, dtype: torch.dtype +): + assert input.shard_count == embedding_matrix.shard_count + shards = [ + embedding_lookup(input_shard, embedding_matrix_shard, dtype) + for input_shard, embedding_matrix_shard in zip( + input.shards, embedding_matrix.shards + ) + ] + return ReplicatedTensor(ts=shards) @equal.override(ReplicatedTensor) @@ -318,6 +367,61 @@ def equal_split(a: SplitPrimitiveTensor, b: AnyTensor) -> bool: return a.is_deep_equal(b) +@expand.override(SplitPrimitiveTensor) +def expand_split( + tensor: SplitPrimitiveTensor, shape: List[int] +) -> SplitPrimitiveTensor: + assert len(shape) == len(tensor.shape) + expanded_dims = [ + i + for i, (old_dim, new_dim) in enumerate(zip(tensor.shape, shape)) + if old_dim == 1 and new_dim != 1 + ] + assert ( + tensor.shard_dim not in expanded_dims + ), "Expanding a split dimension is not supported" + + def set_element(l: List, idx: int, el: Any) -> List: + l[idx] = el + return l + + shards = [ + expand( + shard, + set_element(list(shape), tensor.shard_dim, shard.shape[tensor.shard_dim]), + ) + for shard in tensor.shards + ] + return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim) + + +@flatten.override(ReplicatedTensor) +def flatten_replicated( + input: ReplicatedTensor, start_dim: int, end_dim: int +) -> ReplicatedTensor: + shards = [shard.flatten(start_dim, end_dim) for shard in input.shards] + return ReplicatedTensor(ts=shards) + + +@flatten.override(SplitPrimitiveTensor) +def flatten_split( + input: SplitPrimitiveTensor, start_dim: int, end_dim: int +) -> SplitPrimitiveTensor: + end_dim_resolved = len(input.shape) - 1 if end_dim == -1 else end_dim + assert input.shard_dim <= start_dim or end_dim_resolved < input.shard_dim, ( + "Flattening of a sharded dimension that is not the leading dimension in the" + " flattening dimension range is not supported. This would result in a" + " block-cyclic sharding which is not implemented." + ) + shards = [shard.flatten(start_dim, end_dim) for shard in input.shards] + shard_dim = ( + input.shard_dim + if input.shard_dim <= start_dim + else input.shard_dim - (end_dim_resolved - start_dim) + ) + return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) + + @group_norm_affine.override( SplitPrimitiveTensor, SplitPrimitiveTensor, SplitPrimitiveTensor ) @@ -338,6 +442,24 @@ def shareded_group_norm_affine(input, weight, bias, *, num_groups, eps): return SplitPrimitiveTensor(shard_dim=1, ts=result_shards) +@index_copy_.override(SplitPrimitiveTensor, ReplicatedTensor, SplitPrimitiveTensor) +def index_copy__split_replicated_split( + inout: SplitPrimitiveTensor, + dim: int, + index: ReplicatedTensor, + tensor: SplitPrimitiveTensor, +) -> SplitPrimitiveTensor: + assert ( + inout.shard_count == index.shard_count + and inout.shard_count == tensor.shard_count + ) + assert inout.shard_dim == tensor.shard_dim + assert inout.shard_dim != dim + for inout_shard, index_shard, tensor_shard in zip(inout, index, tensor): + index_copy_(inout_shard, dim, index_shard, tensor_shard) + return inout + + @interpolate.override(ReplicatedTensor) def interpolate_replicated( input: ReplicatedTensor, @@ -432,6 +554,7 @@ def matmul_replicated_lhs_split_rhs( lhs: ReplicatedTensor, rhs: SplitPrimitiveTensor, *, transpose_rhs: bool ) -> SplitPrimitiveTensor | UnreducedTensor: assert lhs.shard_count == rhs.shard_count + assert len(rhs.shape) == 2 if transpose_rhs: return matmul(lhs, rhs.T) @@ -512,9 +635,11 @@ def matmul_split( f"Cannot matmul split tensors of different shard_count: " f"({lhs.shard_count} vs {rhs.shard_count})" ) + if transpose_rhs: + return matmul(lhs, rhs.T) lhs_reduction_dim = len(lhs.shape) - 1 - rhs_reduction_dim = 1 if transpose_rhs else 0 + rhs_reduction_dim = len(rhs.shape) - 2 if len(rhs.shape) > 1 else len(rhs.shape) - 1 # The reduction dimension is split on both tensors. if lhs_reduction_dim == lhs.shard_dim and rhs_reduction_dim == rhs.shard_dim: @@ -524,10 +649,26 @@ def matmul_split( ] return UnreducedTensor(ts=partials) + is_batched_matmul = len(lhs.shape) > 2 or len(rhs.shape) > 2 + if ( + is_batched_matmul + and len(lhs.shape) == len(rhs.shape) + and lhs.shard_dim == rhs.shard_dim + ): + # The same batch dim is sharded for both arguments. + shards = [ + matmul(lhs_shard, rhs_shard) + for lhs_shard, rhs_shard in zip(lhs.shards, rhs.shards) + ] + return SplitPrimitiveTensor(ts=shards, shard_dim=lhs.shard_dim) + + # -1 for missing parallel dim. + lhs_parallel_dim = len(lhs.shape) - 2 + rhs_parallel_dim = len(rhs.shape) - 1 if len(rhs.shape) > 1 else -1 + # One parallel dimension is split for each tensor. - if lhs_reduction_dim != lhs.shard_dim and rhs_reduction_dim != rhs.shard_dim: - if transpose_rhs: - rhs = rhs.T + # Or lhs batch dim and rhs parallel dim are split. + if lhs.shard_dim <= lhs_parallel_dim and rhs_parallel_dim == rhs.shard_dim: # We gather along the rhs shard dim. # It is more natural to preserve the sharding axis of the input. shards = [sharded_cat(matmul(lhs_shard, rhs)) for lhs_shard in lhs.shards] @@ -536,6 +677,30 @@ def matmul_split( assert False, "Sharding configuration not supported" +@mean.override(ReplicatedTensor) +def mean_replicated( + x: ReplicatedTensor, + dim: Union[int, List[int]], + keepdim: bool, + *, + dtype: torch.dtype, +) -> None: + shards = [ + torch.mean(unbox_tensor(shard), dim=dim, keepdim=keepdim, dtype=dtype) + for shard in x.shards + ] + return ReplicatedTensor(ts=shards) + + +@module_register_buffer.override(torch.nn.Module, ShardedTensor) +def module_register_buffer_sharded( + module: torch.nn.Module, name: str, tensor: ShardedTensor +) -> None: + for i, shard in enumerate(tensor.shards): + module_register_buffer(module, f"{name}__shard__{i}", shard) + setattr(module, name, tensor) + + @permute.override(SplitPrimitiveTensor) def permute_split(tensor: SplitPrimitiveTensor, dims: List[int]): permuted_shards = [permute(shard, dims) for shard in tensor.shards] @@ -544,7 +709,7 @@ def permute_split(tensor: SplitPrimitiveTensor, dims: List[int]): @permute.override(ReplicatedTensor) -def permute_split(tensor: ReplicatedTensor, dims: List[int]): +def permute_replicated(tensor: ReplicatedTensor, dims: List[int]): permuted_shards = [permute(shard, dims) for shard in tensor.shards] return ReplicatedTensor(ts=permuted_shards) @@ -563,6 +728,22 @@ def replicate_unsharded(input, *, count: int) -> ReplicatedTensor: return ReplicatedTensor(ts=torch_input, shard_count=count) +@reshape.override(SplitPrimitiveTensor) +def reshape_split( + tensor: SplitPrimitiveTensor, shape: List[int] +) -> SplitPrimitiveTensor: + if _reshape_get_single_split_dim(tensor.shape, shape) is not None: + return view(tensor, shape) + + flatten_dim_range = _reshape_get_flatten_dim_range(tensor.shape, shape) + if flatten_dim_range is not None: + return flatten(tensor, flatten_dim_range[0], flatten_dim_range[1] - 1) + + raise ValueError( + f"Unsupported reshaping of sharded split tensor of shape {tensor.shape} to shape {shape}" + ) + + @reshard.override(Tensor, sharding.Split) def reshard_tensor_split(input: Tensor, spec: sharding.Split) -> AnyTensor: return reshard_split(input, dim=spec.shard_dim, count=spec.shard_count) @@ -588,7 +769,13 @@ def make_value(input: Theta | InferenceTensor, spec) -> dict | InferenceTensor: result.name = input.name return result - return Theta({k: make_value(input(k), spec[k]) for k in input.keys}) + return Theta( + { + k: make_value(input(k), spec[k]) + for k in input.keys + if not isinstance(spec[k], sharding.Ignore) + } + ) @reshard.override(Theta, sharding.ThetaLayerSharding) @@ -734,6 +921,45 @@ def sharded_sum_unreduced(maybe_sharded: UnreducedTensor) -> Tensor: return _sharded_sum_sharded(maybe_sharded) +@softmax.override(SplitPrimitiveTensor) +def softmax_split( + tensor: SplitPrimitiveTensor, dim: Optional[int], dtype: Optional[torch.dtype] +) -> Tensor: + dim = dim if dim is None or dim >= 0 else len(tensor.shape) + dim + assert ( + dim is not None and dim != tensor.shard_dim + ), "Softmax along split dimension is not supported." + shards = [softmax(shard, dim=dim, dtype=dtype) for shard in tensor.shards] + return SplitPrimitiveTensor( + ts=shards, shard_dim=tensor.shard_dim, shape=tensor.shape + ) + + +@to.override(ReplicatedTensor) +def to_replicated(tensor: ReplicatedTensor, *args, **kwargs): + shards = [to(shard, *args, **kwargs) for shard in tensor.shards] + return ReplicatedTensor(ts=shards) + + +@to.override(SplitPrimitiveTensor) +def to_split(tensor: SplitPrimitiveTensor, *args, **kwargs): + shards = [to(shard, *args, **kwargs) for shard in tensor.shards] + return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim) + + +@transpose.override(SplitPrimitiveTensor) +def transpose_split( + tensor: SplitPrimitiveTensor, dim0: int, dim1: int +) -> SplitPrimitiveTensor: + shards = [transpose(shard, dim0, dim1) for shard in tensor.shards] + shard_dim = tensor.shard_dim + if shard_dim == dim0: + shard_dim = dim1 + elif shard_dim == dim1: + shard_dim = dim0 + return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) + + @unshard.override(ReplicatedTensor) def unshard_replicated(input: ReplicatedTensor) -> Tensor: return input.shards[0] @@ -747,3 +973,101 @@ def unshard_split(input: SplitPrimitiveTensor) -> Tensor: @unshard.override(Tensor) def unshard_unsharded(input: Tensor) -> Tensor: return input + + +def _reshape_get_flatten_dim_range( + from_shape: List[int], to_shape: List[int] +) -> Optional[Tuple[int, int]]: + """If a reshape would flatten a range of dimensions return that index range [begin, end). + If the reshape is not of that kind return `None`.""" + flatten_start_len = _reshape_get_single_split_dim(to_shape, from_shape) + if flatten_start_len is None: + return None + start, length = flatten_start_len + return start, start + length + + +def _reshape_infer_dynamic_dim( + shape1: List[int], shape2: List[int] +) -> Tuple[List[int], List[int]]: + assert ( + len([d for d in list(shape1) + list(shape2) if d < 0]) <= 1 + ), "Only one dynamic dimension is allowed" + shape1_dynamic_dims = [i for i, d in enumerate(shape1) if d <= 0] + if len(shape1_dynamic_dims) > 0: + s2, s1 = _reshape_infer_dynamic_dim(shape2, shape1) + return s1, s2 + + shape2_dynamic_dims = [i for i, d in enumerate(shape2) if d <= 0] + if len(shape2_dynamic_dims) == 0: + return shape1, shape2 + shape2_dynamic_dim = shape2_dynamic_dims[0] + shape1_size = math.prod(shape1) + shape2_size_without_dynamic_dim = math.prod(d for d in shape2 if d > 0) + shape2_res = list(shape2) + assert shape1_size % shape2_size_without_dynamic_dim == 0 + shape2_res[shape2_dynamic_dim] = shape1_size // shape2_size_without_dynamic_dim + assert shape2_res[shape2_dynamic_dim] > 0 + return shape1, shape2_res + + +def _reshape_get_single_split_dim( + from_shape: List[int], to_shape: List[int] +) -> Optional[Tuple[int, int]]: + """If a reshape would split a single dimension, return its index and the length of the new dimensions. + If the reshape is not of that kind return `None`.""" + from_shape, to_shape = _reshape_infer_dynamic_dim(from_shape, to_shape) + + if len(to_shape) < len(from_shape): + return None + i = longest_equal_range(from_shape, to_shape) + if i == len(from_shape): + return i + j = len(to_shape) - longest_equal_range(reversed(from_shape), reversed(to_shape)) + assert i < j + expected_split_dim_size = math.prod(to_shape[i:j]) + if expected_split_dim_size != from_shape[i]: + return None + return ( + i, + j - i, + ) + + +@unsqueeze.override(SplitPrimitiveTensor) +def unsqueeze_default(tensor: SplitPrimitiveTensor, dim: int) -> SplitPrimitiveTensor: + shards = [torch.unsqueeze(unbox_tensor(shard), dim) for shard in tensor.shards] + shard_dim = tensor.shard_dim + dim_resolved = dim if dim >= 0 else dim + len(tensor.shape) + 1 + if shard_dim >= dim_resolved: + shard_dim += 1 + return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) + + +@view.override(SplitPrimitiveTensor) +def view_split(tensor: SplitPrimitiveTensor, shape: List[int]) -> SplitPrimitiveTensor: + view_split_range = _reshape_get_single_split_dim(tensor.shape, shape) + if view_split_range is None: + raise ValueError( + "Only taking a tensor view where splitting a single dimension is supported" + ) + view_split_dim = view_split_range[0] + + if view_split_dim == tensor.shard_dim: + if tensor.shape[view_split_dim] % tensor.shard_count != 0: + raise ValueError( + "Only splitting a dimension that is multiple of the shard count is supported" + ) + if shape[view_split_dim] % tensor.shard_count != 0: + raise ValueError( + "The resulting leading splitting dimension must be multiple of the shard count" + ) + + shard_dim = tensor.shard_dim + if shard_dim > view_split_dim: + new_dims_count = len(shape) - len(tensor.shape) + shard_dim += new_dims_count + new_shard_shape = list(shape) + new_shard_shape[shard_dim] //= tensor.shard_count + shards = [view(shard, new_shard_shape) for shard in tensor.shards] + return SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards, shape=shape) diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index d39aba71c..979eca72b 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -11,7 +11,7 @@ import torch import numbers from torch import Tensor, dtype -from ..types import AnyTensor, ShardedTensor, Theta, sharding +from ..types import AnyTensor, ShardedTensor, Theta, sharding, InferenceTensor from numbers import Number from ._registry import * @@ -24,23 +24,34 @@ "elementwise", "embedding_lookup", "equal", + "expand", + "flatten", "gemm", "group_norm_affine", "layer_norm", + "index_copy_", "interpolate", "linear", "matmul", + "mean", + "module_register_buffer", "permute", "rms_norm", "replicate", + "reshape", "reshard", "reshard_split", "reshard_like", "scaled_dot_product_attention", "sharded_cat", "sharded_sum", + "softmax", + "to", "transfer_to_logical_device", + "transpose", "unshard", + "unsqueeze", + "view", ] IntOrSequenceInt = Union[int, Sequence[int]] @@ -152,16 +163,21 @@ def _conv2d_trampoline( @overridable -def elementwise(operator, *args: AnyTensor) -> AnyTensor: +def elementwise(operator, *args, **kwargs) -> AnyTensor: """Applies an elementwise operator against arguments.""" raise NotImplementedError @elementwise.trampoline -def _elementwise_trampoline(d: SignatureDispatcher, operator, *args: AnyTensor): - tensors = args +def _elementwise_trampoline(d: SignatureDispatcher, operator, *args, **kwargs): + tensors = [] + for a in args: + if isinstance(a, (Tensor, InferenceTensor)): + tensors.append(a) + else: + break for override in d.find_overrides(tensors): - result = override(operator, *args) + result = override(operator, *args, **kwargs) if result is not NotImplemented: return override, result else: @@ -232,6 +248,44 @@ def _equal_trampoline(d: SignatureDispatcher, a: AnyTensor, b: AnyTensor): d.fail(tensors) +@overridable +def expand(tensor: AnyTensor, shape: List[int]) -> AnyTensor: + """See torch.Tensor.expand""" + ... + + +@expand.trampoline +def _expand_trampoline( + d: SignatureDispatcher, tensor: AnyTensor, shape: List[int] +) -> AnyTensor: + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor, shape) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + +@overridable +def flatten(input: AnyTensor, start_dim: int = 0, end_dim: int = -1) -> AnyTensor: + """See torch.flatten""" + ... + + +@flatten.trampoline +def _flatten_trampoline( + d: SignatureDispatcher, input: AnyTensor, start_dim: int = 0, end_dim: int = -1 +) -> AnyTensor: + dispatch_args = (input,) + for override in d.find_overrides(dispatch_args): + result = override(input, start_dim, end_dim) + if result is not NotImplemented: + return override, result + else: + d.fail(dispatch_args) + + @overridable def gemm( a: AnyTensor, @@ -298,6 +352,31 @@ def _group_norm_affine_trampoline( d.fail(tensors) +@overridable +def index_copy_( + inout: AnyTensor, dim: int, index: AnyTensor, tensor: AnyTensor +) -> AnyTensor: + """See torch.Tensor.index_copy_""" + ... + + +@index_copy_.trampoline +def _index_copy__trampoline( + d: SignatureDispatcher, + inout: AnyTensor, + dim: int, + index: AnyTensor, + tensor: AnyTensor, +) -> AnyTensor: + tensors = (inout, index, tensor) + for override in d.find_overrides(tensors): + result = override(inout, dim, index, tensor) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + @overridable def interpolate( input: AnyTensor, @@ -436,7 +515,6 @@ def _matmul_trampoline( d: SignatureDispatcher, lhs, rhs, *, transpose_rhs: bool = False ): tensors = (lhs, rhs) - assert isinstance(rhs, numbers.Number) or len(rhs.shape) == 2 for override in d.find_overrides(tensors): result = override(lhs, rhs, transpose_rhs=transpose_rhs) if result is not NotImplemented: @@ -464,6 +542,57 @@ def _permute_trampoline(d: SignatureDispatcher, tensor: AnyTensor, dims: List[in d.fail(tensors) +@overridable +def mean( + x: AnyTensor, + dim: Union[int, List[int]], + keepdim: bool = False, + *, + dtype: torch.dtype = None, +) -> AnyTensor: + """See torch.mean""" + raise NotImplementedError + + +@mean.trampoline +def _mean_trampoline( + d: SignatureDispatcher, + x: AnyTensor, + dim: Union[int, List[int]], + keepdim: bool = False, + *, + dtype: torch.dtype = None, +) -> AnyTensor: + tensors = (x,) + for override in d.find_overrides(tensors): + result = override(x, dim, keepdim, dtype=dtype) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + +@overridable +def module_register_buffer( + module: torch.nn.Module, name: str, tensor: AnyTensor +) -> None: + """Register the tensor into the module. See torch.nn.Module.register_buffer.""" + ... + + +@module_register_buffer.trampoline +def _module_register_buffer_trampoline( + d: SignatureDispatcher, module: torch.nn.Module, name: str, tensor: AnyTensor +) -> None: + args = (module, tensor) + for override in d.find_overrides(args): + result = override(module, name, tensor) + if result is not NotImplemented: + return override, result + else: + d.fail(args) + + @overridable def rms_norm(x: AnyTensor, weight: AnyTensor, *, epsilon: float) -> AnyTensor: """Computes the full, unbiased RMS normalization of an input.""" @@ -529,6 +658,25 @@ def _scaled_dot_product_attention( d.fail(tensors) +@overridable +def reshape(input: AnyTensor, shape: List[int]) -> AnyTensor: + """Returns a tensor with the same data and number of elements as input, but with + the specified shape. + """ + ... + + +@reshape.trampoline +def _reshape_trampoline(d: SignatureDispatcher, input, shape) -> AnyTensor: + dispatch_args = (input,) + for override in d.find_overrides(dispatch_args): + result = override(input, shape) + if result is not NotImplemented: + return override, result + else: + d.fail(dispatch_args) + + @overridable def reshard( input: AnyTensor | Theta, @@ -636,6 +784,57 @@ def _sharded_sum_trampoline(d: SignatureDispatcher, maybe_sharded: AnyTensor): d.fail(tensors) +@overridable +def softmax( + tensor: AnyTensor, dim: Optional[int] = None, dtype: Optional[torch.dtype] = None +) -> AnyTensor: + """See torch.nn.functional.softmax""" + ... + + +@softmax.trampoline +def _softmax_trampoline( + d: SignatureDispatcher, + tensor: AnyTensor, + dim: Optional[int] = None, + dtype: Optional[torch.dtype] = None, +) -> AnyTensor: + dispatch_args = [tensor] + for override in d.find_overrides(dispatch_args): + result = override(tensor, dim=dim, dtype=dtype) + if result is not NotImplemented: + return override, result + else: + d.fail(dispatch_args) + + +@overridable +def to(tensor: AnyTensor, *args, **kwargs) -> AnyTensor: + """See torch.Tensor.to""" + ... + + +@to.trampoline +def _to_trampoline(d: SignatureDispatcher, tensor: AnyTensor, *args, **kwargs): + dispatch_args = [tensor] + # if len(args) > 0: + # dispatch_args.append(args[0]) + # else: + # if "dtype" in kwargs: + # dispatch_args.append(kwargs["dtype"]) + # elif "device" in kwargs: + # dispatch_args.append(kwargs["device"]) + # else: + # assert "other" in kwargs + # dispatch_args.append(kwargs["other"]) + for override in d.find_overrides(dispatch_args): + result = override(tensor, *args, **kwargs) + if result is not NotImplemented: + return override, result + else: + d.fail(dispatch_args) + + @overridable def transfer_to_logical_device(tensor: AnyTensor, ordinal: int) -> AnyTensor: """Transfer the tensor to a device with ordinal `ordinal`.""" @@ -643,7 +842,7 @@ def transfer_to_logical_device(tensor: AnyTensor, ordinal: int) -> AnyTensor: @transfer_to_logical_device.trampoline -def _transfer_to_logical_device( +def _transfer_to_logical_device_trampoline( d: SignatureDispatcher, tensor: AnyTensor, ordinal: int ): tensors = (tensor,) @@ -655,6 +854,25 @@ def _transfer_to_logical_device( d.fail(tensors) +@overridable +def transpose(tensor: AnyTensor, dim0: int, dim1: int) -> AnyTensor: + """See torch.transpose""" + ... + + +@transpose.trampoline +def _transpose_trampoline( + d: SignatureDispatcher, tensor: AnyTensor, dim0: int, dim1: int +) -> AnyTensor: + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor, dim0, dim1) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + @overridable def unshard(tensor: AnyTensor) -> AnyTensor: """Return the tensor that has the same elements and shape, but is not sharded.""" @@ -670,3 +888,41 @@ def _unshard_trampoline(d: SignatureDispatcher, tensor: AnyTensor): return override, result else: d.fail(tensors) + + +@overridable +def unsqueeze(tensor: AnyTensor, dim: int) -> AnyTensor: + """See torch.unsqueeze""" + ... + + +@unsqueeze.trampoline +def _unsqueeze_trampoline( + d: SignatureDispatcher, tensor: AnyTensor, dim: int +) -> AnyTensor: + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor, dim) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + +@overridable +def view(tensor: AnyTensor, shape: List[int]) -> AnyTensor: + """See torch.Tensor.view""" + ... + + +@view.trampoline +def _view_trampoline( + d: SignatureDispatcher, tensor: AnyTensor, shape: List[int] +) -> AnyTensor: + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor, shape) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) diff --git a/sharktank/sharktank/types/sharding.py b/sharktank/sharktank/types/sharding.py index a670cae33..7d4eb93b4 100644 --- a/sharktank/sharktank/types/sharding.py +++ b/sharktank/sharktank/types/sharding.py @@ -18,7 +18,7 @@ def __init__(self): class TensorSharding(Sharding): - def __init__(self, *, shard_count: int): + def __init__(self, shard_count: int): super().__init__() self.shard_count = shard_count @@ -29,7 +29,7 @@ def __init__(self): class Replicated(TensorSharding): - def __init__(self, *, shard_count: int): + def __init__(self, shard_count: int): super().__init__(shard_count=shard_count) @@ -39,6 +39,17 @@ def __init__(self, *, shard_count: int, shard_dim: int): self.shard_dim = shard_dim +class Ignore(TensorSharding): + """When a theta is sharded, a tensor or a branch with this sharding type will be + ignored. + It will not appear in the resulting sharded theta. + This is not strictly a TensorSharding. It will terminate further traversal of a + branch of a theta tree as well.""" + + def __init__(self): + super().__init__(shard_count=0) + + class ThetaSharding(dict): """Sharding for each tensor in a theta. It is of type dict[str, "ThetaSharding" | TensorSharding]. @@ -49,7 +60,15 @@ def __init__(self, *args, **kwargs): for k, v in d.items(): d[k] = tree.map_nodes( tree=v, - f=lambda x: x if isinstance(x, TensorSharding) else ThetaSharding(x), + f=lambda x: x + if isinstance( + x, + ( + TensorSharding, + ThetaSharding, + ), + ) + else ThetaSharding(x), ) super().__init__(d) @@ -89,6 +108,27 @@ def theta_sharding(self) -> ThetaSharding: ) +class FFNSharding(ThetaLayerSharding): + def __init__(self, shard_count: int): + super().__init__() + self.shard_count = shard_count + + def theta_sharding(self) -> ThetaSharding: + return ThetaSharding( + { + "ffn_gate": LinearSplitReductionDimSharding( + shard_count=self.shard_count + ).theta_sharding(), + "ffn_up": LinearReplicatedInputSplitWeightAndBiasSharding( + shard_count=self.shard_count + ).theta_sharding(), + "ffn_down": LinearSplitReductionDimSharding( + shard_count=self.shard_count + ).theta_sharding(), + } + ) + + class GroupNormSplitChannelSharding(ThetaLayerSharding): def __init__(self, shard_count: int): super().__init__() @@ -105,6 +145,9 @@ def theta_sharding(self) -> ThetaSharding: class LinearReplicatedInputSplitWeightAndBiasSharding(ThetaLayerSharding): def __init__(self, shard_count: int, weight_and_bias_spit_dim: int = 0): + """Split one parallel dimension for both the weight and bias. + Since the weight is transposed before multiplying, the weight parallel + dimension is the same as the output(bias) dimension.""" super().__init__() self.shard_count = shard_count self.weight_and_bias_spit_dim = weight_and_bias_spit_dim @@ -123,3 +166,49 @@ def theta_sharding(self) -> ThetaSharding: ), } ) + + +class LinearSplitReductionDimSharding(ThetaLayerSharding): + def __init__(self, shard_count: int): + super().__init__() + self.shard_count = shard_count + + def theta_sharding(self) -> ThetaSharding: + return ThetaSharding( + { + "premul_input": Replicated(shard_count=self.shard_count), + "weight": Split( + shard_count=self.shard_count, + shard_dim=1, + ), + "bias": Replicated( + shard_count=self.shard_count, + ), + } + ) + + +class RmsNormReplicatedSharding(ThetaLayerSharding): + def __init__(self, shard_count: int): + super().__init__() + self.shard_count = shard_count + + def theta_sharding(self) -> ThetaSharding: + return ThetaSharding( + { + "weight": Replicated(shard_count=self.shard_count), + } + ) + + +class TokenEmbeddingLayerReplicatedSharding(ThetaLayerSharding): + def __init__(self, shard_count: int): + super().__init__() + self.shard_count = shard_count + + def theta_sharding(self) -> ThetaSharding: + return ThetaSharding( + { + "weight": Replicated(shard_count=self.shard_count), + } + ) diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 8d59a89d4..81df53882 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -18,7 +18,8 @@ ) from copy import deepcopy from collections.abc import Collection -from numbers import Integral +from numbers import Integral, Number +import numpy as np from abc import ABC, abstractmethod from dataclasses import dataclass @@ -254,8 +255,15 @@ def transform_globals( return self._clone_with_globals(prev_globals) def to( - self, *, device: Optional[Union[str, torch.device]] = None + self, + *, + device: Optional[Union[str, torch.device]] = None, ) -> "InferenceTensor": + # TODO: reconcile with ops.to(...) and torch.Tensor.to(...). + # Do we always want to clone with globals? + # This makes our type inconsistent with torch tensors. + # If we use this to transform a theta we want to change the theta. + # If we want to use this in a computation we don't want to change the theta. return self.transform_globals( lambda d: {k: t.to(device=device) for k, t in d.items()} ) @@ -278,6 +286,78 @@ def T(self) -> "InferenceTensor": return permute(self, dims=dims) + @property + def dtype(self) -> torch.dtype: + raise NotImplementedError() + + def expand(self, *args: Union[List[List[int]], List[int]]) -> "AnyTensor": + from ..ops import expand + + if all(isinstance(a, int) for a in args): + shape = args + else: + assert len(args) == 1 + shape = args[0] + return expand(self, shape) + + def flatten(self, start_dim: int = 0, end_dim: int = -1) -> "AnyTensor": + from ..ops import flatten + + return flatten(self, start_dim, end_dim) + + def index_copy_( + self, dim: int, index: "AnyTensor", tensor: "AnyTensor" + ) -> "InferenceTensor": + from ..ops import index_copy_ + + return index_copy_(self, dim, index, tensor) + + def mean( + self, + dim: Union[int, List[int]], + keepdim: bool = False, + *, + dtype: torch.dtype = None, + ) -> "AnyTensor": + from ..ops import mean + + return mean(self, dim, keepdim, dtype=None) + + def pow(self, exponent: Union["AnyTensor", Number]) -> "AnyTensor": + from ..ops import elementwise + + return elementwise(torch.pow, self, exponent) + + def reshape(self, *args: Union[List[List[int]], List[int]]) -> "AnyTensor": + from ..ops import reshape + + if all(isinstance(a, int) for a in args): + shape = args + else: + assert len(args) == 1 + shape = args[0] + return reshape(self, shape) + + def transpose(self, dim0: int, dim1: int) -> "AnyTensor": + from ..ops import transpose + + return transpose(self, dim0, dim1) + + def unsqueeze(self, dim: int) -> "AnyTensor": + from ..ops import unsqueeze + + return unsqueeze(self, dim) + + def view(self, *args: Union[List[List[int]], List[int]]) -> "AnyTensor": + from ..ops import view + + if all(isinstance(a, int) for a in args): + shape = args + else: + assert len(args) == 1 + shape = args[0] + return view(self, shape) + def __add__(self, rhs): from ..ops import elementwise @@ -298,6 +378,11 @@ def __rmul__(self, lhs): # numbers on the lhs. return self.__mul__(lhs) + def __truediv__(self, rhs): + from ..ops import elementwise + + return elementwise(torch.div, self, rhs) + REGISTERED_INFERENCE_TENSOR_CLASSES: dict[str, Type[InferenceTensor]] = {} @@ -337,6 +422,10 @@ def as_torch(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ ... + @property + def dtype(self) -> torch.dtype: + return self.as_torch().dtype + @register_inference_tensor class DefaultPrimitiveTensor(PrimitiveTensor): @@ -600,6 +689,10 @@ def name(self, name: str): for i, shard in enumerate(self.shards): shard.name = f"{name}.shard.{i}" + @property + def dtype(self) -> torch.dtype: + return self.shards[0].dtype + @register_inference_tensor class ShardedTensorBase(ShardedTensor): @@ -780,21 +873,25 @@ def __init__( assert len(ts) > 0 first_shape = ts[0].shape assert len(first_shape) > shard_dim - if shape is None: - # Compute the shape. - shape = list(first_shape) - shape[shard_dim] *= len(ts) - - # Assert the shape. - shard_dim_size = first_shape[shard_dim] - for t in ts[1:]: - assert ( - t.shape == first_shape - ), f"Shape mismatch for split tensors: {t.shape} vs {first_shape}" - shard_dim_size += t.shape[shard_dim] - assert ( - shard_dim_size == shape[shard_dim] - ), f"Sharding mismatch: Sharded dims do not cover the whole volume {shard_dim_size} vs {shape[shard_dim]}" + expected_shape = list(first_shape) + expected_shape[shard_dim] = sum([t.shape[shard_dim] for t in ts]) + if shape is not None: + shape = list(shape) + assert expected_shape == shape + else: + shape = expected_shape + + # Assert the shapes. + for i, t in enumerate(ts): + t_shape = list(t.shape) + assert len(shape) == len( + t_shape + ), f"Shape size mismatch tensor shard {i} with shape {t.shape}. Expected shape size {len(shape)}. Got {len(t_shape)}." + assert np.array_equal( + shape[0:shard_dim], t_shape[0:shard_dim] + ) and np.array_equal( + shape[shard_dim + 1 :], t_shape[shard_dim + 1 :] + ), f"Shape mismatch for non-split dimension for tensor shard {i} with shape {t.shape}" super().__init__(name=name, ts=ts, shape=shape, shard_dim=shard_dim) diff --git a/sharktank/sharktank/utils/__init__.py b/sharktank/sharktank/utils/__init__.py new file mode 100644 index 000000000..3651913ca --- /dev/null +++ b/sharktank/sharktank/utils/__init__.py @@ -0,0 +1,7 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .misc import * diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index 3a13b1d45..396c74363 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -106,6 +106,8 @@ def get_input_dataset(args) -> Dataset: if "irpa" in data_files: return Dataset.load(data_files["irpa"], file_type="irpa") + raise ValueError(f'Dataset format unsupported. Must be "gguf" or "irpa".') + def get_tokenizer(args) -> tokenizer.InferenceTokenizer: """Gets a tokenizer based on arguments. diff --git a/sharktank/sharktank/utils/hf_datasets.py b/sharktank/sharktank/utils/hf_datasets.py index e49623a8d..94bc2c631 100644 --- a/sharktank/sharktank/utils/hf_datasets.py +++ b/sharktank/sharktank/utils/hf_datasets.py @@ -83,6 +83,23 @@ def alias_dataset(from_name: str, to_name: str): # Dataset definitions ################################################################################ +Dataset( + "SanctumAI/Meta-Llama-3.1-8B-Instruct-GGUF", + ( + RemoteFile( + "gguf", + "SanctumAI/Meta-Llama-3.1-8B-Instruct-GGUF", + "meta-llama-3.1-8b-instruct.f16.gguf", + ), + RemoteFile( + "tokenizer_config.json", + "NousResearch/Meta-Llama-3-8B", + "tokenizer_config.json", + extra_filenames=["tokenizer.json"], + ), + ), +).alias_to("llama3_8B_fp16") + Dataset( "QuantFactory/Llama-3-8B_q4_1_gguf", ( diff --git a/sharktank/sharktank/utils/math.py b/sharktank/sharktank/utils/math.py index df47b5ae6..3f32ac952 100644 --- a/sharktank/sharktank/utils/math.py +++ b/sharktank/sharktank/utils/math.py @@ -4,6 +4,12 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from numbers import Number + def ceildiv(a: int | float, b: int | float) -> int | float: return -(a // -b) + + +def round_up_to_multiple_of(x: Number, multiple: Number) -> Number: + return x + (-x % multiple) diff --git a/sharktank/sharktank/utils/misc.py b/sharktank/sharktank/utils/misc.py new file mode 100644 index 000000000..3687c3c05 --- /dev/null +++ b/sharktank/sharktank/utils/misc.py @@ -0,0 +1,16 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Any, List + + +def longest_equal_range(l1: List[Any], l2: List[Any]) -> int: + """Find the longest range that is the same from the start of both lists. + Returns the greatest `i` such that `l1[0:i] == l2[0:i]`.""" + for i, (a, b) in enumerate(zip(l1, l2)): + if a != b: + return i + return len(zip(l1, l2)) diff --git a/sharktank/tests/models/llama/sharded_llama_test.py b/sharktank/tests/models/llama/sharded_llama_test.py new file mode 100644 index 000000000..82f0ef789 --- /dev/null +++ b/sharktank/tests/models/llama/sharded_llama_test.py @@ -0,0 +1,115 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest +from sharktank.models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 +import sharktank.ops as ops +from sharktank.models.llama.testing import make_random_llama_theta +from sharktank.models.llama.sharding import shard_theta +from sharktank.layers.configs import LlamaHParams +from sharktank.utils.math import round_up_to_multiple_of +import torch +from copy import deepcopy + + +class AttentionBlockTest(unittest.TestCase): + def testToyModelCompareToUnsharded(self): + """Run a sharded variant of a toy model size and compare it against the + unsharded variant.""" + torch.random.manual_seed(123456) + torch.set_default_dtype(torch.float32) + batch_size = 3 + attention_head_count_kv = 4 + attention_head_count = attention_head_count_kv * 5 + vocabulary_size = 19 + rope_dimension_count = 7 * 2 + attn_head_dim = rope_dimension_count + block_seq_stride = 13 + cache_page_count = 11 + config = LlamaModelConfig( + hp=LlamaHParams( + # context_length=17, + context_length=block_seq_stride * 2, + embedding_length=attention_head_count * attn_head_dim, + block_count=3, + feed_forward_length=23, + rope_dimension_count=rope_dimension_count, + rope_freq_base=500000.0, + attention_head_count=attention_head_count, + attn_head_dim=attn_head_dim, + attention_layer_norm_rms_epsilon=0.01, + attention_head_count_kv=attention_head_count_kv, + expert_count=0, + expert_used_count=0, + ), + block_seq_stride=block_seq_stride, + activation_dtype=torch.float32, + attention_dtype=torch.float32, + ) + theta = make_random_llama_theta( + config=config, + vocab_size=vocabulary_size, + ) + + model = PagedLlamaModelV1(theta, config) + seq_lens = [ + torch.randint(high=config.hp.context_length + 1, size=[1])[0].item() + for _ in range(batch_size - 1) + ] + seq_lens.append(config.hp.context_length) + seq_lens = torch.tensor(seq_lens, dtype=torch.int32) + cache_state = model.cache.paged.allocate(page_count=cache_page_count) + cache_state = [torch.rand_like(cache_state[0])] + batch_seq_len = round_up_to_multiple_of( + config.hp.context_length, model.cache.pad_sequence_stride + ) + token_ids = torch.randint( + low=0, + high=vocabulary_size, + size=[batch_size, batch_seq_len], + dtype=torch.int32, + ) + attention_mask = model.attention_mask( + model.input_mask(seq_lens, config.hp.context_length) + ) + seq_block_ids = torch.arange( + batch_size * batch_seq_len // config.block_seq_stride + ).view(batch_size, -1) + + sharded_config = deepcopy(config) + sharded_config.tensor_parallelism_size = 2 + sharded_theta = shard_theta(theta, sharded_config) + sharded_model = PagedLlamaModelV1(sharded_theta, sharded_config) + sharded_cache_state = sharded_model.cache.paged.allocate( + page_count=cache_page_count + ) + sharded_cache_state = [ops.reshard_like(cache_state[0], sharded_cache_state[0])] + + expected_prefill_result = model.prefill( + token_ids, + attention_mask=attention_mask, + seq_block_ids=seq_block_ids, + cache_state=cache_state, + ) + sharded_prefill_result = sharded_model.prefill( + token_ids, + attention_mask=attention_mask, + seq_block_ids=seq_block_ids, + cache_state=sharded_cache_state, + ) + actual_prefill_result = ops.unshard(sharded_prefill_result) + torch.testing.assert_close(actual_prefill_result, expected_prefill_result) + + # tokens_ids_2 = torch.tensor( + # model.extract_tokens_from_logits(expected_prefill_result, self.seq_lens) + # ).unsqueeze(1) + # seq_lens.add_(1) + # expected_decode_result = model.decode( + # tokens_ids_2, + # attention_mask=attention_mask, + # seq_block_ids=seq_block_ids, + # cache_state=cache_state, + # ) diff --git a/sharktank/tests/ops/sharded_test.py b/sharktank/tests/ops/sharded_test.py index 89a37870c..cf4c6b2cf 100644 --- a/sharktank/tests/ops/sharded_test.py +++ b/sharktank/tests/ops/sharded_test.py @@ -688,6 +688,21 @@ def compute(input, ffn_gate_weight, ffn_down_weight, ffn_up_weight): ) torch.testing.assert_close(Z_sharded, Z_ref) + def testSameSplitLhsAndRhsBatchDim(self): + a = torch.rand(3, 4, 5, 6) + b = torch.rand(3, 4, 6, 7) + shard_count = 2 + shard_dim = 1 + expected_result = torch.matmul(a, b) + sharded_a = ops.reshard_split(a, dim=shard_dim, count=shard_count) + sharded_b = ops.reshard_split(b, dim=shard_dim, count=shard_count) + sharded_result = ops.matmul(sharded_a, sharded_b) + assert isinstance(sharded_result, SplitPrimitiveTensor) + assert sharded_result.shard_count == shard_count + assert sharded_result.shard_dim == shard_dim + actual_result = unbox_tensor(ops.unshard(sharded_result)) + torch.testing.assert_close(actual_result, expected_result) + class ReplicateTest(unittest.TestCase): def testReplicateReplicated(self):