From c963a72a3322e08825a96b904d2b8bf6f7dad633 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 10 Feb 2021 17:55:32 -0500 Subject: [PATCH] Add fairscale.nn.misc.checkpoint_activations (#376) * Add fairscale.utils.containers Co-authored-by: Min Xu <24926999+min-xu-ai@users.noreply.github.com> * Add fairscale.nn.misc.checkpoint_activations Co-authored-by: Sam Shleifer Co-authored-by: Min Xu <24926999+min-xu-ai@users.noreply.github.com> Co-authored-by: Sam Shleifer --- fairscale/nn/misc/checkpoint_activations.py | 158 +++++++++++++++++++ fairscale/utils/containers.py | 119 ++++++++++++++ stubs/torch/utils/checkpoint.pyi | 3 +- tests/nn/misc/test_checkpoint_activations.py | 80 ++++++++++ tests/utils/test_containers.py | 125 +++++++++++++++ 5 files changed, 484 insertions(+), 1 deletion(-) create mode 100644 fairscale/nn/misc/checkpoint_activations.py create mode 100644 fairscale/utils/containers.py create mode 100644 tests/nn/misc/test_checkpoint_activations.py create mode 100644 tests/utils/test_containers.py diff --git a/fairscale/nn/misc/checkpoint_activations.py b/fairscale/nn/misc/checkpoint_activations.py new file mode 100644 index 000000000..44f29c612 --- /dev/null +++ b/fairscale/nn/misc/checkpoint_activations.py @@ -0,0 +1,158 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import functools +from typing import Any, Dict, Optional, Tuple + +import torch +from torch import Tensor +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors + + +def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False) -> nn.Module: + """ + A friendlier wrapper for performing activation checkpointing. + + Compared to the PyTorch version, this version: + - wraps an nn.Module, so that all subsequent calls will use checkpointing + - handles keyword arguments in the forward + - handles non-Tensor outputs from the forward + - supports offloading activations to CPU + + Usage:: + + checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True) + a, b = checkpointed_module(x, y=3, z=torch.Tensor([1])) + + Args: + module (nn.Module): module to wrap + offload_to_cpu (Optional, bool): whether to offload activations to CPU + """ + module.forward = functools.partial(_checkpointed_forward, module.forward, offload_to_cpu) # type: ignore + return module + + +def _checkpointed_forward(original_forward: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any) -> Any: + # Autograd Functions in PyTorch work best with positional args, since + # the backward must return gradients (or None) for every input argument. + # We can flatten keyword arguments to make this easier. + kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) + parent_ctx_dict: Dict[str, Any] = {"offload": offload_to_cpu} + output = CheckpointFunction.apply(original_forward, parent_ctx_dict, kwarg_keys, *flat_args) + if isinstance(output, torch.Tensor): + return output + else: + packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"] + if packed_non_tensor_outputs: + output = unpack_non_tensors(output, packed_non_tensor_outputs) + return output + + +def get_rng_state() -> Dict[str, Any]: + state = {"torch_rng_state": torch.get_rng_state()} + if torch.cuda.is_available(): + state["cuda_rng_state"] = torch.cuda.get_rng_state() + return state + + +def set_rng_state(state: Dict[str, Any]) -> None: + torch.set_rng_state(state["torch_rng_state"]) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(state["cuda_rng_state"]) + + +class CheckpointFunction(torch.autograd.Function): + """Similar to the torch version, but support non-Tensor outputs. + + The caller is expected to provide a dict (*parent_ctx_dict*) that will hold + the non-Tensor outputs. These should be combined with the Tensor *outputs* + by calling :func:`unpack_non_tensors`. + """ + + @staticmethod + def forward( # type: ignore + ctx: Any, + run_function: Any, + parent_ctx_dict: Dict[str, Any], + kwarg_keys: Tuple[str, ...], + *args: Any, + **kwargs: Any + ) -> Any: + if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation + checkpoint.check_backward_validity(args) + + ctx.run_function = run_function + ctx.kwarg_keys = kwarg_keys + ctx.fwd_rng_state = get_rng_state() + + tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) + if parent_ctx_dict["offload"]: + ctx.fwd_device = tuple(x.device for x in tensor_inputs) + ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) + tensor_inputs = tuple(x.cpu() for x in tensor_inputs) + + else: + ctx.fwd_device, ctx.grad_requirements = None, None + + ctx.save_for_backward(*tensor_inputs) + ctx.packed_non_tensor_inputs = packed_non_tensor_inputs + + with torch.no_grad(): + unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) + outputs = run_function(*unpacked_args, **unpacked_kwargs) + + if isinstance(outputs, torch.Tensor): + return outputs + else: + # Autograd Functions don't like non-Tensor outputs. We can split the + # non-Tensor and Tensor outputs, returning the former by reference + # through *parent_ctx_dict* and returning the latter directly. + outputs, packed_non_tensor_outputs = split_non_tensors(outputs) + parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs + return outputs + + @staticmethod + def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]: + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible") + + tensor_inputs: Tuple = ctx.saved_tensors + tensor_inputs = checkpoint.detach_variable(tensor_inputs) + if ctx.fwd_device is not None: + tensor_inputs = tuple(t.to(ctx.fwd_device[i]) for i, t in enumerate(tensor_inputs)) + for i, need_grad in enumerate(ctx.grad_requirements): + tensor_inputs[i].requires_grad = need_grad + inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs) + + # Store the current states. + bwd_rng_state = get_rng_state() + + # Set the states to what it used to be before the forward pass. + set_rng_state(ctx.fwd_rng_state) + + with torch.enable_grad(): + unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs) + outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) + tensor_outputs, _ = split_non_tensors(outputs) + # Set the states back to what it was at the start of this function. + set_rng_state(bwd_rng_state) + + # Run backward() with only Tensors that require grad + outputs_with_grad = [] + args_with_grad = [] + for i in range(len(tensor_outputs)): + if tensor_outputs[i].requires_grad: + outputs_with_grad.append(tensor_outputs[i]) + args_with_grad.append(args[i]) + if len(outputs_with_grad) == 0: + raise RuntimeError("None of the outputs have requires_grad=True, " "this checkpoint() is not necessary") + + torch.autograd.backward(outputs_with_grad, args_with_grad) + + grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs) + return (None, None, None) + grads diff --git a/fairscale/utils/containers.py b/fairscale/utils/containers.py new file mode 100644 index 000000000..dbad26337 --- /dev/null +++ b/fairscale/utils/containers.py @@ -0,0 +1,119 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import torch + +"""Useful functions to deal with tensor types with other python container types.""" + + +def apply_to_tensors(fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any: + """Recursively apply to all tensor in 4 kinds of container types.""" + + def _apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any: + if torch.is_tensor(x): + return fn(x) + elif isinstance(x, dict): + return {key: _apply(value) for key, value in x.items()} + elif isinstance(x, list): + return [_apply(x) for x in x] + elif isinstance(x, tuple): + return tuple(_apply(x) for x in x) + elif isinstance(x, set): + return {_apply(x) for x in x} + else: + return x + + return _apply(container) + + +def pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[str, ...], Tuple[Any, ...]]: + """ + Turn argument list into separate key list and value list (unpack_kwargs does the opposite) + + Usage:: + + kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) + assert kwarg_keys == ("a", "b") + assert flat_args == (1, 2, 3, 4) + args, kwargs = unpack_kwargs(kwarg_keys, flat_args) + assert args == (1, 2) + assert kwargs == {"a": 3, "b": 4} + """ + kwarg_keys: List[str] = [] + flat_args: List[Any] = list(args) + for k, v in kwargs.items(): + kwarg_keys.append(k) + flat_args.append(v) + return tuple(kwarg_keys), tuple(flat_args) + + +def unpack_kwargs(kwarg_keys: Tuple[str, ...], flat_args: Tuple[Any, ...]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + """See pack_kwargs.""" + assert len(kwarg_keys) <= len(flat_args), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" + if len(kwarg_keys) == 0: + return flat_args, {} + args = flat_args[: -len(kwarg_keys)] + kwargs = {k: v for k, v in zip(kwarg_keys, flat_args[-len(kwarg_keys) :])} + return args, kwargs + + +def split_non_tensors( + mixed: Union[torch.Tensor, Tuple[Any, ...]] +) -> Tuple[Tuple[torch.Tensor, ...], Optional[Dict[str, List[Any]]]]: + """ + Split a tuple into a list of tensors and the rest with information + for later reconstruction. + + Usage:: + + x = torch.Tensor([1]) + y = torch.Tensor([2]) + tensors, packed_non_tensors = split_non_tensors((x, y, None, 3)) + assert tensors == (x, y) + assert packed_non_tensors == { + "is_tensor": [True, True, False, False], + "objects": [None, 3], + } + recon = unpack_non_tensors(tensors, packed_non_tensors) + assert recon == (x, y, None, 3) + """ + if isinstance(mixed, torch.Tensor): + return (mixed,), None + tensors: List[torch.Tensor] = [] + packed_non_tensors: Dict[str, List[Any]] = {"is_tensor": [], "objects": []} + for o in mixed: + if isinstance(o, torch.Tensor): + packed_non_tensors["is_tensor"].append(True) + tensors.append(o) + else: + packed_non_tensors["is_tensor"].append(False) + packed_non_tensors["objects"].append(o) + return tuple(tensors), packed_non_tensors + + +def unpack_non_tensors( + tensors: Tuple[torch.Tensor, ...], packed_non_tensors: Optional[Dict[str, List[Any]]] +) -> Tuple[Any, ...]: + """See split_non_tensors.""" + if packed_non_tensors is None: + return tensors + assert isinstance(packed_non_tensors, dict), type(packed_non_tensors) + mixed: List[Any] = [] + is_tensor_list = packed_non_tensors["is_tensor"] + objects = packed_non_tensors["objects"] + assert len(tensors) + len(objects) == len(is_tensor_list), ( + f"len(tensors) {len(tensors)} len(objects) {len(objects)} " f"len(is_tensor_list) {len(is_tensor_list)}" + ) + obj_i = tnsr_i = 0 + for is_tensor in is_tensor_list: + if is_tensor: + mixed.append(tensors[tnsr_i]) + tnsr_i += 1 + else: + mixed.append(objects[obj_i]) + obj_i += 1 + return tuple(mixed) diff --git a/stubs/torch/utils/checkpoint.pyi b/stubs/torch/utils/checkpoint.pyi index f37a23ddd..003be48ae 100644 --- a/stubs/torch/utils/checkpoint.pyi +++ b/stubs/torch/utils/checkpoint.pyi @@ -1,8 +1,9 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from typing import Tuple +from typing import Any, Iterable, Tuple from .. import Tensor from torch.nn.modules.module import Module def detach_variable(inputs: Tuple[Tensor,...]) -> Tuple[Tensor,...]: ... def checkpoint(function: Module, *args, **kwargs): ... +def check_backward_validity(inputs: Iterable[Any]): ... diff --git a/tests/nn/misc/test_checkpoint_activations.py b/tests/nn/misc/test_checkpoint_activations.py new file mode 100644 index 000000000..84fdc13d5 --- /dev/null +++ b/tests/nn/misc/test_checkpoint_activations.py @@ -0,0 +1,80 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +""" +Test fairscale.nn.misc.checkpoint_activations +""" + +import unittest + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper + + +class Model(nn.Module): + def __init__(self, use_pytorch_checkpoint=False, use_fairseq_checkpoint=False, **kwargs): + super().__init__() + torch.manual_seed(0) + self.use_pytorch_checkpoint = use_pytorch_checkpoint + self.ffn = nn.Sequential( + nn.Linear(32, 128), + # add a Dropout layer to test RNG save/restore + nn.Dropout(p=0.5), + nn.Linear(128, 32), + ) + if use_fairseq_checkpoint: + self.ffn = checkpoint_wrapper(self.ffn, **kwargs) + self.out = nn.Linear(32, 1) + + def forward(self, x): + if self.use_pytorch_checkpoint: + x = checkpoint(self.ffn, x) + else: + x = self.ffn(x) + return self.out(x) + + +class TestComparisonToPyTorch(unittest.TestCase): + def _test_checkpoint_wrapper(self, device, log_memory_usage=False): + def get_loss_and_gnorm(model): + torch.manual_seed(1) + input = torch.rand(2, 16, 32).requires_grad_(True).to(device) + model.zero_grad() + loss = model(input).sum() + loss.backward() + gnorm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters()])) + return {"loss": loss, "gnorm": gnorm} + + model = Model().to(device) + no_cpt = get_loss_and_gnorm(model) + + model = Model(use_pytorch_checkpoint=True).to(device) + pyt_cpt = get_loss_and_gnorm(model) + torch.testing.assert_allclose(no_cpt["loss"], pyt_cpt["loss"]) + torch.testing.assert_allclose(no_cpt["gnorm"], pyt_cpt["gnorm"]) + + model = Model(use_fairseq_checkpoint=True).to(device) + fairseq_cpt = get_loss_and_gnorm(model) + torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt["loss"]) + torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt["gnorm"]) + + model = Model(use_fairseq_checkpoint=True, offload_to_cpu=True).to(device) + fairseq_cpt_offload = get_loss_and_gnorm(model) + torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt_offload["loss"]) + torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt_offload["gnorm"]) + + def test_checkpoint_wrapper_cpu(self): + self._test_checkpoint_wrapper(device=torch.device("cpu")) + + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") + def test_checkpoint_wrapper_cuda(self): + self._test_checkpoint_wrapper(device=torch.device("cuda")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/test_containers.py b/tests/utils/test_containers.py new file mode 100644 index 000000000..d304478cb --- /dev/null +++ b/tests/utils/test_containers.py @@ -0,0 +1,125 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring + +""" Test utility classes from containers.py. """ + +import random + +import pytest +import torch + +from fairscale.utils.containers import ( + apply_to_tensors, + pack_kwargs, + split_non_tensors, + unpack_kwargs, + unpack_non_tensors, +) + + +@pytest.mark.parametrize("devices", [["cpu"], ["cuda"], ["cpu", "cuda"]]) +def test_apply_to_tensors(devices): + """Test apply_to_tensors for both cpu & gpu""" + if "cuda" in devices and not torch.cuda.is_available() or torch.cuda.device_count() < 1: + pytest.skip("Skipped due to lack of GPU") + expected = 0 + + def get_a_tensor(): + """Return a random tensor on random device.""" + dev = random.choice(devices) + shape = random.choice(((1), (2, 3), (4, 5, 6), (7, 8, 9, 10))) + t = torch.rand(shape).to(dev) + nonlocal expected + expected += t.numel() + return t + + # create a mixed bag of data. + data = [1, "str"] + data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3}) + data.insert(0, set(["x", get_a_tensor(), get_a_tensor()])) + data.append(([1], get_a_tensor(), (1), [get_a_tensor()], set((1, 2)))) + + total = 0 + + def fn(t, x=[[total]]): + nonlocal total + total += t.numel() + return t + + apply_to_tensors(fn, data) + assert total == expected, f"{total} vs. {expected}" + + +def test_pack_unpack(): + """Test pack_kwargs and unpack_kwargs.""" + kwarg_keys, flat_args = pack_kwargs(1, 2, 3, 4) + assert kwarg_keys == tuple() + assert flat_args == (1, 2, 3, 4) + + kwarg_keys, flat_args = pack_kwargs(a=1, b={2: "2"}, c={3}, d=[4], e=(5,)) + assert kwarg_keys == ("a", "b", "c", "d", "e") + assert flat_args == (1, {2: "2"}, {3}, [4], (5,)) + + kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) + assert kwarg_keys == ("a", "b") + assert flat_args == (1, 2, 3, 4) + + args, kwargs = unpack_kwargs(kwarg_keys, flat_args) + assert args == (1, 2) + assert kwargs == {"a": 3, "b": 4} + + args, kwargs = unpack_kwargs([], flat_args) + assert kwargs == {} + assert args == (1, 2, 3, 4) + + args, kwargs = unpack_kwargs(["a", "b", "c", "d"], flat_args) + assert kwargs == {"a": 1, "b": 2, "c": 3, "d": 4} + assert args == tuple() + + with pytest.raises(AssertionError): + # too many keys should assert. + args, kwargs = unpack_kwargs(["a", "b", "c", "d", "e"], flat_args) + + +def test_split_unpack(): + """Test split_non_tensors and unpack_non_tensors.""" + x = torch.Tensor([1]) + y = torch.Tensor([2]) + + tensors, packed_non_tensors = split_non_tensors((x, y, None, 3)) + assert tensors == (x, y) + assert packed_non_tensors == { + "is_tensor": [True, True, False, False], + "objects": [None, 3], + } + recon = unpack_non_tensors(tensors, packed_non_tensors) + assert recon == (x, y, None, 3) + + tensors, packed_non_tensors = split_non_tensors((None, 3, x, y)) + recon = unpack_non_tensors(tensors, packed_non_tensors) + assert recon == (None, 3, x, y) + + tensors, packed_non_tensors = split_non_tensors((None, 3)) + recon = unpack_non_tensors(tensors, packed_non_tensors) + assert recon == (None, 3) + + tensors, packed_non_tensors = split_non_tensors((x, y)) + recon = unpack_non_tensors(tensors, packed_non_tensors) + assert recon == (x, y) + + recon = unpack_non_tensors(tensors, None) + assert recon == (x, y) + + with pytest.raises(AssertionError): + # assert the second arg should be a dict. + recon = unpack_non_tensors(tensors, set()) + + with pytest.raises(AssertionError): + # assert the content of the second arg should be sane. + recon = unpack_non_tensors(tensors, {"is_tensor": [], "objects": []})