Skip to content

Commit

Permalink
[dtensor] local_map UX change: keep func signature and be compatible …
Browse files Browse the repository at this point in the history
…with Tensor input (pytorch#126924)

**Summary**
This PR has 2 parts of change in `local_map`:

1. regulates the way user can access `DeviceMesh` inside the `func` argument of `local_map`. This means `local_map` will strictly follow the `func` signature without implicitly passing any argument to `func`. If user wants to use `DeviceMesh` inside `func`, this mesh must be explicitly passed to `func` as an argument by user. For example,

```
def user_function(device_mesh, /, *args, **kwargs):
    USER CODE HERE

local_func = local_map(func=user_function, ...)
dtensor_out = local_func(device_mesh, dtensor_input, ...)
```

Before this PR, user code was like:
```
def user_function(device_mesh, /, *args, **kwargs):
    USER CODE HERE

local_func = local_map(func=user_function, ...)
dtensor_out = local_func(dtensor_input, ...)  # local_map passes mesh implicitly for user
```

2. `local_map` now supports mix use of `torch.Tensor` and `DTensor` in argument:

- Pure torch.Tensor case: no `DTensor` argument is passed in, all tensor arguments are `torch.Tensor`. Bypass the `in_placements` check and unwrapping steps. The output will not be wrapped into `DTensor` but directly returned.
- Pure DTensor case: no `torch.Tensor` argument is passed in, all tensor arguments are `DTensor`. This follows the default rule: `in_placements` check, unwrapping arguments, pass into `func`, wrapping the `torch.Tensor` output into `DTensor` if the `out_placements` is not `None`.
- Mix of the above two: some arguments are `torch.Tensor` while some are `DTensor`. Only perform `in_placements` check and unwrapping on `DTensor` arguments. For output processing, it's the same as Pure DTensor case.

**Test**
`pytest test/distributed/_tensor/experimental/test_local_map.py`

Pull Request resolved: pytorch#126924
Approved by: https://github.com/wanchaol
  • Loading branch information
XilunWu authored and pytorchmergebot committed Jun 3, 2024
1 parent 2d1ad0c commit e017b56
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 70 deletions.
123 changes: 104 additions & 19 deletions test/distributed/_tensor/experimental/test_local_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.distributed._functional_collectives as funcol
from torch.distributed._tensor import (
distribute_tensor,
DTensor,
init_device_mesh,
Replicate,
Shard,
Expand All @@ -18,23 +19,30 @@
)


def equal_forward(device_mesh, X, Y):
funcol_py = torch.ops.c10d_functional


def equal_allgather_forward(device_mesh, X, Y):
eq = torch.tensor([torch.equal(X, Y)], device=X.device)
eq_gather = funcol.all_gather_tensor(eq, 0, device_mesh)
return torch.all(eq_gather).item()


def mm_forward(device_mesh, W, X):
return torch.mm(W, X)
def mm_all_gather_forward(device_mesh, A, B):
local_mm_result = torch.mm(A, B)
return funcol.all_gather_tensor(local_mm_result, 0, device_mesh).wait()


def mm_forward(A, B): # no device mesh needed since we don't do collective
return torch.mm(A, B)

def mm_allreduce_forward(device_mesh, W, X):
partial_sum_tensor = torch.mm(W, X)
reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait()
return reduced_tensor

def mm_allreduce_forward(device_mesh, A, B):
partial_sum_tensor = torch.mm(A, B)
return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait()

def mul_forward(device_mesh, X, scalar):

def mul_forward(X, scalar): # no device mesh needed since we don't do collective
return torch.mul(X, scalar)


Expand All @@ -58,6 +66,7 @@ def test_local_map_correctness(self):

row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh
col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh
replicate = [Replicate()]
W_dt = distribute_tensor(
W, device_mesh, col_wise
) # col-wisely sharded W tensor
Expand All @@ -70,12 +79,12 @@ def test_local_map_correctness(self):
# DTensors' `_local_tensor`.
local_mm_allreduce_forward = local_map(
mm_allreduce_forward,
out_placements=[Replicate()],
in_placements=(col_wise, row_wise),
out_placements=replicate,
in_placements=(None, col_wise, row_wise),
device_mesh=device_mesh,
)
with comm_mode:
Y_dt = local_mm_allreduce_forward(W_dt, X_dt)
Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt)

# output redistribution to Replicate
self.assertEqual(comm_mode.get_total_counts(), 1)
Expand All @@ -88,6 +97,7 @@ def test_local_map_correctness(self):
# check for `out_placements`
@with_comms
def test_local_map_out_placements(self):
# Test 1: wrap out into DTensor w/ `out_placements`
device_mesh = init_device_mesh(
device_type=self.device_type, mesh_shape=(self.world_size,)
)
Expand All @@ -99,14 +109,40 @@ def test_local_map_out_placements(self):
row_wise = [Shard(0)]
X_dt = distribute_tensor(X, device_mesh, row_wise)
Y_dt = distribute_tensor(Y, device_mesh, row_wise)
local_equal_forward = local_map(equal_forward, out_placements=None)
local_equal_allgather_forward = local_map(
equal_allgather_forward,
out_placements=None,
)
with comm_mode:
equal_dt = local_equal_forward(X_dt, Y_dt) # a bool
equal_dt = local_equal_allgather_forward(device_mesh, X_dt, Y_dt) # a bool

self.assertEqual(comm_mode.get_total_counts(), 1)
self.assertTrue(not equal_dt)
self.assertTrue(not (X.equal(Y)))

# Test 2: directly return out if no argument is DTensor
# matmul in DDP
replicate = [Replicate()]
X = torch.randn(
4 // self.world_size, 4, device=self.device_type, requires_grad=False
)
W = torch.randn(4, 4, device=self.device_type, requires_grad=False)
local_mm_all_gather_forward = local_map(
mm_all_gather_forward,
out_placements=row_wise,
in_placements=(None, row_wise, replicate),
)
with comm_mode:
Y = local_mm_all_gather_forward(device_mesh, X, W)

self.assertEqual(comm_mode.get_total_counts(), 1)
self.assertEqual(
comm_mode.get_comm_counts()[funcol_py.all_gather_into_tensor], 1
)
X_replicate = funcol.all_gather_tensor(X, 0, device_mesh).wait()
Y_replicate = torch.mm(X_replicate, W)
self.assertEqual(Y, Y_replicate) # Y is a torch.Tensor

# check for `in_placements` handling
@with_comms
def test_local_map_in_placements(self):
Expand Down Expand Up @@ -173,6 +209,54 @@ def test_local_map_in_placements(self):
self.assertTrue(placement.is_shard(dim=0))
self.assertEqual(Y_dt.full_tensor(), Y)

# Test 4: `None` placements for Tensor input argument
X = torch.randn(16, 8, device=self.device_type, requires_grad=False)
W = torch.randn(8, 12, device=self.device_type, requires_grad=False)
X_dt = distribute_tensor(
X, device_mesh, row_wise
) # row-wisely sharded X tensor
W_dt = distribute_tensor(W, device_mesh, replicate) # replicate W tensor
local_mm_forward = local_map(
mm_forward,
out_placements=None,
in_placements=(None, None),
device_mesh=device_mesh,
)
with comm_mode:
Y_dt_local = local_mm_forward(X_dt.to_local(), W_dt.to_local())

self.assertEqual(comm_mode.get_total_counts(), 0)
self.assertEqual(
DTensor.from_local(Y_dt_local, device_mesh, row_wise).full_tensor(),
torch.mm(X, W),
)

# Test 5: Some placements for Tensor input argument
local_mm_forward = local_map(
mm_forward,
out_placements=None,
in_placements=(replicate, row_wise),
device_mesh=device_mesh,
)
with comm_mode:
Y_dt_local = local_mm_forward(X_dt.to_local(), W_dt.to_local())

self.assertEqual(comm_mode.get_total_counts(), 0)
self.assertEqual(
DTensor.from_local(Y_dt_local, device_mesh, row_wise).full_tensor(),
torch.mm(X, W),
)

# Test 6: expect error - `None` placements for DTensor input argument
local_mm_forward = local_map(
mm_forward,
out_placements=row_wise,
in_placements=(row_wise, None),
device_mesh=device_mesh,
)
with self.assertRaisesRegex(AssertionError, "expects placements"):
Y_dt = local_mm_forward(X_dt, W_dt)

# check for `redistribute_inputs` handling
@with_comms
def test_local_map_redistribute(self):
Expand All @@ -188,6 +272,7 @@ def test_local_map_redistribute(self):

row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh
col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh
replicate = [Replicate()]
W_dt = distribute_tensor(
W, device_mesh, row_wise
) # row-wisely sharded W tensor which will be redistributed
Expand All @@ -198,13 +283,13 @@ def test_local_map_redistribute(self):
# Test 1: allow input redistribution
local_mm_allreduce_forward = local_map(
mm_allreduce_forward,
out_placements=[Replicate()],
in_placements=(col_wise, row_wise),
out_placements=replicate,
in_placements=(None, col_wise, row_wise),
device_mesh=device_mesh,
redistribute_inputs=True,
)
with comm_mode:
Y_dt = local_mm_allreduce_forward(W_dt, X_dt)
Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt)

# 2 for input redistribution and 1 for output
self.assertEqual(comm_mode.get_total_counts(), 3)
Expand All @@ -215,13 +300,13 @@ def test_local_map_redistribute(self):
# Test 2: no input redistribution is allowed
local_mm_allreduce_forward = local_map(
mm_allreduce_forward,
out_placements=[Replicate()],
in_placements=(col_wise, row_wise),
out_placements=replicate,
in_placements=(None, col_wise, row_wise),
device_mesh=device_mesh,
redistribute_inputs=False,
)
with self.assertRaisesRegex(ValueError, "set redistribute_inputs=True"):
Y_dt = local_mm_allreduce_forward(W_dt, X_dt)
Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit e017b56

Please sign in to comment.