-
Notifications
You must be signed in to change notification settings - Fork 280
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add fairscale.nn.misc.checkpoint_activations (#376)
* Add fairscale.utils.containers Co-authored-by: Min Xu <[email protected]> * Add fairscale.nn.misc.checkpoint_activations Co-authored-by: Sam Shleifer <[email protected]> Co-authored-by: Min Xu <[email protected]> Co-authored-by: Sam Shleifer <[email protected]>
- Loading branch information
1 parent
e92e85c
commit c963a72
Showing
5 changed files
with
484 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import functools | ||
from typing import Any, Dict, Optional, Tuple | ||
|
||
import torch | ||
from torch import Tensor | ||
import torch.nn as nn | ||
import torch.utils.checkpoint as checkpoint | ||
|
||
from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors | ||
|
||
|
||
def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False) -> nn.Module: | ||
""" | ||
A friendlier wrapper for performing activation checkpointing. | ||
Compared to the PyTorch version, this version: | ||
- wraps an nn.Module, so that all subsequent calls will use checkpointing | ||
- handles keyword arguments in the forward | ||
- handles non-Tensor outputs from the forward | ||
- supports offloading activations to CPU | ||
Usage:: | ||
checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True) | ||
a, b = checkpointed_module(x, y=3, z=torch.Tensor([1])) | ||
Args: | ||
module (nn.Module): module to wrap | ||
offload_to_cpu (Optional, bool): whether to offload activations to CPU | ||
""" | ||
module.forward = functools.partial(_checkpointed_forward, module.forward, offload_to_cpu) # type: ignore | ||
return module | ||
|
||
|
||
def _checkpointed_forward(original_forward: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any) -> Any: | ||
# Autograd Functions in PyTorch work best with positional args, since | ||
# the backward must return gradients (or None) for every input argument. | ||
# We can flatten keyword arguments to make this easier. | ||
kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) | ||
parent_ctx_dict: Dict[str, Any] = {"offload": offload_to_cpu} | ||
output = CheckpointFunction.apply(original_forward, parent_ctx_dict, kwarg_keys, *flat_args) | ||
if isinstance(output, torch.Tensor): | ||
return output | ||
else: | ||
packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"] | ||
if packed_non_tensor_outputs: | ||
output = unpack_non_tensors(output, packed_non_tensor_outputs) | ||
return output | ||
|
||
|
||
def get_rng_state() -> Dict[str, Any]: | ||
state = {"torch_rng_state": torch.get_rng_state()} | ||
if torch.cuda.is_available(): | ||
state["cuda_rng_state"] = torch.cuda.get_rng_state() | ||
return state | ||
|
||
|
||
def set_rng_state(state: Dict[str, Any]) -> None: | ||
torch.set_rng_state(state["torch_rng_state"]) | ||
if torch.cuda.is_available(): | ||
torch.cuda.set_rng_state(state["cuda_rng_state"]) | ||
|
||
|
||
class CheckpointFunction(torch.autograd.Function): | ||
"""Similar to the torch version, but support non-Tensor outputs. | ||
The caller is expected to provide a dict (*parent_ctx_dict*) that will hold | ||
the non-Tensor outputs. These should be combined with the Tensor *outputs* | ||
by calling :func:`unpack_non_tensors`. | ||
""" | ||
|
||
@staticmethod | ||
def forward( # type: ignore | ||
ctx: Any, | ||
run_function: Any, | ||
parent_ctx_dict: Dict[str, Any], | ||
kwarg_keys: Tuple[str, ...], | ||
*args: Any, | ||
**kwargs: Any | ||
) -> Any: | ||
if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation | ||
checkpoint.check_backward_validity(args) | ||
|
||
ctx.run_function = run_function | ||
ctx.kwarg_keys = kwarg_keys | ||
ctx.fwd_rng_state = get_rng_state() | ||
|
||
tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) | ||
if parent_ctx_dict["offload"]: | ||
ctx.fwd_device = tuple(x.device for x in tensor_inputs) | ||
ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) | ||
tensor_inputs = tuple(x.cpu() for x in tensor_inputs) | ||
|
||
else: | ||
ctx.fwd_device, ctx.grad_requirements = None, None | ||
|
||
ctx.save_for_backward(*tensor_inputs) | ||
ctx.packed_non_tensor_inputs = packed_non_tensor_inputs | ||
|
||
with torch.no_grad(): | ||
unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) | ||
outputs = run_function(*unpacked_args, **unpacked_kwargs) | ||
|
||
if isinstance(outputs, torch.Tensor): | ||
return outputs | ||
else: | ||
# Autograd Functions don't like non-Tensor outputs. We can split the | ||
# non-Tensor and Tensor outputs, returning the former by reference | ||
# through *parent_ctx_dict* and returning the latter directly. | ||
outputs, packed_non_tensor_outputs = split_non_tensors(outputs) | ||
parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs | ||
return outputs | ||
|
||
@staticmethod | ||
def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]: | ||
if not torch.autograd._is_checkpoint_valid(): | ||
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible") | ||
|
||
tensor_inputs: Tuple = ctx.saved_tensors | ||
tensor_inputs = checkpoint.detach_variable(tensor_inputs) | ||
if ctx.fwd_device is not None: | ||
tensor_inputs = tuple(t.to(ctx.fwd_device[i]) for i, t in enumerate(tensor_inputs)) | ||
for i, need_grad in enumerate(ctx.grad_requirements): | ||
tensor_inputs[i].requires_grad = need_grad | ||
inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs) | ||
|
||
# Store the current states. | ||
bwd_rng_state = get_rng_state() | ||
|
||
# Set the states to what it used to be before the forward pass. | ||
set_rng_state(ctx.fwd_rng_state) | ||
|
||
with torch.enable_grad(): | ||
unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs) | ||
outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) | ||
tensor_outputs, _ = split_non_tensors(outputs) | ||
# Set the states back to what it was at the start of this function. | ||
set_rng_state(bwd_rng_state) | ||
|
||
# Run backward() with only Tensors that require grad | ||
outputs_with_grad = [] | ||
args_with_grad = [] | ||
for i in range(len(tensor_outputs)): | ||
if tensor_outputs[i].requires_grad: | ||
outputs_with_grad.append(tensor_outputs[i]) | ||
args_with_grad.append(args[i]) | ||
if len(outputs_with_grad) == 0: | ||
raise RuntimeError("None of the outputs have requires_grad=True, " "this checkpoint() is not necessary") | ||
|
||
torch.autograd.backward(outputs_with_grad, args_with_grad) | ||
|
||
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs) | ||
return (None, None, None) + grads |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union | ||
|
||
import torch | ||
|
||
"""Useful functions to deal with tensor types with other python container types.""" | ||
|
||
|
||
def apply_to_tensors(fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any: | ||
"""Recursively apply to all tensor in 4 kinds of container types.""" | ||
|
||
def _apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any: | ||
if torch.is_tensor(x): | ||
return fn(x) | ||
elif isinstance(x, dict): | ||
return {key: _apply(value) for key, value in x.items()} | ||
elif isinstance(x, list): | ||
return [_apply(x) for x in x] | ||
elif isinstance(x, tuple): | ||
return tuple(_apply(x) for x in x) | ||
elif isinstance(x, set): | ||
return {_apply(x) for x in x} | ||
else: | ||
return x | ||
|
||
return _apply(container) | ||
|
||
|
||
def pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[str, ...], Tuple[Any, ...]]: | ||
""" | ||
Turn argument list into separate key list and value list (unpack_kwargs does the opposite) | ||
Usage:: | ||
kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) | ||
assert kwarg_keys == ("a", "b") | ||
assert flat_args == (1, 2, 3, 4) | ||
args, kwargs = unpack_kwargs(kwarg_keys, flat_args) | ||
assert args == (1, 2) | ||
assert kwargs == {"a": 3, "b": 4} | ||
""" | ||
kwarg_keys: List[str] = [] | ||
flat_args: List[Any] = list(args) | ||
for k, v in kwargs.items(): | ||
kwarg_keys.append(k) | ||
flat_args.append(v) | ||
return tuple(kwarg_keys), tuple(flat_args) | ||
|
||
|
||
def unpack_kwargs(kwarg_keys: Tuple[str, ...], flat_args: Tuple[Any, ...]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: | ||
"""See pack_kwargs.""" | ||
assert len(kwarg_keys) <= len(flat_args), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" | ||
if len(kwarg_keys) == 0: | ||
return flat_args, {} | ||
args = flat_args[: -len(kwarg_keys)] | ||
kwargs = {k: v for k, v in zip(kwarg_keys, flat_args[-len(kwarg_keys) :])} | ||
return args, kwargs | ||
|
||
|
||
def split_non_tensors( | ||
mixed: Union[torch.Tensor, Tuple[Any, ...]] | ||
) -> Tuple[Tuple[torch.Tensor, ...], Optional[Dict[str, List[Any]]]]: | ||
""" | ||
Split a tuple into a list of tensors and the rest with information | ||
for later reconstruction. | ||
Usage:: | ||
x = torch.Tensor([1]) | ||
y = torch.Tensor([2]) | ||
tensors, packed_non_tensors = split_non_tensors((x, y, None, 3)) | ||
assert tensors == (x, y) | ||
assert packed_non_tensors == { | ||
"is_tensor": [True, True, False, False], | ||
"objects": [None, 3], | ||
} | ||
recon = unpack_non_tensors(tensors, packed_non_tensors) | ||
assert recon == (x, y, None, 3) | ||
""" | ||
if isinstance(mixed, torch.Tensor): | ||
return (mixed,), None | ||
tensors: List[torch.Tensor] = [] | ||
packed_non_tensors: Dict[str, List[Any]] = {"is_tensor": [], "objects": []} | ||
for o in mixed: | ||
if isinstance(o, torch.Tensor): | ||
packed_non_tensors["is_tensor"].append(True) | ||
tensors.append(o) | ||
else: | ||
packed_non_tensors["is_tensor"].append(False) | ||
packed_non_tensors["objects"].append(o) | ||
return tuple(tensors), packed_non_tensors | ||
|
||
|
||
def unpack_non_tensors( | ||
tensors: Tuple[torch.Tensor, ...], packed_non_tensors: Optional[Dict[str, List[Any]]] | ||
) -> Tuple[Any, ...]: | ||
"""See split_non_tensors.""" | ||
if packed_non_tensors is None: | ||
return tensors | ||
assert isinstance(packed_non_tensors, dict), type(packed_non_tensors) | ||
mixed: List[Any] = [] | ||
is_tensor_list = packed_non_tensors["is_tensor"] | ||
objects = packed_non_tensors["objects"] | ||
assert len(tensors) + len(objects) == len(is_tensor_list), ( | ||
f"len(tensors) {len(tensors)} len(objects) {len(objects)} " f"len(is_tensor_list) {len(is_tensor_list)}" | ||
) | ||
obj_i = tnsr_i = 0 | ||
for is_tensor in is_tensor_list: | ||
if is_tensor: | ||
mixed.append(tensors[tnsr_i]) | ||
tnsr_i += 1 | ||
else: | ||
mixed.append(objects[obj_i]) | ||
obj_i += 1 | ||
return tuple(mixed) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
|
||
from typing import Tuple | ||
from typing import Any, Iterable, Tuple | ||
from .. import Tensor | ||
from torch.nn.modules.module import Module | ||
|
||
def detach_variable(inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]: ... | ||
def checkpoint(function: Module, *args, **kwargs): ... | ||
def check_backward_validity(inputs: Iterable[Any]): ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
""" | ||
Test fairscale.nn.misc.checkpoint_activations | ||
""" | ||
|
||
import unittest | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.utils.checkpoint import checkpoint | ||
|
||
from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper | ||
|
||
|
||
class Model(nn.Module): | ||
def __init__(self, use_pytorch_checkpoint=False, use_fairseq_checkpoint=False, **kwargs): | ||
super().__init__() | ||
torch.manual_seed(0) | ||
self.use_pytorch_checkpoint = use_pytorch_checkpoint | ||
self.ffn = nn.Sequential( | ||
nn.Linear(32, 128), | ||
# add a Dropout layer to test RNG save/restore | ||
nn.Dropout(p=0.5), | ||
nn.Linear(128, 32), | ||
) | ||
if use_fairseq_checkpoint: | ||
self.ffn = checkpoint_wrapper(self.ffn, **kwargs) | ||
self.out = nn.Linear(32, 1) | ||
|
||
def forward(self, x): | ||
if self.use_pytorch_checkpoint: | ||
x = checkpoint(self.ffn, x) | ||
else: | ||
x = self.ffn(x) | ||
return self.out(x) | ||
|
||
|
||
class TestComparisonToPyTorch(unittest.TestCase): | ||
def _test_checkpoint_wrapper(self, device, log_memory_usage=False): | ||
def get_loss_and_gnorm(model): | ||
torch.manual_seed(1) | ||
input = torch.rand(2, 16, 32).requires_grad_(True).to(device) | ||
model.zero_grad() | ||
loss = model(input).sum() | ||
loss.backward() | ||
gnorm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters()])) | ||
return {"loss": loss, "gnorm": gnorm} | ||
|
||
model = Model().to(device) | ||
no_cpt = get_loss_and_gnorm(model) | ||
|
||
model = Model(use_pytorch_checkpoint=True).to(device) | ||
pyt_cpt = get_loss_and_gnorm(model) | ||
torch.testing.assert_allclose(no_cpt["loss"], pyt_cpt["loss"]) | ||
torch.testing.assert_allclose(no_cpt["gnorm"], pyt_cpt["gnorm"]) | ||
|
||
model = Model(use_fairseq_checkpoint=True).to(device) | ||
fairseq_cpt = get_loss_and_gnorm(model) | ||
torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt["loss"]) | ||
torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt["gnorm"]) | ||
|
||
model = Model(use_fairseq_checkpoint=True, offload_to_cpu=True).to(device) | ||
fairseq_cpt_offload = get_loss_and_gnorm(model) | ||
torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt_offload["loss"]) | ||
torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt_offload["gnorm"]) | ||
|
||
def test_checkpoint_wrapper_cpu(self): | ||
self._test_checkpoint_wrapper(device=torch.device("cpu")) | ||
|
||
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") | ||
def test_checkpoint_wrapper_cuda(self): | ||
self._test_checkpoint_wrapper(device=torch.device("cuda")) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.