Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Fix in-place tensor mutation when sharding #290

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def unflatten_page_table(
)
for shard in page_slab.shards
]
return SplitPrimitiveTensor(ts=shards, shard_dim=4)
return SplitPrimitiveTensor(
ts=shards, shard_dim=4, insert_device_assignment=False
)

def shard_state(
self, state: List[torch.Tensor]
Expand Down Expand Up @@ -236,7 +238,9 @@ def shard_state(
shards = [
ops.flatten(shard, start_dim=1) for shard in sharded_page_table.shards
]
flat_sharded_page_table = SplitPrimitiveTensor(ts=shards, shard_dim=1)
flat_sharded_page_table = SplitPrimitiveTensor(
ts=shards, shard_dim=1, insert_device_assignment=False
)
return [flat_sharded_page_table]

@property
Expand Down
15 changes: 13 additions & 2 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,15 @@ def elementwise_unary(operator, x, *args, **kwargs):
IsOfType(Tensor, PrimitiveTensor), IsOfType(Tensor, PrimitiveTensor, Number)
)
)
def elementwise_binary(operator, x, y, *args, **kwargs):
def elementwise_binary(
operator, x, y, out: Optional[Tensor | PrimitiveTensor] = None, *args, **kwargs
):
x = unbox_tensor(x)
if isinstance(y, PrimitiveTensor):
y = unbox_tensor(y)
return operator(x, y, *args, **kwargs)
if isinstance(out, PrimitiveTensor):
out = unbox_tensor(out)
return operator(x, y, *args, out=out, **kwargs)


@elementwise.override(
Expand Down Expand Up @@ -398,6 +402,13 @@ def transfer_to_logical_device_default(tensor: Tensor, ordinal: int):
)


@transfer_to_logical_device_.override(Tensor)
def transfer_to_logical_device__default(tensor: Tensor, ordinal: int):
iree.turbine.ops.iree.transfer_to_logical_device_(
f"{ordinal}", unbox_tensor(tensor)
)


@transpose.override(Tensor)
def transpose_default(
tensor: Union[Tensor, PrimitiveTensor], dim0: int, dim1: int
Expand Down
26 changes: 22 additions & 4 deletions sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,29 @@ def split_elementwise_binary(

@elementwise.override(SplitPrimitiveTensor, Number)
def elementwise_binary_split_lhs_scalar_rhs(
operator, x: SplitPrimitiveTensor, y: Number, *args, **kwargs
operator,
x: SplitPrimitiveTensor,
y: Number,
out: SplitPrimitiveTensor = None,
*args,
**kwargs,
):
pt_xs = [unbox_tensor(pt) for pt in x.shards]
partials = [operator(pt_x, y, *args, **kwargs) for pt_x in pt_xs]
return SplitPrimitiveTensor(shard_dim=x.shard_dim, shape=x.shape, ts=partials)
x_shards = [unbox_tensor(pt) for pt in x.shards]
out_shards = (
[None] * len(x.shards)
if out is None
else [unbox_tensor(shard) for shard in out.shards]
)
partials = [
operator(x_shard, y, out=out_shard, *args, **kwargs)
for x_shard, out_shard in zip(x_shards, out_shards)
]
return SplitPrimitiveTensor(
shard_dim=x.shard_dim,
shape=x.shape,
ts=partials,
insert_device_assignment=out is None,
)


@elementwise.override(SplitPrimitiveTensor, Tensor)
Expand Down
24 changes: 23 additions & 1 deletion sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"softmax",
"to",
"transfer_to_logical_device",
"transfer_to_logical_device_",
"transpose",
"unflatten",
"unshard",
Expand Down Expand Up @@ -210,7 +211,7 @@ def elementwise(operator, *args, **kwargs) -> AnyTensor:
def _elementwise_trampoline(d: SignatureDispatcher, operator, *args, **kwargs):
tensors = []
for a in args:
if isinstance(a, (Tensor, InferenceTensor)):
if isinstance(a, (Tensor, InferenceTensor, Number)):
tensors.append(a)
else:
break
Expand Down Expand Up @@ -994,6 +995,27 @@ def _transfer_to_logical_device_trampoline(
d.fail(tensors)


@overridable
def transfer_to_logical_device_(tensor: AnyTensor, ordinal: int) -> None:
"""In-place variant of transfer_to_logical_device.
Used to annotate function arguments.
"""
...


@transfer_to_logical_device_.trampoline
def _transfer_to_logical_device__trampoline(
d: SignatureDispatcher, tensor: AnyTensor, ordinal: int
):
tensors = (tensor,)
for override in d.find_overrides(tensors):
result = override(tensor, ordinal)
if result is not NotImplemented:
return override, result
else:
d.fail(tensors)


@overridable
def transpose(tensor: AnyTensor, dim0: int, dim1: int) -> AnyTensor:
"""See torch.transpose"""
Expand Down
31 changes: 28 additions & 3 deletions sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,11 @@ def __radd__(self, lhs):
# numbers on the lhs.
return self.__add__(lhs)

def __iadd__(self, rhs):
from ..ops import elementwise

return elementwise(torch.add, self, rhs, out=self)

def __mod__(self, rhs):
from ..ops import elementwise

Expand Down Expand Up @@ -758,6 +763,7 @@ def __init__(
ts: list[torch.Tensor],
name: str = UnnamedTensorName,
shape: Optional[list[int]],
insert_device_assignment: bool = True,
):
from ..ops import transfer_to_logical_device

Expand All @@ -767,7 +773,9 @@ def __init__(
self._shards: tuple[DefaultPrimitiveTensor] = tuple(
DefaultPrimitiveTensor(
name=f"{name}.shard.{i}",
data=transfer_to_logical_device(t, i),
data=transfer_to_logical_device(t, i)
if insert_device_assignment
else unbox_tensor(t),
)
for i, t in enumerate(ts)
)
Expand Down Expand Up @@ -930,6 +938,7 @@ def __init__(
shard_count: None | int = None,
name: str = UnnamedTensorName,
shape: Optional[list[int]] = None,
insert_device_assignment: bool = True,
):
"""
If `ts` is a list of tensors, it is interpreted as the shards.
Expand Down Expand Up @@ -966,7 +975,13 @@ def __init__(
s == t for i, (s, t) in enumerate(zip(shape, t_shape)) if i != shard_dim
), f"Shape mismatch for non-split dimension for tensor shard {i} with shape {t.shape}"

super().__init__(name=name, ts=ts, shape=shape, shard_dim=shard_dim)
super().__init__(
name=name,
ts=ts,
shape=shape,
shard_dim=shard_dim,
insert_device_assignment=insert_device_assignment,
)

def _is_slicing_split_dim(self, key):
if isinstance(
Expand Down Expand Up @@ -1309,6 +1324,7 @@ def flatten_with_keys_default_primitive_tensor(t: DefaultPrimitiveTensor):
flatten_fn=flatten_default_primitive_tensor,
unflatten_fn=unflatten_defult_primitive_tensor,
flatten_with_keys_fn=flatten_with_keys_default_primitive_tensor,
serialized_type_name=f"{DefaultPrimitiveTensor.__module__}.{DefaultPrimitiveTensor.__name__}",
)


Expand All @@ -1321,8 +1337,16 @@ def flatten_split_primitive_tensor(
def unflatten_split_primitive_tensor(
values: Iterable[Any], ctx: torch.utils._pytree.Context
) -> SplitPrimitiveTensor:
from ..ops import transfer_to_logical_device_

shards = list(values)
for i, tensor in enumerate(shards):
transfer_to_logical_device_(tensor, i)
return SplitPrimitiveTensor(
shard_dim=ctx["shard_dim"], ts=list(values), name=ctx["name"]
shard_dim=ctx["shard_dim"],
ts=shards,
name=ctx["name"],
insert_device_assignment=False,
)


Expand All @@ -1336,6 +1360,7 @@ def flatten_with_keys_split_primitive_tensor(t: SplitPrimitiveTensor):
flatten_fn=flatten_split_primitive_tensor,
unflatten_fn=unflatten_split_primitive_tensor,
flatten_with_keys_fn=flatten_with_keys_split_primitive_tensor,
serialized_type_name=f"{SplitPrimitiveTensor.__module__}.{SplitPrimitiveTensor.__name__}",
)


Expand Down
6 changes: 5 additions & 1 deletion sharktank/sharktank/utils/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def run_iree_module_function(
if trace_path_prefix is not None:
for i, arg in enumerate(args):
np.save(
f"{trace_path_prefix}{function_name}_arg_post_call{i}.npy",
f"{trace_path_prefix}{function_name}_arg{i}_post_call.npy",
arg.to_host(),
)
for i, arg in enumerate(results):
Expand Down Expand Up @@ -187,3 +187,7 @@ def call_torch_module_function(
result.to("cpu").numpy(),
)
return res


def iree_to_torch(*tensors: iree.runtime.DeviceArray) -> List[torch.Tensor]:
return [torch.tensor(tensor.to_host()) for tensor in tensors]
53 changes: 40 additions & 13 deletions sharktank/tests/models/llama/sharded_llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
run_iree_module_function,
prepare_iree_module_function_args,
call_torch_module_function,
iree_to_torch,
)
import tempfile
import torch
Expand All @@ -34,10 +35,6 @@
import os


def iree_to_torch(*tensors: iree.runtime.DeviceArray) -> List[torch.Tensor]:
return [torch.tensor(tensor.to_host()) for tensor in tensors]


@pytest.mark.usefixtures("caching", "path_prefix")
class ShardedLlamaTest(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -219,12 +216,12 @@ def testCompareToySizedModelToUnsharded(self):
actual_decode_cache_state, expected_decode_cache_state, atol=1e-4, rtol=1e-4
)

@unittest.skip(
(
"Before this does not crash at all we need "
"https://github.com/iree-org/iree/pull/18663 merged."
)
)
# @unittest.skip(
# (
# "Before this does not crash at all we need "
# "https://github.com/iree-org/iree/pull/18663 merged."
# )
# )
def testExportAndRunToySizedModelWithIree(self):
"""Test exporting to MLIR and compiling with IREE the sharded Llama model.
Test numerical accuracy of the IREE module against PyTorch."""
Expand All @@ -250,19 +247,49 @@ def runTestExportAndRunToySizedModelWithIree(
sharded_dataset = Dataset.load(sharded_parameters_path, mmap=False)
iree_driver = "local-task"

model = PagedLlamaModelV1(self.theta, self.config)
self.theta.rename_tensors_to_paths()
dataset = Dataset({}, self.theta)
parameters_path = f"{path_prefix}unsharded-parameters.irpa"
dataset.save(parameters_path)
dataset = Dataset.load(parameters_path, mmap=False)
model = PagedLlamaModelV1(dataset.root_theta, self.config)
sharded_model = PagedLlamaModelV1(
sharded_dataset.root_theta, self.sharded_config
)
(
_,
prefill_args,
sharded_prefill_args,
) = self.make_equal_unsharded_and_sharded_prefill_args(model, sharded_model)
(
_,
decode_args,
sharded_decode_args,
) = self.make_equal_unsharded_and_sharded_decode_args(model, sharded_model)

################################################################
fxb = FxProgramsBuilder(model)

@fxb.export_program(
name="prefill",
args=tuple(),
kwargs=prefill_args,
strict=False,
)
def _(model, *args, **kwargs) -> torch.Tensor:
return model.prefill(*args, **kwargs)

@fxb.export_program(
name="decode",
args=tuple(),
kwargs=decode_args,
strict=False,
)
def _(model, *args, **kwargs) -> torch.Tensor:
return model.decode(*args, **kwargs)

output = export(fxb)
output.save_mlir(f"{path_prefix}program-unsharded.mlir")
################################################################

iree_module_path = f"{path_prefix}program.vmfb"
if not self.caching or not os.path.exists(iree_module_path):
# Export and compile the IREE module.
Expand Down
8 changes: 4 additions & 4 deletions sharktank/tests/ops/sharded_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,9 +786,9 @@ def testReplicateUnsharded(self):
expected_result = ReplicatedTensor(ts=tensor, shard_count=shard_count)
assert expected_result.is_deep_equal(actual_result)

# Test not a copy.
# Test that is a copy.
tensor[...] = torch.rand_like(tensor)
assert all(ops.equal(tensor, shard) for shard in actual_result.shards)
assert all(not ops.equal(tensor, shard) for shard in actual_result.shards)


class ReshapeTest(unittest.TestCase):
Expand Down Expand Up @@ -851,10 +851,10 @@ def testReshardUnsharded(self):
)
assert expected_result.is_deep_equal(actual_result)

# Test not a copy.
# Test that is a copy.
tensor[...] = torch.rand_like(tensor)
result_split2 = ops.reshard_split(tensor, dim=shard_dim, count=shard_count)
assert ops.equal(actual_result, result_split2)
assert not ops.equal(actual_result, result_split2)

def testReshardSharded(self):
tensor = torch.rand(4, 5, 6, dtype=torch.float32)
Expand Down
Loading
Loading