Skip to content

Commit

Permalink
basic sharding support for quant tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
IanNod committed Oct 25, 2024
1 parent 6f3f8c7 commit 2fb5eeb
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 49 deletions.
4 changes: 2 additions & 2 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,8 @@ def to_default(tensor: Tensor, *args, **kwargs):
return unbox_tensor(tensor).to(*args, **kwargs)


@transfer_to_logical_device.override(Tensor)
def transfer_to_logical_device_default(tensor: Tensor, ordinal: int):
@transfer_to_logical_device.override(AllOfType(AnyTensor, QuantizedTensor))
def transfer_to_logical_device_default(tensor, ordinal: int):
return iree.turbine.ops.iree.transfer_to_logical_device(
f"{ordinal}", unbox_tensor(tensor)
)
Expand Down
83 changes: 83 additions & 0 deletions sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
AnyTensor,
DefaultPrimitiveTensor,
InferenceTensor,
QuantizedTensor,
PlanarQuantizedTensor,
PrimitiveTensor,
ReplicatedTensor,
ShardedTensor,
Expand All @@ -28,6 +30,8 @@
from .signatures import *
from .shape import broadcast_dims, broadcast_dim, unbroadcast_dim
from ..utils import longest_equal_range
from ..utils.math import ceildiv
from sharktank.types.tensors import REGISTERED_LAYOUT_CLASSES


@all_gather.override(SplitPrimitiveTensor)
Expand Down Expand Up @@ -1264,3 +1268,82 @@ def view_split(tensor: SplitPrimitiveTensor, shape: List[int]) -> SplitPrimitive
res = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards)
assert math.prod(res.shape) == math.prod(tensor.shape)
return res


@split.override(QuantizedTensor)
def split_QuantizedTensor(tensor: QuantizedTensor, split_size_or_sections, dim):
tensors = []
unpacked = tensor.unpack()
num_outputs = ceildiv(unpacked._qs.shape[dim], split_size_or_sections)
new_shape = unpacked._shape
new_shape[dim] = split_size_or_sections
new_qs = torch.split(unpacked._qs, split_size_or_sections, dim)
if unpacked._d.ndim > 0:
new_d = torch.split(unpacked._d, split_size_or_sections, dim)
if unpacked.serialized_name() == "SuperBlockOffsetScaled_4_6_Layout":
new_dmin = torch.split(unpacked._dmin, split_size_or_sections, dim)
new_sb_scales_high = torch.split(
unpacked._sb_scales_high, split_size_or_sections, dim
)
new_sb_scales_low = torch.split(
unpacked._sb_scales_low, split_size_or_sections, dim
)
new_sb_mins_high = torch.split(
unpacked._sb_mins_high, split_size_or_sections, dim
)
new_sb_mins_low = torch.split(
unpacked._sb_mins_low, split_size_or_sections, dim
)
for i in range(num_outputs):
layout_clazz = REGISTERED_LAYOUT_CLASSES[unpacked.serialized_name()]
new_layout = layout_clazz(
shape=new_shape,
d=new_d[i],
dmin=new_dmin[i],
sb_scales_high=new_sb_scales_high[i],
sb_scales_low=new_sb_scales_low[i],
sb_mins_high=new_sb_mins_high[i],
sb_mins_low=new_sb_mins_low[i],
qs=new_qs[i],
)
new_tensor = tensor.__class__
new_tensor_layout = new_layout.create(
new_layout.shape, new_layout.metadata, new_layout.planes
)
new_tensor = tensor.__class__(
shape=new_shape, layout=new_tensor_layout, name=tensor._name + str(i)
)
tensors.append(new_tensor)
else:
if split_size_or_sections > unpacked._qs.shape[dim]:
raise ValueError("split size greater than tensor dim")

if unpacked._m is not None:
if unpacked._m.ndim > 0:
new_m = torch.split(unpacked._m, split_size_or_sections, dim)
for i in range(num_outputs):
layout_clazz = REGISTERED_LAYOUT_CLASSES[unpacked.serialized_name()]
if unpacked._m is not None:
if unpacked._d.ndim > 0:
new_layout = layout_clazz(
shape=new_shape, d=new_d[i], qs=new_qs[i], m=new_m[i]
)
else:
new_layout = layout_clazz(
shape=new_shape, d=unpacked._d, qs=new_qs[i], m=unpacked._m
)
else:
if unpacked._d.ndim > 0:
new_layout = layout_clazz(shape=new_shape, d=new_d[i], qs=new_qs[i])
else:
new_layout = layout_clazz(
shape=new_shape, d=unpacked._d, qs=new_qs[i]
)
new_tensor_layout = new_layout.create(
new_layout.shape, new_layout.metadata, new_layout.planes
)
new_tensor = tensor.__class__(
shape=new_shape, layout=new_tensor_layout, name=tensor._name + str(i)
)
tensors.append(new_tensor)
return tensors
51 changes: 44 additions & 7 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@
import torch
import numbers
from torch import Tensor, dtype
from ..types import AnyTensor, ShardedTensor, Theta, sharding, InferenceTensor
from ..types import (
AnyTensor,
ShardedTensor,
Theta,
sharding,
InferenceTensor,
QuantizedTensor,
PlanarQuantizedTensor,
)
from numbers import Number

from ._registry import *
Expand Down Expand Up @@ -59,6 +67,7 @@
"unshard",
"unsqueeze",
"view",
"split",
]

IntOrSequenceInt = Union[int, Sequence[int]]
Expand Down Expand Up @@ -101,8 +110,9 @@ def _all_reduce_trampoline(d: SignatureDispatcher, tensor: AnyTensor):


@overridable
def cat(tensors: Tuple[AnyTensor, ...] | List[AnyTensor], dim: int = 0) -> AnyTensor:
...
def cat(
tensors: Tuple[AnyTensor, ...] | List[AnyTensor], dim: int = 0
) -> AnyTensor: ...


@cat.trampoline
Expand Down Expand Up @@ -919,8 +929,7 @@ def _sharded_cat_trampoline(d: SignatureDispatcher, maybe_sharded: AnyTensor):


@overridable
def sharded_sum(maybe_sharded: AnyTensor):
...
def sharded_sum(maybe_sharded: AnyTensor): ...


@sharded_sum.trampoline
Expand Down Expand Up @@ -976,14 +985,18 @@ def _to_trampoline(d: SignatureDispatcher, tensor: AnyTensor, *args, **kwargs):


@overridable
def transfer_to_logical_device(tensor: AnyTensor, ordinal: int) -> AnyTensor:
def transfer_to_logical_device(
tensor: Union[AnyTensor, QuantizedTensor, PlanarQuantizedTensor], ordinal: int
) -> Union[AnyTensor, QuantizedTensor, PlanarQuantizedTensor]:
"""Transfer the tensor to a device with ordinal `ordinal`."""
...


@transfer_to_logical_device.trampoline
def _transfer_to_logical_device_trampoline(
d: SignatureDispatcher, tensor: AnyTensor, ordinal: int
d: SignatureDispatcher,
tensor: Union[AnyTensor, QuantizedTensor, PlanarQuantizedTensor],
ordinal: int,
):
tensors = (tensor,)
for override in d.find_overrides(tensors):
Expand Down Expand Up @@ -1085,3 +1098,27 @@ def _view_trampoline(
return override, result
else:
d.fail(tensors)


@overridable
def split(
tensor: QuantizedTensor, split_size_or_sections: List[int], dim: int
) -> [QuantizedTensor]:
"""See torch.Tensor.split"""
...


@split.trampoline
def _split_trampoline(
d: SignatureDispatcher,
tensor: QuantizedTensor,
split_size_or_sections: List[int],
dim: int,
) -> [QuantizedTensor]:
tensors = (tensor,)
for override in d.find_overrides(tensors):
result = override(tensor, split_size_or_sections, dim)
if result is not NotImplemented:
return override, result
else:
d.fail(tensors)
20 changes: 11 additions & 9 deletions sharktank/sharktank/types/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,17 @@ def __init__(self, *args, **kwargs):
for k, v in d.items():
d[k] = tree.map_nodes(
tree=v,
f=lambda x: x
if isinstance(
x,
(
TensorSharding,
ThetaSharding,
),
)
else ThetaSharding(x),
f=lambda x: (
x
if isinstance(
x,
(
TensorSharding,
ThetaSharding,
),
)
else ThetaSharding(x)
),
)
super().__init__(d)

Expand Down
20 changes: 10 additions & 10 deletions sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@

class QuantizedLayout(ABC):
@abstractmethod
def dequant(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
...
def dequant(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: ...

@classmethod
@abstractmethod
Expand All @@ -78,8 +77,7 @@ def create(
shape: list[int],
metadata: Optional[dict[str, MetaDataValueType]],
planes: dict[str, torch.Tensor],
) -> "QuantizedLayout":
...
) -> "QuantizedLayout": ...

@property
@abstractmethod
Expand Down Expand Up @@ -559,8 +557,7 @@ def __init__(
self.layout_type = layout_type

@abstractmethod
def unpack(self) -> QuantizedLayoutT:
...
def unpack(self) -> QuantizedLayoutT: ...

def to_planar(self) -> "PlanarQuantizedTensor":
"""Converts this QuantizedTensor to a generic planar form.
Expand All @@ -581,6 +578,11 @@ def add_to_archive(self, builder: ShardedArchiveBuilder) -> InferenceTensorMetad
"""
return self.to_planar().add_to_archive(builder)

def split(self, split_size_or_sections: [int], dim: int) -> "[QuantizedTensor]":
from ..ops import split

return split(self, split_size_or_sections, dim)


@register_inference_tensor
class PlanarQuantizedTensor(QuantizedTensor):
Expand Down Expand Up @@ -717,8 +719,7 @@ def __init__(

@property
@abstractmethod
def shard_count(self) -> int:
...
def shard_count(self) -> int: ...

@property
@abstractmethod
Expand Down Expand Up @@ -941,9 +942,8 @@ def __init__(
will be split along dimension `shard_dim` into `shard_count`
number of pieces.
"""
if isinstance(ts, torch.Tensor):
if isinstance(ts, torch.Tensor) or isinstance(ts, InferenceTensor):
from ..ops import transfer_to_logical_device

assert shard_count is not None
ts = ts.split(ceildiv(ts.shape[shard_dim], shard_count), dim=shard_dim)
ts = [transfer_to_logical_device(t, i) for i, t in enumerate(ts)]
Expand Down
20 changes: 0 additions & 20 deletions sharktank/tests/ops/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,26 +194,6 @@ def testTorchImplTransposedQuantizedRHS_BlockScaledLayout(self):
ops.custom_impls.matmul_generic_tensor_block_scaled,
)

def testTorchImplTransposedQuantizedRHS_BlockScaledOffsetI4(self):
ops._registry._test_enable_last_op_dispatch(True)
a_dtype = torch.float32
d_dtype = torch.float32
ref_dtype = torch.float32
a = torch.rand([4, 16, 3200], dtype=a_dtype) / 256.0
d = torch.rand([3200, 100, 1], dtype=d_dtype) / 256.0
qs = (torch.rand([3200, 100, 16], dtype=ref_dtype) * 255.0).to(torch.uint8)
m = torch.rand([3200, 100, 1], dtype=d_dtype) + 16.0
rhs_pqt = PlanarQuantizedTensor(
shape=[3200, 3200],
layout=BlockScaledI4Layout([3200, 3200], d, qs, m=m, signed=False),
)
result = ops.matmul(a, rhs_pqt, transpose_rhs=True)
# Just verifying dispatch. Numerics are tested at the kernel level.
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
ops.custom_impls.matmul_generic_tensor_block_scaled_i4,
)

# TODO: mmt_super_block_scaled_offset_q4_unsigned


Expand Down
Loading

0 comments on commit 2fb5eeb

Please sign in to comment.