forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[dtensor] local_map UX change: keep func signature and be compatible …
…with Tensor input (pytorch#126924) **Summary** This PR has 2 parts of change in `local_map`: 1. regulates the way user can access `DeviceMesh` inside the `func` argument of `local_map`. This means `local_map` will strictly follow the `func` signature without implicitly passing any argument to `func`. If user wants to use `DeviceMesh` inside `func`, this mesh must be explicitly passed to `func` as an argument by user. For example, ``` def user_function(device_mesh, /, *args, **kwargs): USER CODE HERE local_func = local_map(func=user_function, ...) dtensor_out = local_func(device_mesh, dtensor_input, ...) ``` Before this PR, user code was like: ``` def user_function(device_mesh, /, *args, **kwargs): USER CODE HERE local_func = local_map(func=user_function, ...) dtensor_out = local_func(dtensor_input, ...) # local_map passes mesh implicitly for user ``` 2. `local_map` now supports mix use of `torch.Tensor` and `DTensor` in argument: - Pure torch.Tensor case: no `DTensor` argument is passed in, all tensor arguments are `torch.Tensor`. Bypass the `in_placements` check and unwrapping steps. The output will not be wrapped into `DTensor` but directly returned. - Pure DTensor case: no `torch.Tensor` argument is passed in, all tensor arguments are `DTensor`. This follows the default rule: `in_placements` check, unwrapping arguments, pass into `func`, wrapping the `torch.Tensor` output into `DTensor` if the `out_placements` is not `None`. - Mix of the above two: some arguments are `torch.Tensor` while some are `DTensor`. Only perform `in_placements` check and unwrapping on `DTensor` arguments. For output processing, it's the same as Pure DTensor case. **Test** `pytest test/distributed/_tensor/experimental/test_local_map.py` Pull Request resolved: pytorch#126924 Approved by: https://github.com/wanchaol
- Loading branch information
1 parent
2d1ad0c
commit e017b56
Showing
2 changed files
with
191 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.