Skip to content

Commit

Permalink
Restore RotaryEmbedding use complex numbers (#517)
Browse files Browse the repository at this point in the history
It appears the change away from complex numbers triggered a downstream
iree failure. An inflight work to use `flow.tensor.bitcast` and restore
to complex numbers appears to fix the issue. Fixing forward as we wanted
to revert part of this change anyway.
  • Loading branch information
rsuderman authored Nov 14, 2024
1 parent bfc7738 commit e381e87
Show file tree
Hide file tree
Showing 10 changed files with 263 additions and 79 deletions.
8 changes: 8 additions & 0 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,14 @@ def decode(self):
trace_tensor("decode.start_positions", start_positions)
trace_tensor("decode.seq_block_ids", seq_block_ids_tensor)
trace_tensor("decode.attention_mask", decode_attention_mask)

if model.config.tensor_parallelism_size != 1:
tp = model.config.tensor_parallelism_size
self.next_tokens = replicate(self.next_tokens, tp)
start_positions = replicate(start_positions, tp)
seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp)
decode_attention_mask = replicate(decode_attention_mask, tp)

logits = model.decode(
self.next_tokens,
attention_mask=decode_attention_mask,
Expand Down
1 change: 1 addition & 0 deletions sharktank/sharktank/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from .conv_2d_nchw_fchw import *
from .pooling_nchw_sum import *
from .base import *
from .bitcast import *
138 changes: 138 additions & 0 deletions sharktank/sharktank/kernels/bitcast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from sharktank.kernels.base import *

import torch

from iree.turbine.support.ir_imports import (
ComplexType,
F16Type,
F32Type,
RankedTensorType,
ShapedType,
Value,
flow_d,
tensor_d,
)

from iree.turbine.runtime.op_reg import (
CustomOp,
KernelBuilder,
KernelSelection,
)

__all__ = [
"bitcast_to_complex",
"bitcast_to_real",
]

_ftype_to_ctype_table = {
torch.float16: torch.complex32,
torch.float32: torch.complex64,
}

_ctype_to_ftype_table = {
torch.complex32: torch.float16,
torch.complex64: torch.float32,
}

_type_to_irtype_table = {
torch.float16: lambda: F16Type.get(),
torch.float32: lambda: F32Type.get(),
torch.complex32: lambda: ComplexType.get(F16Type.get()),
torch.complex64: lambda: ComplexType.get(F32Type.get()),
}


@CustomOp.register(library=LIBRARY)
class bitcast_to_complex(CustomOp):

signature = "bitcast_to_complex(Tensor q) -> (Tensor)"

def select(self, ksel: KernelSelection):
ta = ksel.arg_tensor(0)

torch._check(ta.t.dtype in _ftype_to_ctype_table)
torch._check(isinstance(ta.t.shape[-1], int))

new_shape = [i for i in ta.t.shape]
new_shape[-1] = new_shape[-1] // 2

ctype = _ftype_to_ctype_table[ta.t.dtype]
ret = ksel.return_new_tensor(new_shape, dtype=ctype)
specialize_all_known_dims(ta)
specialize_all_known_dims(ret)

def eager_execute(self, tensor):
return torch.view_as_complex(tensor.unflatten(-1, (-1, 2)))

def generate(self, ksel: KernelSelection, kb: KernelBuilder):
t = kb.arg_bindings[0]
result_desc = ksel.result_descs[0]
result_shape = [
d if isinstance(d, int) else RankedTensorType.get_dynamic_size()
for d in result_desc.t.shape
]

dynamic_dims: list[Value] = []
_append_dynamic_dims(kb, dynamic_dims, t)

c64 = _type_to_irtype_table[result_desc.t.dtype]()
rtt = RankedTensorType.get(result_shape, c64)
result = flow_d.TensorBitCastOp(rtt, t, dynamic_dims, dynamic_dims).result
kb.yield_results(result)


@CustomOp.register(library=LIBRARY)
class bitcast_to_real(CustomOp):

signature = "bitcast_to_real(Tensor q) -> (Tensor)"

def select(self, ksel: KernelSelection):
ta = ksel.arg_tensor(0)

torch._check(ta.t.dtype in _ctype_to_ftype_table)
torch._check(isinstance(ta.t.shape[-1], int))

new_shape = [i for i in ta.t.shape]
new_shape[-1] = new_shape[-1] * 2

ftype = _ctype_to_ftype_table[ta.t.dtype]
ret = ksel.return_new_tensor(new_shape, dtype=ftype)
specialize_all_known_dims(ta)
specialize_all_known_dims(ret)

def eager_execute(self, tensor):
return torch.view_as_real(tensor).flatten(-2, -1)

def generate(self, ksel: KernelSelection, kb: KernelBuilder):
t = kb.arg_bindings[0]
result_desc = ksel.result_descs[0]
result_shape = [
d if isinstance(d, int) else RankedTensorType.get_dynamic_size()
for d in result_desc.t.shape
]

dynamic_dims: list[Value] = []
_append_dynamic_dims(kb, dynamic_dims, t)

ftype = _type_to_irtype_table[result_desc.t.dtype]()
rtt = RankedTensorType.get(result_shape, ftype)
result = flow_d.TensorBitCastOp(rtt, t, dynamic_dims, dynamic_dims).result
kb.yield_results(result)


################################################################################
# Emission utilities
################################################################################


def _append_dynamic_dims(kb: KernelBuilder, dynamic_dims: list[Value], tensor: Value):
rtt = RankedTensorType(tensor.type)
for i in range(rtt.rank):
if rtt.is_dynamic_dim(i):
dynamic_dims.append(tensor_d.dim(tensor, kb.constant_index(i)))
81 changes: 28 additions & 53 deletions sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,54 +138,29 @@ def create_ordering_tensor(dim):

if self.use_hf:
xt = xt[..., create_interleaved_tensor(xt.shape[-1])]
xt_ = xt.unflatten(-1, (-1, 2))
_, sl, _, dim, _ = xt_.shape
xt_ = xt
_, sl, _, _ = xt_.shape

# Offset the table based on starting position.
if self.use_table:
freqs_cis = rotary_embed_table[:, start_index : start_index + sl, :]
freqs_cis = rotary_embed_table[start_index : start_index + sl, :]
freqs_cis = freqs_cis[None, 0:sl, None, :]
else:
freqs_cis = torch.arange(start_index, start_index + sl, device=xt.device)
freqs_cis = self._compute_rotary_embed_table(freqs_cis)
freqs_cis = torch.arange(sl, device=xt.device) + start_index
freqs_cis = self._compute_rotary_embed_table(freqs_cis)[None, :, None, :]

assert freqs_cis.shape[-1] == dim
assert (
freqs_cis.shape[1] >= sl
), f"Sequence length longer than embedding table ({sl} vs {freqs_cis.shape[0]})"

broadcast_freqs_cis = freqs_cis[:, None, 0:sl, None, :]

cos = broadcast_freqs_cis[0]
sin = broadcast_freqs_cis[1]
xt_r = xt_[..., 0]
xt_i = xt_[..., 1]

xt_out_r = xt_r * cos - xt_i * sin
xt_out_i = xt_i * cos + xt_r * sin

xt_out = torch.concatenate((xt_out_r, xt_out_i), dim=-1)
xt_ = ops.view_as_complex(xt_)
xt_ = xt_ * freqs_cis
xt_out = ops.view_as_real(xt_)

if self.use_hf:
xt_out = xt_out[..., create_ordering_tensor(xt_out.shape[-1])]
return xt_out.type_as(xt)

return xt_out.type_as(xt)

def complex_multiply(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Function for elementwise-multiplication of two complex torch tensors.
Functionally similar to a*b, but numerically accurate for HuggingFace
LLaMa implementation.
Args:
a: First torch tensor operand
b: Second torch tensor operand
Returns:
Tensor of same size to a, b whose elements is product of corresponding
elements in a, b
"""
return torch.complex(
a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real
)
return ops.to(xt_out, xt.dtype)

def compute_batch_mask(
self, start_positions: Union[torch.Tensor, ReplicatedTensor], batch_seq_len: int
Expand All @@ -207,11 +182,18 @@ def compute_batch_mask(
self.trace_tensor("rope.positions_seq", positions_seq)

if self.use_table:
freqs_cis = self.rotary_embed_table[:, positions_seq]
freqs_cis = self.rotary_embed_table[positions_seq]
else:
shape = positions_seq.shape
freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten())
freqs_cis = freqs_cis.unflatten(1, shape)
if isinstance(positions_seq, ReplicatedTensor):
ts = [
self._compute_rotary_embed_table(s.flatten()).unflatten(0, shape)
for s in positions_seq.shards
]
freqs_cis = ReplicatedTensor(ts=ts)
else:
freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten())
freqs_cis = freqs_cis.unflatten(0, shape)

# Unsqueeze a unit dim for attention heads.
broadcast_freqs_cis = freqs_cis.unsqueeze(2)
Expand Down Expand Up @@ -247,30 +229,23 @@ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor):
"""
# xq_, xk_ shape: bs, sl, _, dim
# freqs_cis shape: max_sl, dim
cos = mask[0]
sin = mask[1]

xt_ = xt.unflatten(-1, (-1, 2))
xt_r = xt_[..., 0]
xt_i = xt_[..., 1]
xt_ = ops.view_as_complex(xt)
xt_ = xt_ * mask
xt_out = ops.view_as_real(xt_)

xt_out_r = xt_r * cos - xt_i * sin
xt_out_i = xt_r * sin + xt_i * cos
xt_out = torch.concatenate((xt_out_r, xt_out_i), dim=-1)
return xt_out.type_as(xt)

def _compute_rotary_embed_table(self, t):
dim = self.rope_dimension_count
freqs = 1.0 / (
self.rope_freq_base
** (torch.arange(0, dim, 2, device=t.device)[: (dim // 2)].float() / dim)
self.rope_freq_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
)
freqs = torch.outer(t, freqs).float()

cos = torch.cos(freqs).unsqueeze(0)
sin = torch.sin(freqs).unsqueeze(0)

return torch.concatenate((cos, sin), dim=0)
cos = torch.cos(freqs)
sin = torch.sin(freqs)
complex = torch.complex(cos, sin)
return complex

def _create_rotary_embed_table(self):
t = torch.arange(self.max_seqlen, device=self.device)
Expand Down
23 changes: 0 additions & 23 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,29 +186,6 @@ def decode(
self._assert_device(start_positions)
self._assert_device(*cache_state, dtype=self.activation_dtype)

if self.config.tensor_parallelism_size > 1:
if not isinstance(tokens, ReplicatedTensor):
tokens = ops.replicate(
tokens, count=self.config.tensor_parallelism_size
)
if not isinstance(attention_mask, ReplicatedTensor):
attention_mask = ops.replicate(
attention_mask, count=self.config.tensor_parallelism_size
)
if not isinstance(start_positions, ReplicatedTensor):
start_positions = ops.replicate(
start_positions, count=self.config.tensor_parallelism_size
)
if not isinstance(seq_block_ids, ReplicatedTensor):
seq_block_ids = ops.replicate(
seq_block_ids, count=self.config.tensor_parallelism_size
)
# If the user provided unsharded arguments they probably want
# an unsharded result as well.
unshard_result = True
else:
unshard_result = False

bs, _ = tokens.shape
# Precompute a position based mask for computing rope embeddings
# as it is the same for all blocks.
Expand Down
19 changes: 17 additions & 2 deletions sharktank/sharktank/ops/custom_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,24 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch

from torch import Tensor, dtype
from typing import Union

import torch.nn.functional as F

from ..kernels import (
einsum_2args_q4,
mmt_block_scaled_offset_q4_unsigned,
mmt_block_scaled_q8,
mmtfp,
mmt_super_block_scaled_offset_q4_unsigned,
bitcast_to_complex,
bitcast_to_real,
)

from ..types import (
BlockScaledLayout,
BlockScaledI4Layout,
InferenceTensor,
PrimitiveTensor,
QuantizedTensor,
SuperBlockOffsetScaled_4_6_Layout,
Expand Down Expand Up @@ -123,3 +126,15 @@ def matmul_generic_tensor_super_block_offset_scaled_4_6_i4(
sb_mins_low,
rhs_unpacked.qs_bit_packed,
)


@view_as_complex.override(Union[Tensor, PrimitiveTensor])
def view_as_complex(t):
t = unbox_tensor(t)
return bitcast_to_complex(t)


@view_as_real.override(Union[Tensor, PrimitiveTensor])
def view_as_real(t):
t = unbox_tensor(t)
return bitcast_to_real(t)
10 changes: 10 additions & 0 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,3 +503,13 @@ def view_QuantizedTensor(tensor: QuantizedTensor, shape):
new_m = unpacked.m.view(shape[:-1] + [shape[-1] // 32, 1])
layout = BlockScaledI4Layout(shape=shape, d=new_d, qs=new_qs, m=new_m)
return PlanarQuantizedTensor(shape=shape, layout=layout)


@view_as_complex.override(Tensor)
def view_as_complex_default(tensor: Union[Tensor, PrimitiveTensor]) -> Tensor:
return torch.view_as_complex(unbox_tensor(tensor))


@view_as_real.override(Tensor)
def view_as_real_default(tensor: Union[Tensor, PrimitiveTensor]) -> Tensor:
return torch.view_as_real(unbox_tensor(tensor))
24 changes: 24 additions & 0 deletions sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,3 +1303,27 @@ def view_split(tensor: SplitPrimitiveTensor, shape: List[int]) -> SplitPrimitive
res = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards)
assert math.prod(res.shape) == math.prod(tensor.shape)
return res


@view_as_complex.override(SplitPrimitiveTensor)
def view_as_complex_split(tensor: SplitPrimitiveTensor) -> SplitPrimitiveTensor:
shards = [view_as_complex(shard) for shard in tensor.shards]
return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim)


@view_as_complex.override(ReplicatedTensor)
def view_as_complex_rep(tensor: ReplicatedTensor) -> ReplicatedTensor:
shards = [view_as_complex(shard) for shard in tensor.shards]
return ReplicatedTensor(ts=shards)


@view_as_real.override(SplitPrimitiveTensor)
def view_as_real_split(tensor: SplitPrimitiveTensor) -> SplitPrimitiveTensor:
shards = [view_as_real(shard) for shard in tensor.shards]
return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim)


@view_as_real.override(ReplicatedTensor)
def view_as_real_rep(tensor: ReplicatedTensor) -> ReplicatedTensor:
shards = [view_as_real(shard) for shard in tensor.shards]
return ReplicatedTensor(ts=shards)
Loading

0 comments on commit e381e87

Please sign in to comment.