From 36ca12fa1ee174b67fdf7b921b9dd0a46b12cd62 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Fri, 11 Oct 2024 11:17:08 -0500 Subject: [PATCH] WIP Fix in-place tensor mutation when sharding Add option to construct sharded tensors without inserting device placement operations. Add a simple test for in-place addition t += 1 When unflattening a sharded tensor during exporting insert in-place device placements, which get translated to function argument affinity attributes. --- sharktank/sharktank/layers/kv_cache.py | 8 +- sharktank/sharktank/ops/default_impls.py | 15 +- sharktank/sharktank/ops/sharded_impls.py | 26 +++- sharktank/sharktank/ops/signatures.py | 24 +++- sharktank/sharktank/types/tensors.py | 31 +++- sharktank/sharktank/utils/iree.py | 6 +- .../tests/models/llama/sharded_llama_test.py | 53 +++++-- sharktank/tests/ops/sharded_test.py | 8 +- sharktank/tests/types/tensors_test.py | 135 ++++++++++++++++++ 9 files changed, 276 insertions(+), 30 deletions(-) diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index 309fe322a..e09dc600d 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -207,7 +207,9 @@ def unflatten_page_table( ) for shard in page_slab.shards ] - return SplitPrimitiveTensor(ts=shards, shard_dim=4) + return SplitPrimitiveTensor( + ts=shards, shard_dim=4, insert_device_assignment=False + ) def shard_state( self, state: List[torch.Tensor] @@ -236,7 +238,9 @@ def shard_state( shards = [ ops.flatten(shard, start_dim=1) for shard in sharded_page_table.shards ] - flat_sharded_page_table = SplitPrimitiveTensor(ts=shards, shard_dim=1) + flat_sharded_page_table = SplitPrimitiveTensor( + ts=shards, shard_dim=1, insert_device_assignment=False + ) return [flat_sharded_page_table] @property diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index fec30fca6..ed84ef585 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -86,11 +86,15 @@ def elementwise_unary(operator, x, *args, **kwargs): IsOfType(Tensor, PrimitiveTensor), IsOfType(Tensor, PrimitiveTensor, Number) ) ) -def elementwise_binary(operator, x, y, *args, **kwargs): +def elementwise_binary( + operator, x, y, out: Optional[Tensor | PrimitiveTensor] = None, *args, **kwargs +): x = unbox_tensor(x) if isinstance(y, PrimitiveTensor): y = unbox_tensor(y) - return operator(x, y, *args, **kwargs) + if isinstance(out, PrimitiveTensor): + out = unbox_tensor(out) + return operator(x, y, *args, out=out, **kwargs) @elementwise.override( @@ -398,6 +402,13 @@ def transfer_to_logical_device_default(tensor: Tensor, ordinal: int): ) +@transfer_to_logical_device_.override(Tensor) +def transfer_to_logical_device__default(tensor: Tensor, ordinal: int): + iree.turbine.ops.iree.transfer_to_logical_device_( + f"{ordinal}", unbox_tensor(tensor) + ) + + @transpose.override(Tensor) def transpose_default( tensor: Union[Tensor, PrimitiveTensor], dim0: int, dim1: int diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index a667669f4..6179b1943 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -300,11 +300,29 @@ def split_elementwise_binary( @elementwise.override(SplitPrimitiveTensor, Number) def elementwise_binary_split_lhs_scalar_rhs( - operator, x: SplitPrimitiveTensor, y: Number, *args, **kwargs + operator, + x: SplitPrimitiveTensor, + y: Number, + out: SplitPrimitiveTensor = None, + *args, + **kwargs, ): - pt_xs = [unbox_tensor(pt) for pt in x.shards] - partials = [operator(pt_x, y, *args, **kwargs) for pt_x in pt_xs] - return SplitPrimitiveTensor(shard_dim=x.shard_dim, shape=x.shape, ts=partials) + x_shards = [unbox_tensor(pt) for pt in x.shards] + out_shards = ( + [None] * len(x.shards) + if out is None + else [unbox_tensor(shard) for shard in out.shards] + ) + partials = [ + operator(x_shard, y, out=out_shard, *args, **kwargs) + for x_shard, out_shard in zip(x_shards, out_shards) + ] + return SplitPrimitiveTensor( + shard_dim=x.shard_dim, + shape=x.shape, + ts=partials, + insert_device_assignment=out is None, + ) @elementwise.override(SplitPrimitiveTensor, Tensor) diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index 89d4309ee..7a8f64c1e 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -54,6 +54,7 @@ "softmax", "to", "transfer_to_logical_device", + "transfer_to_logical_device_", "transpose", "unflatten", "unshard", @@ -210,7 +211,7 @@ def elementwise(operator, *args, **kwargs) -> AnyTensor: def _elementwise_trampoline(d: SignatureDispatcher, operator, *args, **kwargs): tensors = [] for a in args: - if isinstance(a, (Tensor, InferenceTensor)): + if isinstance(a, (Tensor, InferenceTensor, Number)): tensors.append(a) else: break @@ -994,6 +995,27 @@ def _transfer_to_logical_device_trampoline( d.fail(tensors) +@overridable +def transfer_to_logical_device_(tensor: AnyTensor, ordinal: int) -> None: + """In-place variant of transfer_to_logical_device. + Used to annotate function arguments. + """ + ... + + +@transfer_to_logical_device_.trampoline +def _transfer_to_logical_device__trampoline( + d: SignatureDispatcher, tensor: AnyTensor, ordinal: int +): + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor, ordinal) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + @overridable def transpose(tensor: AnyTensor, dim0: int, dim1: int) -> AnyTensor: """See torch.transpose""" diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 324cc4331..44daa17c2 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -395,6 +395,11 @@ def __radd__(self, lhs): # numbers on the lhs. return self.__add__(lhs) + def __iadd__(self, rhs): + from ..ops import elementwise + + return elementwise(torch.add, self, rhs, out=self) + def __mod__(self, rhs): from ..ops import elementwise @@ -758,6 +763,7 @@ def __init__( ts: list[torch.Tensor], name: str = UnnamedTensorName, shape: Optional[list[int]], + insert_device_assignment: bool = True, ): from ..ops import transfer_to_logical_device @@ -767,7 +773,9 @@ def __init__( self._shards: tuple[DefaultPrimitiveTensor] = tuple( DefaultPrimitiveTensor( name=f"{name}.shard.{i}", - data=transfer_to_logical_device(t, i), + data=transfer_to_logical_device(t, i) + if insert_device_assignment + else unbox_tensor(t), ) for i, t in enumerate(ts) ) @@ -930,6 +938,7 @@ def __init__( shard_count: None | int = None, name: str = UnnamedTensorName, shape: Optional[list[int]] = None, + insert_device_assignment: bool = True, ): """ If `ts` is a list of tensors, it is interpreted as the shards. @@ -966,7 +975,13 @@ def __init__( s == t for i, (s, t) in enumerate(zip(shape, t_shape)) if i != shard_dim ), f"Shape mismatch for non-split dimension for tensor shard {i} with shape {t.shape}" - super().__init__(name=name, ts=ts, shape=shape, shard_dim=shard_dim) + super().__init__( + name=name, + ts=ts, + shape=shape, + shard_dim=shard_dim, + insert_device_assignment=insert_device_assignment, + ) def _is_slicing_split_dim(self, key): if isinstance( @@ -1309,6 +1324,7 @@ def flatten_with_keys_default_primitive_tensor(t: DefaultPrimitiveTensor): flatten_fn=flatten_default_primitive_tensor, unflatten_fn=unflatten_defult_primitive_tensor, flatten_with_keys_fn=flatten_with_keys_default_primitive_tensor, + serialized_type_name=f"{DefaultPrimitiveTensor.__module__}.{DefaultPrimitiveTensor.__name__}", ) @@ -1321,8 +1337,16 @@ def flatten_split_primitive_tensor( def unflatten_split_primitive_tensor( values: Iterable[Any], ctx: torch.utils._pytree.Context ) -> SplitPrimitiveTensor: + from ..ops import transfer_to_logical_device_ + + shards = list(values) + for i, tensor in enumerate(shards): + transfer_to_logical_device_(tensor, i) return SplitPrimitiveTensor( - shard_dim=ctx["shard_dim"], ts=list(values), name=ctx["name"] + shard_dim=ctx["shard_dim"], + ts=shards, + name=ctx["name"], + insert_device_assignment=False, ) @@ -1336,6 +1360,7 @@ def flatten_with_keys_split_primitive_tensor(t: SplitPrimitiveTensor): flatten_fn=flatten_split_primitive_tensor, unflatten_fn=unflatten_split_primitive_tensor, flatten_with_keys_fn=flatten_with_keys_split_primitive_tensor, + serialized_type_name=f"{SplitPrimitiveTensor.__module__}.{SplitPrimitiveTensor.__name__}", ) diff --git a/sharktank/sharktank/utils/iree.py b/sharktank/sharktank/utils/iree.py index 7c666ff62..8c933fe02 100644 --- a/sharktank/sharktank/utils/iree.py +++ b/sharktank/sharktank/utils/iree.py @@ -98,7 +98,7 @@ def run_iree_module_function( if trace_path_prefix is not None: for i, arg in enumerate(args): np.save( - f"{trace_path_prefix}{function_name}_arg_post_call{i}.npy", + f"{trace_path_prefix}{function_name}_arg{i}_post_call.npy", arg.to_host(), ) for i, arg in enumerate(results): @@ -187,3 +187,7 @@ def call_torch_module_function( result.to("cpu").numpy(), ) return res + + +def iree_to_torch(*tensors: iree.runtime.DeviceArray) -> List[torch.Tensor]: + return [torch.tensor(tensor.to_host()) for tensor in tensors] diff --git a/sharktank/tests/models/llama/sharded_llama_test.py b/sharktank/tests/models/llama/sharded_llama_test.py index 4d34dc704..c4f5b8f27 100644 --- a/sharktank/tests/models/llama/sharded_llama_test.py +++ b/sharktank/tests/models/llama/sharded_llama_test.py @@ -24,6 +24,7 @@ run_iree_module_function, prepare_iree_module_function_args, call_torch_module_function, + iree_to_torch, ) import tempfile import torch @@ -34,10 +35,6 @@ import os -def iree_to_torch(*tensors: iree.runtime.DeviceArray) -> List[torch.Tensor]: - return [torch.tensor(tensor.to_host()) for tensor in tensors] - - @pytest.mark.usefixtures("caching", "path_prefix") class ShardedLlamaTest(unittest.TestCase): def setUp(self): @@ -219,12 +216,12 @@ def testCompareToySizedModelToUnsharded(self): actual_decode_cache_state, expected_decode_cache_state, atol=1e-4, rtol=1e-4 ) - @unittest.skip( - ( - "Before this does not crash at all we need " - "https://github.com/iree-org/iree/pull/18663 merged." - ) - ) + # @unittest.skip( + # ( + # "Before this does not crash at all we need " + # "https://github.com/iree-org/iree/pull/18663 merged." + # ) + # ) def testExportAndRunToySizedModelWithIree(self): """Test exporting to MLIR and compiling with IREE the sharded Llama model. Test numerical accuracy of the IREE module against PyTorch.""" @@ -250,19 +247,49 @@ def runTestExportAndRunToySizedModelWithIree( sharded_dataset = Dataset.load(sharded_parameters_path, mmap=False) iree_driver = "local-task" - model = PagedLlamaModelV1(self.theta, self.config) + self.theta.rename_tensors_to_paths() + dataset = Dataset({}, self.theta) + parameters_path = f"{path_prefix}unsharded-parameters.irpa" + dataset.save(parameters_path) + dataset = Dataset.load(parameters_path, mmap=False) + model = PagedLlamaModelV1(dataset.root_theta, self.config) sharded_model = PagedLlamaModelV1( sharded_dataset.root_theta, self.sharded_config ) ( - _, + prefill_args, sharded_prefill_args, ) = self.make_equal_unsharded_and_sharded_prefill_args(model, sharded_model) ( - _, + decode_args, sharded_decode_args, ) = self.make_equal_unsharded_and_sharded_decode_args(model, sharded_model) + ################################################################ + fxb = FxProgramsBuilder(model) + + @fxb.export_program( + name="prefill", + args=tuple(), + kwargs=prefill_args, + strict=False, + ) + def _(model, *args, **kwargs) -> torch.Tensor: + return model.prefill(*args, **kwargs) + + @fxb.export_program( + name="decode", + args=tuple(), + kwargs=decode_args, + strict=False, + ) + def _(model, *args, **kwargs) -> torch.Tensor: + return model.decode(*args, **kwargs) + + output = export(fxb) + output.save_mlir(f"{path_prefix}program-unsharded.mlir") + ################################################################ + iree_module_path = f"{path_prefix}program.vmfb" if not self.caching or not os.path.exists(iree_module_path): # Export and compile the IREE module. diff --git a/sharktank/tests/ops/sharded_test.py b/sharktank/tests/ops/sharded_test.py index 0da22474e..c5965d5e6 100644 --- a/sharktank/tests/ops/sharded_test.py +++ b/sharktank/tests/ops/sharded_test.py @@ -786,9 +786,9 @@ def testReplicateUnsharded(self): expected_result = ReplicatedTensor(ts=tensor, shard_count=shard_count) assert expected_result.is_deep_equal(actual_result) - # Test not a copy. + # Test that is a copy. tensor[...] = torch.rand_like(tensor) - assert all(ops.equal(tensor, shard) for shard in actual_result.shards) + assert all(not ops.equal(tensor, shard) for shard in actual_result.shards) class ReshapeTest(unittest.TestCase): @@ -851,10 +851,10 @@ def testReshardUnsharded(self): ) assert expected_result.is_deep_equal(actual_result) - # Test not a copy. + # Test that is a copy. tensor[...] = torch.rand_like(tensor) result_split2 = ops.reshard_split(tensor, dim=shard_dim, count=shard_count) - assert ops.equal(actual_result, result_split2) + assert not ops.equal(actual_result, result_split2) def testReshardSharded(self): tensor = torch.rand(4, 5, 6, dtype=torch.float32) diff --git a/sharktank/tests/types/tensors_test.py b/sharktank/tests/types/tensors_test.py index 55f00cdaf..0c91f4377 100644 --- a/sharktank/tests/types/tensors_test.py +++ b/sharktank/tests/types/tensors_test.py @@ -9,9 +9,21 @@ import torch import tempfile import os +import pytest +from collections import OrderedDict from sharktank.types import * +from sharktank.utils.iree import ( + get_iree_devices, + load_iree_module, + run_iree_module_function, + prepare_iree_module_function_args, + call_torch_module_function, + iree_to_torch, +) from sharktank import ops +from copy import deepcopy +from iree.turbine.aot import FxProgramsBuilder, export def _createTestLayout(): @@ -60,6 +72,7 @@ def transform2(d): self.assertEqual(new_planes["d"].dtype, torch.float16) +@pytest.mark.usefixtures("path_prefix") class ShardedTensorTest(unittest.TestCase): def testReplicatedTensorSaveLoad(self): tensor = torch.rand([2, 3, 4], dtype=torch.float32) @@ -168,6 +181,128 @@ def testSplitTensorInsertSliceWithEllipsis(self): actual_result = ops.unshard(sharded_dst) assert ops.equal(actual_result, dst) + def testInPlaceUpdate(self): + if self.path_prefix is not None: + self.runTestInPlaceUpdate(path_prefix=self.path_prefix, dump_enabled=True) + else: + with tempfile.TemporaryDirectory() as temp_dir: + self.runTestInPlaceUpdate( + path_prefix=f"{temp_dir}/", dump_enabled=False + ) + + def runTestInPlaceUpdate(self, path_prefix: str, dump_enabled: bool): + shard_dim = 2 + shard_count = 2 + + class Module(torch.nn.Module): + def main(self, tensor: AnyTensor): + tensor += 1 + # TODO: figure out why when not returning anything fails the export + # fails. + return torch.empty([1]) + + shape = [2, 3, 4] + tensor = torch.rand(shape) + sharded_tensor = SplitPrimitiveTensor( + ts=tensor, + shard_dim=shard_dim, + shard_count=shard_count, + insert_device_assignment=False, + ) + + # Avoid aliasing with tensor. + # Torch exporting complains about mutating an aliased input. + # Doing + # sharded_tensor = deepcopy(sharded_tensor) + # is not enough. + shards = [ + torch.empty_like(unbox_tensor(shard)) for shard in sharded_tensor.shards + ] + for src_shard, dst_shard in zip(sharded_tensor.shards, shards): + dst_shard[...] = unbox_tensor(src_shard) + sharded_tensor = SplitPrimitiveTensor( + ts=shards, shard_dim=shard_dim, insert_device_assignment=False + ) + + sharded_tensor_snapshot = deepcopy(sharded_tensor) + module = Module() + module.main(sharded_tensor) + actual_result = ops.unshard(sharded_tensor) + expected_result = tensor + 1 + assert ops.equal(expected_result, actual_result) + + fxb = FxProgramsBuilder(module) + + @fxb.export_program( + args=(deepcopy(sharded_tensor),), + name="main", + strict=False, + ) + def _(model, *args, **kwargs) -> AnyTensor: + return model.main(*args, **kwargs) + + if dump_enabled: + for program_name, ep in fxb.programs.items(): + with open( + f"{path_prefix}{program_name}.torch.fx.txt", + "w", + ) as f: + print(str(ep), file=f) + + output = export(fxb) + if dump_enabled: + output.save_mlir(f"{path_prefix}program.mlir") + + iree_module_path = f"{path_prefix}program.vmfb" + output.session.set_flags( + *[f"--iree-hal-target-device=llvm-cpu[{i}]" for i in range(shard_count)] + ) + output.compile( + save_to=iree_module_path, + target_backends=None, + ) + + iree_driver = "local-task" + iree_devices = get_iree_devices( + driver=iree_driver, + device_count=shard_count, + ) + iree_module, vm_context, vm_instance = load_iree_module( + module_path=iree_module_path, + devices=iree_devices, + ) + iree_args = prepare_iree_module_function_args( + args=[deepcopy(sharded_tensor_snapshot)], devices=iree_devices + ) + run_iree_module_function( + args=iree_args, + function_name="main", + module=iree_module, + vm_context=vm_context, + driver=iree_driver, + trace_path_prefix=path_prefix if dump_enabled else None, + ) + iree_args_as_torch = iree_to_torch(*iree_args) + iree_args_sharded_tensor = SplitPrimitiveTensor( + ts=iree_args_as_torch, shard_dim=shard_dim, insert_device_assignment=False + ) + actual_iree_result = ops.unshard(iree_args_sharded_tensor) + if dump_enabled: + call_torch_module_function( + module=module, + function_name="main", + kwargs=OrderedDict( + [ + ( + "tensor", + deepcopy(sharded_tensor_snapshot), + ) + ] + ), + trace_path_prefix=f"{path_prefix}expected_", + ) + torch.testing.assert_close(actual_iree_result, expected_result) + if __name__ == "__main__": unittest.main()