Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

basic sharding support for quant tensors #337

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,8 @@ def to_default(tensor: Tensor, *args, **kwargs):
return unbox_tensor(tensor).to(*args, **kwargs)


@transfer_to_logical_device.override(Tensor)
def transfer_to_logical_device_default(tensor: Tensor, ordinal: int):
@transfer_to_logical_device.override(AllOfType(AnyTensor, QuantizedTensor))
def transfer_to_logical_device_default(tensor, ordinal: int):
return iree.turbine.ops.iree.transfer_to_logical_device(
f"{ordinal}", unbox_tensor(tensor)
)
Expand Down
83 changes: 83 additions & 0 deletions sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
AnyTensor,
DefaultPrimitiveTensor,
InferenceTensor,
QuantizedTensor,
PlanarQuantizedTensor,
PrimitiveTensor,
ReplicatedTensor,
ShardedTensor,
Expand All @@ -28,6 +30,8 @@
from .signatures import *
from .shape import broadcast_dims, broadcast_dim, unbroadcast_dim
from ..utils import longest_equal_range
from ..utils.math import ceildiv
from sharktank.types.tensors import REGISTERED_LAYOUT_CLASSES


@all_gather.override(SplitPrimitiveTensor)
Expand Down Expand Up @@ -1264,3 +1268,82 @@ def view_split(tensor: SplitPrimitiveTensor, shape: List[int]) -> SplitPrimitive
res = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards)
assert math.prod(res.shape) == math.prod(tensor.shape)
return res


@split.override(QuantizedTensor)
def split_QuantizedTensor(tensor: QuantizedTensor, split_size_or_sections, dim):
tensors = []
unpacked = tensor.unpack()
num_outputs = ceildiv(unpacked._qs.shape[dim], split_size_or_sections)
new_shape = unpacked._shape
new_shape[dim] = split_size_or_sections
new_qs = torch.split(unpacked._qs, split_size_or_sections, dim)
if unpacked._d.ndim > 0:
new_d = torch.split(unpacked._d, split_size_or_sections, dim)
if unpacked.serialized_name() == "SuperBlockOffsetScaled_4_6_Layout":
new_dmin = torch.split(unpacked._dmin, split_size_or_sections, dim)
new_sb_scales_high = torch.split(
unpacked._sb_scales_high, split_size_or_sections, dim
)
new_sb_scales_low = torch.split(
unpacked._sb_scales_low, split_size_or_sections, dim
)
new_sb_mins_high = torch.split(
unpacked._sb_mins_high, split_size_or_sections, dim
)
new_sb_mins_low = torch.split(
unpacked._sb_mins_low, split_size_or_sections, dim
)
for i in range(num_outputs):
layout_clazz = REGISTERED_LAYOUT_CLASSES[unpacked.serialized_name()]
new_layout = layout_clazz(
shape=new_shape,
d=new_d[i],
dmin=new_dmin[i],
sb_scales_high=new_sb_scales_high[i],
sb_scales_low=new_sb_scales_low[i],
sb_mins_high=new_sb_mins_high[i],
sb_mins_low=new_sb_mins_low[i],
qs=new_qs[i],
)
new_tensor = tensor.__class__
new_tensor_layout = new_layout.create(
new_layout.shape, new_layout.metadata, new_layout.planes
)
new_tensor = tensor.__class__(
shape=new_shape, layout=new_tensor_layout, name=tensor._name + str(i)
)
tensors.append(new_tensor)
else:
if split_size_or_sections > unpacked._qs.shape[dim]:
raise ValueError("split size greater than tensor dim")

if unpacked._m is not None:
if unpacked._m.ndim > 0:
new_m = torch.split(unpacked._m, split_size_or_sections, dim)
for i in range(num_outputs):
layout_clazz = REGISTERED_LAYOUT_CLASSES[unpacked.serialized_name()]
if unpacked._m is not None:
if unpacked._d.ndim > 0:
new_layout = layout_clazz(
shape=new_shape, d=new_d[i], qs=new_qs[i], m=new_m[i]
)
else:
new_layout = layout_clazz(
shape=new_shape, d=unpacked._d, qs=new_qs[i], m=unpacked._m
)
else:
if unpacked._d.ndim > 0:
new_layout = layout_clazz(shape=new_shape, d=new_d[i], qs=new_qs[i])
else:
new_layout = layout_clazz(
shape=new_shape, d=unpacked._d, qs=new_qs[i]
)
new_tensor_layout = new_layout.create(
new_layout.shape, new_layout.metadata, new_layout.planes
)
new_tensor = tensor.__class__(
shape=new_shape, layout=new_tensor_layout, name=tensor._name + str(i)
)
tensors.append(new_tensor)
return tensors
43 changes: 40 additions & 3 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@
import torch
import numbers
from torch import Tensor, dtype
from ..types import AnyTensor, ShardedTensor, Theta, sharding, InferenceTensor
from ..types import (
AnyTensor,
ShardedTensor,
Theta,
sharding,
InferenceTensor,
QuantizedTensor,
PlanarQuantizedTensor,
)
from numbers import Number

from ._registry import *
Expand Down Expand Up @@ -59,6 +67,7 @@
"unshard",
"unsqueeze",
"view",
"split",
]

IntOrSequenceInt = Union[int, Sequence[int]]
Expand Down Expand Up @@ -976,14 +985,18 @@ def _to_trampoline(d: SignatureDispatcher, tensor: AnyTensor, *args, **kwargs):


@overridable
def transfer_to_logical_device(tensor: AnyTensor, ordinal: int) -> AnyTensor:
def transfer_to_logical_device(
tensor: Union[AnyTensor, QuantizedTensor, PlanarQuantizedTensor], ordinal: int
) -> Union[AnyTensor, QuantizedTensor, PlanarQuantizedTensor]:
"""Transfer the tensor to a device with ordinal `ordinal`."""
...


@transfer_to_logical_device.trampoline
def _transfer_to_logical_device_trampoline(
d: SignatureDispatcher, tensor: AnyTensor, ordinal: int
d: SignatureDispatcher,
tensor: Union[AnyTensor, QuantizedTensor, PlanarQuantizedTensor],
ordinal: int,
):
tensors = (tensor,)
for override in d.find_overrides(tensors):
Expand Down Expand Up @@ -1085,3 +1098,27 @@ def _view_trampoline(
return override, result
else:
d.fail(tensors)


@overridable
def split(
tensor: QuantizedTensor, split_size_or_sections: List[int], dim: int
) -> [QuantizedTensor]:
"""See torch.Tensor.split"""
...


@split.trampoline
def _split_trampoline(
d: SignatureDispatcher,
tensor: QuantizedTensor,
split_size_or_sections: List[int],
dim: int,
) -> [QuantizedTensor]:
tensors = (tensor,)
for override in d.find_overrides(tensors):
result = override(tensor, split_size_or_sections, dim)
if result is not NotImplemented:
return override, result
else:
d.fail(tensors)
20 changes: 11 additions & 9 deletions sharktank/sharktank/types/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,17 @@ def __init__(self, *args, **kwargs):
for k, v in d.items():
d[k] = tree.map_nodes(
tree=v,
f=lambda x: x
if isinstance(
x,
(
TensorSharding,
ThetaSharding,
),
)
else ThetaSharding(x),
f=lambda x: (
x
if isinstance(
x,
(
TensorSharding,
ThetaSharding,
),
)
else ThetaSharding(x)
),
)
super().__init__(d)

Expand Down
31 changes: 20 additions & 11 deletions sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,11 @@ def add_to_archive(self, builder: ShardedArchiveBuilder) -> InferenceTensorMetad
"""
return self.to_planar().add_to_archive(builder)

def split(self, split_size_or_sections: [int], dim: int) -> "[QuantizedTensor]":
from ..ops import split

return split(self, split_size_or_sections, dim)


@register_inference_tensor
class PlanarQuantizedTensor(QuantizedTensor):
Expand Down Expand Up @@ -764,12 +769,14 @@ def __init__(
assert shard_dim is None or (shard_dim >= 0 and len(ts[0].shape) > shard_dim)
super().__init__(name=name, shape=shape, shard_dim=shard_dim)
self._shards: tuple[DefaultPrimitiveTensor] = tuple(
DefaultPrimitiveTensor(
name=f"{name}.shard.{i}",
data=t,
(
DefaultPrimitiveTensor(
name=f"{name}.shard.{i}",
data=t,
)
if isinstance(t, torch.Tensor)
else t
)
if isinstance(t, torch.Tensor)
else t
for i, t in enumerate(ts)
)

Expand Down Expand Up @@ -941,7 +948,7 @@ def __init__(
will be split along dimension `shard_dim` into `shard_count`
number of pieces.
"""
if isinstance(ts, torch.Tensor):
if isinstance(ts, torch.Tensor) or isinstance(ts, InferenceTensor):
from ..ops import transfer_to_logical_device

assert shard_count is not None
Expand Down Expand Up @@ -1082,12 +1089,14 @@ def __init__(
assert shape == list(shard.shape)

self._shards: tuple[DefaultPrimitiveTensor] = tuple(
DefaultPrimitiveTensor(
name=f"{name}.shard.{i}",
data=t,
(
DefaultPrimitiveTensor(
name=f"{name}.shard.{i}",
data=t,
)
if isinstance(t, torch.Tensor)
else t
)
if isinstance(t, torch.Tensor)
else t
for i, t in enumerate(ts)
)

Expand Down
20 changes: 0 additions & 20 deletions sharktank/tests/ops/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,26 +194,6 @@ def testTorchImplTransposedQuantizedRHS_BlockScaledLayout(self):
ops.custom_impls.matmul_generic_tensor_block_scaled,
)

def testTorchImplTransposedQuantizedRHS_BlockScaledOffsetI4(self):
ops._registry._test_enable_last_op_dispatch(True)
a_dtype = torch.float32
d_dtype = torch.float32
ref_dtype = torch.float32
a = torch.rand([4, 16, 3200], dtype=a_dtype) / 256.0
d = torch.rand([3200, 100, 1], dtype=d_dtype) / 256.0
qs = (torch.rand([3200, 100, 16], dtype=ref_dtype) * 255.0).to(torch.uint8)
m = torch.rand([3200, 100, 1], dtype=d_dtype) + 16.0
rhs_pqt = PlanarQuantizedTensor(
shape=[3200, 3200],
layout=BlockScaledI4Layout([3200, 3200], d, qs, m=m, signed=False),
)
result = ops.matmul(a, rhs_pqt, transpose_rhs=True)
# Just verifying dispatch. Numerics are tested at the kernel level.
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
ops.custom_impls.matmul_generic_tensor_block_scaled_i4,
)

# TODO: mmt_super_block_scaled_offset_q4_unsigned


Expand Down
78 changes: 77 additions & 1 deletion sharktank/tests/ops/sharded_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def testAllGather(self):

sharded = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards)
actual_result = ops.all_gather(sharded)

for shard in actual_result.shards:
torch.testing.assert_close(shard.as_torch(), expected_result)

Expand Down Expand Up @@ -770,6 +769,83 @@ def testSameSplitLhsAndRhsBatchDim(self):
actual_result = unbox_tensor(ops.unshard(sharded_result))
torch.testing.assert_close(actual_result, expected_result)

def testTranposedQuantizedRHSSharded_BlockScaledOffsetI4(self):
ops._registry._test_enable_last_op_dispatch(True)
a_dtype = torch.float32
d_dtype = torch.float32
ref_dtype = torch.float32
a = torch.rand([4, 16, 3200], dtype=a_dtype) / 256.0
d = torch.rand([3200, 100, 1], dtype=d_dtype) / 256.0
qs = (torch.rand([3200, 100, 16], dtype=ref_dtype) * 255.0).to(torch.uint8)
m = torch.rand([3200, 100, 1], dtype=d_dtype) + 16.0
rhs_pqt = PlanarQuantizedTensor(
shape=[3200, 3200],
layout=BlockScaledI4Layout([3200, 3200], d, qs, m=m, signed=False),
)
expected_result = ops.matmul(a, rhs_pqt, transpose_rhs=True)

shard_count = 2
rhs_pqt_sharded = SplitPrimitiveTensor(
shard_dim=0, ts=rhs_pqt, shard_count=shard_count
)

sharded_result = ops.matmul(a, rhs_pqt_sharded, transpose_rhs=True)
actual_result = ops.sharded_cat(sharded_result)

torch.testing.assert_close(actual_result, expected_result)

def testTorchImplTransposedQuantizedRHSSharded_BlockScaledLayout(self):
ops._registry._test_enable_last_op_dispatch(True)
a_dtype = torch.float32
d_dtype = torch.float32
ref_dtype = torch.float32
a = torch.rand([4, 16, 3200], dtype=a_dtype) * 64
d = torch.rand([3200, 100, 1], dtype=d_dtype) * 64
qs = (torch.rand([3200, 100, 32], dtype=ref_dtype) * 32.0).to(torch.int8)
rhs_pqt = PlanarQuantizedTensor(
shape=[3200, 3200], layout=BlockScaledLayout([3200, 3200], d, qs)
)
expected_result = ops.matmul(a, rhs_pqt, transpose_rhs=True)

shard_count = 2
rhs_pqt_sharded = SplitPrimitiveTensor(
shard_dim=0, ts=rhs_pqt, shard_count=shard_count
)

sharded_result = ops.matmul(a, rhs_pqt_sharded, transpose_rhs=True)
actual_result = ops.sharded_cat(sharded_result)

torch.testing.assert_close(actual_result, expected_result)

def testTorchImplTransposedQuantizedRHSSharded_TensorScaledLayout(self):
ops._registry._test_enable_last_op_dispatch(True)
a_dtype = torch.float32
d_dtype = torch.float32
ref_dtype = torch.float32
a = torch.rand([4, 16, 3200], dtype=a_dtype) * 64
d = torch.tensor(5.1, dtype=d_dtype) # torch.rand([3200], dtype=d_dtype)
qs = (torch.rand([3200, 3200], dtype=ref_dtype) * 32.0).to(torch.int8)
m = torch.tensor(
16.0, dtype=d_dtype
) # torch.rand([3200], dtype=d_dtype) + 16.0
rhs_pqt = PlanarQuantizedTensor(
shape=[3200, 3200],
layout=TensorScaledLayout(shape=[3200, 3200], d=d, qs=qs, m=m),
)
print("a shape:, ", a.shape)
print("rhs_pqt.shape: ", rhs_pqt.shape)
expected_result = ops.matmul(a, rhs_pqt, transpose_rhs=True)

shard_count = 2
rhs_pqt_sharded = SplitPrimitiveTensor(
shard_dim=0, ts=rhs_pqt, shard_count=shard_count
)

sharded_result = ops.matmul(a, rhs_pqt_sharded, transpose_rhs=True)
actual_result = ops.sharded_cat(sharded_result)

torch.testing.assert_close(actual_result, expected_result)


class ReplicateTest(unittest.TestCase):
def testReplicateReplicated(self):
Expand Down
Loading