From e017b56c0c5e167b36f03902e16220650849fb20 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Fri, 31 May 2024 17:19:53 +0000 Subject: [PATCH] [dtensor] local_map UX change: keep func signature and be compatible with Tensor input (#126924) **Summary** This PR has 2 parts of change in `local_map`: 1. regulates the way user can access `DeviceMesh` inside the `func` argument of `local_map`. This means `local_map` will strictly follow the `func` signature without implicitly passing any argument to `func`. If user wants to use `DeviceMesh` inside `func`, this mesh must be explicitly passed to `func` as an argument by user. For example, ``` def user_function(device_mesh, /, *args, **kwargs): USER CODE HERE local_func = local_map(func=user_function, ...) dtensor_out = local_func(device_mesh, dtensor_input, ...) ``` Before this PR, user code was like: ``` def user_function(device_mesh, /, *args, **kwargs): USER CODE HERE local_func = local_map(func=user_function, ...) dtensor_out = local_func(dtensor_input, ...) # local_map passes mesh implicitly for user ``` 2. `local_map` now supports mix use of `torch.Tensor` and `DTensor` in argument: - Pure torch.Tensor case: no `DTensor` argument is passed in, all tensor arguments are `torch.Tensor`. Bypass the `in_placements` check and unwrapping steps. The output will not be wrapped into `DTensor` but directly returned. - Pure DTensor case: no `torch.Tensor` argument is passed in, all tensor arguments are `DTensor`. This follows the default rule: `in_placements` check, unwrapping arguments, pass into `func`, wrapping the `torch.Tensor` output into `DTensor` if the `out_placements` is not `None`. - Mix of the above two: some arguments are `torch.Tensor` while some are `DTensor`. Only perform `in_placements` check and unwrapping on `DTensor` arguments. For output processing, it's the same as Pure DTensor case. **Test** `pytest test/distributed/_tensor/experimental/test_local_map.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126924 Approved by: https://github.com/wanchaol --- .../_tensor/experimental/test_local_map.py | 123 +++++++++++++--- .../_tensor/experimental/local_map.py | 138 +++++++++++------- 2 files changed, 191 insertions(+), 70 deletions(-) diff --git a/test/distributed/_tensor/experimental/test_local_map.py b/test/distributed/_tensor/experimental/test_local_map.py index 1035df2f5f7d8..b483194d6c3a8 100644 --- a/test/distributed/_tensor/experimental/test_local_map.py +++ b/test/distributed/_tensor/experimental/test_local_map.py @@ -5,6 +5,7 @@ import torch.distributed._functional_collectives as funcol from torch.distributed._tensor import ( distribute_tensor, + DTensor, init_device_mesh, Replicate, Shard, @@ -18,23 +19,30 @@ ) -def equal_forward(device_mesh, X, Y): +funcol_py = torch.ops.c10d_functional + + +def equal_allgather_forward(device_mesh, X, Y): eq = torch.tensor([torch.equal(X, Y)], device=X.device) eq_gather = funcol.all_gather_tensor(eq, 0, device_mesh) return torch.all(eq_gather).item() -def mm_forward(device_mesh, W, X): - return torch.mm(W, X) +def mm_all_gather_forward(device_mesh, A, B): + local_mm_result = torch.mm(A, B) + return funcol.all_gather_tensor(local_mm_result, 0, device_mesh).wait() + +def mm_forward(A, B): # no device mesh needed since we don't do collective + return torch.mm(A, B) -def mm_allreduce_forward(device_mesh, W, X): - partial_sum_tensor = torch.mm(W, X) - reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait() - return reduced_tensor +def mm_allreduce_forward(device_mesh, A, B): + partial_sum_tensor = torch.mm(A, B) + return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait() -def mul_forward(device_mesh, X, scalar): + +def mul_forward(X, scalar): # no device mesh needed since we don't do collective return torch.mul(X, scalar) @@ -58,6 +66,7 @@ def test_local_map_correctness(self): row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh + replicate = [Replicate()] W_dt = distribute_tensor( W, device_mesh, col_wise ) # col-wisely sharded W tensor @@ -70,12 +79,12 @@ def test_local_map_correctness(self): # DTensors' `_local_tensor`. local_mm_allreduce_forward = local_map( mm_allreduce_forward, - out_placements=[Replicate()], - in_placements=(col_wise, row_wise), + out_placements=replicate, + in_placements=(None, col_wise, row_wise), device_mesh=device_mesh, ) with comm_mode: - Y_dt = local_mm_allreduce_forward(W_dt, X_dt) + Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # output redistribution to Replicate self.assertEqual(comm_mode.get_total_counts(), 1) @@ -88,6 +97,7 @@ def test_local_map_correctness(self): # check for `out_placements` @with_comms def test_local_map_out_placements(self): + # Test 1: wrap out into DTensor w/ `out_placements` device_mesh = init_device_mesh( device_type=self.device_type, mesh_shape=(self.world_size,) ) @@ -99,14 +109,40 @@ def test_local_map_out_placements(self): row_wise = [Shard(0)] X_dt = distribute_tensor(X, device_mesh, row_wise) Y_dt = distribute_tensor(Y, device_mesh, row_wise) - local_equal_forward = local_map(equal_forward, out_placements=None) + local_equal_allgather_forward = local_map( + equal_allgather_forward, + out_placements=None, + ) with comm_mode: - equal_dt = local_equal_forward(X_dt, Y_dt) # a bool + equal_dt = local_equal_allgather_forward(device_mesh, X_dt, Y_dt) # a bool self.assertEqual(comm_mode.get_total_counts(), 1) self.assertTrue(not equal_dt) self.assertTrue(not (X.equal(Y))) + # Test 2: directly return out if no argument is DTensor + # matmul in DDP + replicate = [Replicate()] + X = torch.randn( + 4 // self.world_size, 4, device=self.device_type, requires_grad=False + ) + W = torch.randn(4, 4, device=self.device_type, requires_grad=False) + local_mm_all_gather_forward = local_map( + mm_all_gather_forward, + out_placements=row_wise, + in_placements=(None, row_wise, replicate), + ) + with comm_mode: + Y = local_mm_all_gather_forward(device_mesh, X, W) + + self.assertEqual(comm_mode.get_total_counts(), 1) + self.assertEqual( + comm_mode.get_comm_counts()[funcol_py.all_gather_into_tensor], 1 + ) + X_replicate = funcol.all_gather_tensor(X, 0, device_mesh).wait() + Y_replicate = torch.mm(X_replicate, W) + self.assertEqual(Y, Y_replicate) # Y is a torch.Tensor + # check for `in_placements` handling @with_comms def test_local_map_in_placements(self): @@ -173,6 +209,54 @@ def test_local_map_in_placements(self): self.assertTrue(placement.is_shard(dim=0)) self.assertEqual(Y_dt.full_tensor(), Y) + # Test 4: `None` placements for Tensor input argument + X = torch.randn(16, 8, device=self.device_type, requires_grad=False) + W = torch.randn(8, 12, device=self.device_type, requires_grad=False) + X_dt = distribute_tensor( + X, device_mesh, row_wise + ) # row-wisely sharded X tensor + W_dt = distribute_tensor(W, device_mesh, replicate) # replicate W tensor + local_mm_forward = local_map( + mm_forward, + out_placements=None, + in_placements=(None, None), + device_mesh=device_mesh, + ) + with comm_mode: + Y_dt_local = local_mm_forward(X_dt.to_local(), W_dt.to_local()) + + self.assertEqual(comm_mode.get_total_counts(), 0) + self.assertEqual( + DTensor.from_local(Y_dt_local, device_mesh, row_wise).full_tensor(), + torch.mm(X, W), + ) + + # Test 5: Some placements for Tensor input argument + local_mm_forward = local_map( + mm_forward, + out_placements=None, + in_placements=(replicate, row_wise), + device_mesh=device_mesh, + ) + with comm_mode: + Y_dt_local = local_mm_forward(X_dt.to_local(), W_dt.to_local()) + + self.assertEqual(comm_mode.get_total_counts(), 0) + self.assertEqual( + DTensor.from_local(Y_dt_local, device_mesh, row_wise).full_tensor(), + torch.mm(X, W), + ) + + # Test 6: expect error - `None` placements for DTensor input argument + local_mm_forward = local_map( + mm_forward, + out_placements=row_wise, + in_placements=(row_wise, None), + device_mesh=device_mesh, + ) + with self.assertRaisesRegex(AssertionError, "expects placements"): + Y_dt = local_mm_forward(X_dt, W_dt) + # check for `redistribute_inputs` handling @with_comms def test_local_map_redistribute(self): @@ -188,6 +272,7 @@ def test_local_map_redistribute(self): row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh + replicate = [Replicate()] W_dt = distribute_tensor( W, device_mesh, row_wise ) # row-wisely sharded W tensor which will be redistributed @@ -198,13 +283,13 @@ def test_local_map_redistribute(self): # Test 1: allow input redistribution local_mm_allreduce_forward = local_map( mm_allreduce_forward, - out_placements=[Replicate()], - in_placements=(col_wise, row_wise), + out_placements=replicate, + in_placements=(None, col_wise, row_wise), device_mesh=device_mesh, redistribute_inputs=True, ) with comm_mode: - Y_dt = local_mm_allreduce_forward(W_dt, X_dt) + Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # 2 for input redistribution and 1 for output self.assertEqual(comm_mode.get_total_counts(), 3) @@ -215,13 +300,13 @@ def test_local_map_redistribute(self): # Test 2: no input redistribution is allowed local_mm_allreduce_forward = local_map( mm_allreduce_forward, - out_placements=[Replicate()], - in_placements=(col_wise, row_wise), + out_placements=replicate, + in_placements=(None, col_wise, row_wise), device_mesh=device_mesh, redistribute_inputs=False, ) with self.assertRaisesRegex(ValueError, "set redistribute_inputs=True"): - Y_dt = local_mm_allreduce_forward(W_dt, X_dt) + Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) if __name__ == "__main__": diff --git a/torch/distributed/_tensor/experimental/local_map.py b/torch/distributed/_tensor/experimental/local_map.py index 002ff5542a119..2bf12871cc367 100644 --- a/torch/distributed/_tensor/experimental/local_map.py +++ b/torch/distributed/_tensor/experimental/local_map.py @@ -2,6 +2,7 @@ from typing import Callable, Optional, Sequence, Tuple, Union import torch +from torch.distributed._functional_collectives import AsyncCollectiveTensor from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed._tensor.placement_types import Placement @@ -12,7 +13,7 @@ PlacementType = Optional[Sequence[Placement]] -InputPlacements = Union[PlacementType, Tuple[PlacementType, ...]] +InputPlacements = Optional[Tuple[PlacementType, ...]] OutputPlacements = Union[PlacementType, Tuple[PlacementType, ...]] @@ -32,24 +33,36 @@ def local_map( func (Callable): the function to be applied on each local shard of :class:`DTensor`s. out_placements (Union[`PlacementType`, Tuple[`PlacementType`, ...]]): - the desired placements of the output :class:`DTensor`s. If the `output` of - `func` is a Python collection, the `out_placements` will be a Tuple of - `PlacementType` values 1:1 mapping to the flattened `output`. For - :class:`Tensor` output, the corresponding `PlacementType` will be its + the desired placements of the :class:`DTensor`s in `func`'s flattened output. + If the flattened `output` is a single value, the `out_placements` should be + of type `PlacementType`. Otherwise if the flattened `output` has multiple + values, the `out_placements` should be a tuple of `PlacementType` values 1:1 + mapping to the flattened `output`. + Besides, for :class:`Tensor` output, we use `PlacementType` as its placements (a `Tuple[Placement]` value). For non-:class:`Tensor` output, - the `PlacementType` will be `None`. - in_placements (Union[`PlacementType`, Tuple[`PlacementType`, ...]], optional): - the required placements of the input :class:`DTensor`s. If not specified, - the input :class:`DTensor` will not be redistributed before passing its local - tensor to `func`. Similarly to `out_placements`, `in_placements` should keep - a 1:1 mapping to the flattened input of `func`. If a redistribution is - required according to `in_placements` and `redistribute_inputs` is `False`, - an exception will be raised. + the `PlacementType` should be `None`. + Note that the only exception is when no :class:`DTensor` argument is passed + in. In this case, even if `out_placements` is not `None`, the result function + should ignore the desired placements because the application is not on + :class:`DTensors`. + in_placements (Tuple[`PlacementType`, ...], optional): + the required placements of the :class:`DTensor`s in `func`'s flattened input. + If `in_placements` is specified, `local_map` would examine whether the + placements of each :class:`DTensor` argument is the same as the required + placements or not. If the placements are not the same and + `redistribute_inputs` is `False`, an exception will be raised. Otherwise if + `redistribute_inputs` is `True`, the argument will be first redistributed to + the required sharding placements before passing its local tensor to `func`. + The only exception is when required placements are not `None` and the + argument is a :class:`torch.Tensor`. In this case, the placements examination + will be skipped and the argument will be directly passed to `func`. + If `in_placements` is `None`, no placements examination will be performed. + Default: `None` device_mesh (:class:`DeviceMesh`, optional): the device mesh that all the :class:`DTensor`s are placed on. If not specified, this will be inferred from the input :class:`DTensor`s' device mesh. `local_map` requires every :class:`DTensor`s to be placed on the same - device mesh. + device mesh. Default: `None`. redistribute_inputs (bool, optional): the bool value indicating whether to reshard the input :class:`DTensor`s when their placements are different from the required input placements. If this @@ -93,9 +106,9 @@ def local_map( >>> device_mesh=device_mesh, >>> ) >>> - >>> W_dt = distribute_tensor(W, device_mesh, col_wise) # col-wisely sharded W tensor - >>> X_dt = distribute_tensor(X, device_mesh, row_wise) # row-wisely sharded X tensor - >>> Y_dt = local_mm_allreduce_forward(W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors + >>> W_dt = distribute_tensor(W, device_mesh, (col_wise)) # col-wisely sharded W tensor + >>> X_dt = distribute_tensor(X, device_mesh, (row_wise)) # row-wisely sharded X tensor + >>> Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors NOTE: This API is currently experimental and subject to change """ @@ -103,10 +116,16 @@ def local_map( def wrapped(*args, **kwargs): # process input args flat_args, args_spec = pytree.tree_flatten(args) + if in_placements is not None: + assert len(in_placements) == len(flat_args), ( + f"in_placements length {len(in_placements)} does not match the number " + f"of input args {len(flat_args)}!" + ) # we assume every DTensor object is placed on the same device mesh flat_local_args = [] nonlocal device_mesh # access var device_mesh from the outer scope + seen_dtensor_arg = False for idx, arg in enumerate(flat_args): if isinstance(arg, DTensor): # TODO: the current code doesn't consider the uneven sharding case @@ -115,17 +134,16 @@ def wrapped(*args, **kwargs): if device_mesh is None: # infer device mesh from the DTensor arg device_mesh = arg.device_mesh + # this function is applied to at least one DTensor argument + seen_dtensor_arg = True + assert arg.device_mesh == device_mesh, ( - f"arg {arg} in local_map has a mismatched device mesh:" - f"{arg} has device mesh {arg.device_mesh} while" + f"arg {arg} in local_map has a mismatched device mesh: " + f"{arg} has device mesh {arg.device_mesh} while " f"the expected device mesh is {device_mesh}!" ) if in_placements is not None: - spec = ( - in_placements[idx] - if isinstance(in_placements, tuple) - else in_placements - ) + spec = in_placements[idx] assert ( spec is not None ), f"DTensor input {arg} expects placements but received {spec}!" @@ -139,44 +157,62 @@ def wrapped(*args, **kwargs): arg = arg.redistribute(device_mesh, spec) else: raise ValueError( - f"arg {arg} in local_map has a mismatched placements:" - f"arg placements is {arg.placements} but the input" - f"placements is {spec}!" - "If redistribute_inputs is wanted, set redistribute_inputs=True to local_map." + f"arg {arg} in local_map has a mismatched placements: " + f"arg placements is {arg.placements} but the input " + f"placements is {spec}! " + "If redistribute_inputs is wanted, set " + "redistribute_inputs=True to local_map." ) - flat_local_args.append(arg.to_local()) + local_arg = arg.to_local() + if isinstance(local_arg, AsyncCollectiveTensor): + local_arg = local_arg.wait() + + flat_local_args.append(local_arg) else: + # Non-Tensor input must have None in `in_placements` + if in_placements is not None and not isinstance(arg, torch.Tensor): + spec = in_placements[idx] + assert spec is None, ( + f"Non-Tensor input {arg} expects None placements " + f"but received {spec}!" + ) + flat_local_args.append(arg) local_args = pytree.tree_unflatten(flat_local_args, args_spec) - out = func(device_mesh, *local_args, **kwargs) + out = func(*local_args, **kwargs) - # process output - flat_out, out_spec = pytree.tree_flatten(out) - flat_dist_out = [] - for idx, out in enumerate(flat_out): - spec = ( - out_placements[idx] - if isinstance(out_placements, tuple) - else out_placements - ) - if isinstance(out, torch.Tensor): - assert not isinstance( - out, DTensor - ), f"torch.Tensor output expected but received {type(out)}: {out}" + if seen_dtensor_arg: + # process output + flat_out, out_spec = pytree.tree_flatten(out) - flat_dist_out.append( - DTensor.from_local(out, device_mesh, spec, run_check=False) + flat_dist_out = [] + for idx, out in enumerate(flat_out): + spec = ( + out_placements[idx] + if isinstance(out_placements, tuple) + else out_placements ) - else: - assert ( - spec is None - ), f"Non-tensor output {out} expects None placements but received {spec}!" - flat_dist_out.append(out) + if isinstance(out, torch.Tensor): + assert not isinstance( + out, DTensor + ), f"torch.Tensor output expected but received {type(out)}: {out}" + + flat_dist_out.append( + DTensor.from_local(out, device_mesh, spec, run_check=False) + ) + else: + assert ( + spec is None + ), f"Non-tensor output {out} expects None placements but received {spec}!" + + flat_dist_out.append(out) - return pytree.tree_unflatten(flat_dist_out, out_spec) + return pytree.tree_unflatten(flat_dist_out, out_spec) + else: + return out return wrapped