diff --git a/flowtorch/bijectors/__init__.py b/flowtorch/bijectors/__init__.py index 57e3c1ad..62800a88 100644 --- a/flowtorch/bijectors/__init__.py +++ b/flowtorch/bijectors/__init__.py @@ -16,6 +16,8 @@ from flowtorch.bijectors.autoregressive import Autoregressive from flowtorch.bijectors.base import Bijector from flowtorch.bijectors.compose import Compose +from flowtorch.bijectors.coupling import ConvCouplingBijector +from flowtorch.bijectors.coupling import CouplingBijector from flowtorch.bijectors.elementwise import Elementwise from flowtorch.bijectors.elu import ELU from flowtorch.bijectors.exp import Exp @@ -28,6 +30,8 @@ from flowtorch.bijectors.softplus import Softplus from flowtorch.bijectors.spline import Spline from flowtorch.bijectors.spline_autoregressive import SplineAutoregressive +from flowtorch.bijectors.split_bijector import ReshapeBijector +from flowtorch.bijectors.split_bijector import SplitBijector from flowtorch.bijectors.tanh import Tanh from flowtorch.bijectors.volume_preserving import VolumePreserving @@ -35,6 +39,8 @@ ("Affine", Affine), ("AffineAutoregressive", AffineAutoregressive), ("AffineFixed", AffineFixed), + ("ConvCouplingBijector", ConvCouplingBijector), + ("CouplingBijector", CouplingBijector), ("ELU", ELU), ("Exp", Exp), ("LeakyReLU", LeakyReLU), @@ -55,6 +61,8 @@ ("Compose", Compose), ("Invert", Invert), ("VolumePreserving", VolumePreserving), + ("ReshapeBijector", ReshapeBijector), + ("SplitBijector", SplitBijector), ] diff --git a/flowtorch/bijectors/affine_autoregressive.py b/flowtorch/bijectors/affine_autoregressive.py index 610e5477..a855cf5d 100644 --- a/flowtorch/bijectors/affine_autoregressive.py +++ b/flowtorch/bijectors/affine_autoregressive.py @@ -16,15 +16,28 @@ def __init__( *, shape: torch.Size, context_shape: Optional[torch.Size] = None, + clamp_values: bool = False, log_scale_min_clip: float = -5.0, log_scale_max_clip: float = 3.0, sigmoid_bias: float = 2.0, + positive_map: str = "softplus", + positive_bias: Optional[float] = None, ) -> None: - super().__init__( + AffineOp.__init__( + self, + params_fn, + shape=shape, + context_shape=context_shape, + clamp_values=clamp_values, + log_scale_min_clip=log_scale_min_clip, + log_scale_max_clip=log_scale_max_clip, + sigmoid_bias=sigmoid_bias, + positive_map=positive_map, + positive_bias=positive_bias, + ) + Autoregressive.__init__( + self, params_fn, shape=shape, context_shape=context_shape, ) - self.log_scale_min_clip = log_scale_min_clip - self.log_scale_max_clip = log_scale_max_clip - self.sigmoid_bias = sigmoid_bias diff --git a/flowtorch/bijectors/affine_fixed.py b/flowtorch/bijectors/affine_fixed.py index af916519..89a717c9 100644 --- a/flowtorch/bijectors/affine_fixed.py +++ b/flowtorch/bijectors/affine_fixed.py @@ -24,7 +24,7 @@ def __init__( shape: torch.Size, context_shape: Optional[torch.Size] = None, loc: float = 0.0, - scale: float = 1.0 + scale: float = 1.0, ) -> None: super().__init__(params_fn, shape=shape, context_shape=context_shape) self.loc = loc @@ -32,9 +32,10 @@ def __init__( def _forward( self, - x: torch.Tensor, + *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]], ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = inputs[0] y = self.loc + self.scale * x ladj: Optional[torch.Tensor] = None if requires_log_detJ(): @@ -42,8 +43,9 @@ def _forward( return y, ladj def _inverse( - self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + y = inputs[0] x = (y - self.loc) / self.scale ladj: Optional[torch.Tensor] = None if requires_log_detJ(): @@ -51,6 +53,9 @@ def _inverse( return x, ladj def _log_abs_det_jacobian( - self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, + x: torch.Tensor, + y: torch.Tensor, + params: Optional[Sequence[torch.Tensor]], ) -> torch.Tensor: return torch.full_like(x, math.log(abs(self.scale))) diff --git a/flowtorch/bijectors/autoregressive.py b/flowtorch/bijectors/autoregressive.py index 8367b51b..06ef3fd5 100644 --- a/flowtorch/bijectors/autoregressive.py +++ b/flowtorch/bijectors/autoregressive.py @@ -7,7 +7,10 @@ import torch import torch.distributions.constraints as constraints from flowtorch.bijectors.base import Bijector -from flowtorch.bijectors.bijective_tensor import BijectiveTensor, to_bijective_tensor +from flowtorch.bijectors.bijective_tensor import ( + BijectiveTensor, + to_bijective_tensor, +) from flowtorch.bijectors.utils import is_record_flow_graph_enabled from flowtorch.parameters.dense_autoregressive import DenseAutoregressive @@ -60,7 +63,7 @@ def inverse( # TODO: Make permutation, inverse work for other event shapes log_detJ: Optional[torch.Tensor] = None for idx in cast(torch.LongTensor, permutation): - _params = self._params_fn(x_new.clone(), context=context) + _params = self._params_fn(x_new.clone(), inverse=False, context=context) x_temp, log_detJ = self._inverse(y, params=_params) x_new[..., idx] = x_temp[..., idx] # _log_detJ = out[1] @@ -78,6 +81,9 @@ def inverse( return x_new def _log_abs_det_jacobian( - self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, + x: torch.Tensor, + y: torch.Tensor, + params: Optional[Sequence[torch.Tensor]], ) -> torch.Tensor: raise NotImplementedError diff --git a/flowtorch/bijectors/base.py b/flowtorch/bijectors/base.py index 2a3d0f01..087d1b5e 100644 --- a/flowtorch/bijectors/base.py +++ b/flowtorch/bijectors/base.py @@ -7,13 +7,17 @@ import flowtorch.parameters import torch import torch.distributions -from flowtorch.bijectors.bijective_tensor import BijectiveTensor, to_bijective_tensor +from flowtorch.bijectors.bijective_tensor import ( + BijectiveTensor, + to_bijective_tensor, +) from flowtorch.bijectors.utils import is_record_flow_graph_enabled from flowtorch.parameters import Parameters from torch.distributions import constraints ParamFnType = Callable[ - [Optional[torch.Tensor], Optional[torch.Tensor]], Optional[Sequence[torch.Tensor]] + [Optional[torch.Tensor], Optional[torch.Tensor]], + Optional[Sequence[torch.Tensor]], ] @@ -60,6 +64,9 @@ def _check_bijective_x( and x.check_context(context) ) + def _forward_pre_ops(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]: + return (x,) + def forward( self, x: torch.Tensor, @@ -71,8 +78,13 @@ def forward( assert isinstance(x, BijectiveTensor) return x.get_parent_from_bijector(self) - params = self._params_fn(x, context) if self._params_fn is not None else None - y, log_detJ = self._forward(x, params) + x_tuple = self._forward_pre_ops(x) + params = ( + self._params_fn(*x_tuple, inverse=False, context=context) + if self._params_fn is not None + else None + ) + y, log_detJ = self._forward(*x_tuple, params=params) if ( is_record_flow_graph_enabled() and not isinstance(y, BijectiveTensor) @@ -84,7 +96,7 @@ def forward( def _forward( self, - x: torch.Tensor, + *x: torch.Tensor, params: Optional[Sequence[torch.Tensor]], ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ @@ -104,6 +116,9 @@ def _check_bijective_y( and y.check_context(context) ) + def _inverse_pre_ops(self, y: torch.Tensor) -> Tuple[torch.Tensor, ...]: + return (y,) + def inverse( self, y: torch.Tensor, @@ -117,8 +132,13 @@ def inverse( return y.get_parent_from_bijector(self) # TODO: What to do in this line? - params = self._params_fn(x, context) if self._params_fn is not None else None - x, log_detJ = self._inverse(y, params) + y_tuple = self._inverse_pre_ops(y) + params = ( + self._params_fn(*y_tuple, inverse=True, context=context) + if self._params_fn is not None + else None + ) + x, log_detJ = self._inverse(*y_tuple, params=params) if ( is_record_flow_graph_enabled() @@ -130,7 +150,7 @@ def inverse( def _inverse( self, - y: torch.Tensor, + *y: torch.Tensor, params: Optional[Sequence[torch.Tensor]], ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ @@ -170,10 +190,12 @@ def log_abs_det_jacobian( if ladj is None: if is_record_flow_graph_enabled(): warnings.warn( - "Computing _log_abs_det_jacobian from values and not " "from cache." + "Computing _log_abs_det_jacobian from values and not from cache." ) params = ( - self._params_fn(x, context) if self._params_fn is not None else None + self._params_fn(x, inverse=False, context=context) + if self._params_fn is not None + else None ) return self._log_abs_det_jacobian(x, y, params) return ladj diff --git a/flowtorch/bijectors/compose.py b/flowtorch/bijectors/compose.py index 5253cacc..65ebb064 100644 --- a/flowtorch/bijectors/compose.py +++ b/flowtorch/bijectors/compose.py @@ -6,8 +6,14 @@ import torch import torch.distributions from flowtorch.bijectors.base import Bijector -from flowtorch.bijectors.bijective_tensor import BijectiveTensor, to_bijective_tensor -from flowtorch.bijectors.utils import is_record_flow_graph_enabled, requires_log_detJ +from flowtorch.bijectors.bijective_tensor import ( + BijectiveTensor, + to_bijective_tensor, +) +from flowtorch.bijectors.utils import ( + is_record_flow_graph_enabled, + requires_log_detJ, +) from torch.distributions.utils import _sum_rightmost diff --git a/flowtorch/bijectors/coupling.py b/flowtorch/bijectors/coupling.py new file mode 100644 index 00000000..5d647c03 --- /dev/null +++ b/flowtorch/bijectors/coupling.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc +from copy import deepcopy +from typing import Optional, Sequence, Tuple + +import flowtorch.parameters +import torch +from flowtorch.bijectors.ops.affine import Affine as AffineOp +from flowtorch.parameters import ConvCoupling, DenseCoupling +from torch.distributions import constraints + + +_REAL3d = deepcopy(constraints.real) +_REAL3d.event_dim = 3 + +_REAL1d = deepcopy(constraints.real) +_REAL1d.event_dim = 1 + + +class CouplingBijector(AffineOp): + """ + Examples: + >>> params = DenseCoupling() + >>> bij = CouplingBijector(params) + >>> bij = bij(shape=torch.Size([32,])) + >>> for p in bij.parameters(): + ... p.data += torch.randn_like(p)/10 + >>> x = torch.randn(1, 32,requires_grad=True) + >>> y = bij.forward(x).detach_from_flow() + >>> x_bis = bij.inverse(y) + >>> torch.testing.assert_allclose(x, x_bis) + """ + + domain: constraints.Constraint = _REAL1d + codomain: constraints.Constraint = _REAL1d + + def __init__( + self, + params_fn: Optional[flowtorch.Lazy] = None, + *, + shape: torch.Size, + context_shape: Optional[torch.Size] = None, + clamp_values: bool = False, + log_scale_min_clip: float = -5.0, + log_scale_max_clip: float = 3.0, + sigmoid_bias: float = 2.0, + positive_map: str = "softplus", + positive_bias: Optional[float] = None, + ) -> None: + + if params_fn is None: + params_fn = DenseCoupling() # type: ignore + + AffineOp.__init__( + self, + params_fn, + shape=shape, + context_shape=context_shape, + clamp_values=clamp_values, + log_scale_min_clip=log_scale_min_clip, + log_scale_max_clip=log_scale_max_clip, + sigmoid_bias=sigmoid_bias, + positive_map=positive_map, + positive_bias=positive_bias, + ) + + def _forward( + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = inputs[0] + assert self._params_fn is not None + + y, ldj = super()._forward(x, params=params) + return y, ldj + + def _inverse( + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + y = inputs[0] + assert self._params_fn is not None + + x, ldj = super()._inverse(y, params=params) + return x, ldj + + +class ConvCouplingBijector(CouplingBijector): + """ + Examples: + >>> params = ConvCoupling() + >>> bij = ConvCouplingBijector(params) + >>> bij = bij(shape=torch.Size([3,16,16])) + >>> for p in bij.parameters(): + ... p.data += torch.randn_like(p)/10 + >>> x = torch.randn(4, 3, 16, 16) + >>> y = bij.forward(x) + >>> x_bis = bij.inverse(y.detach_from_flow()) + >>> torch.testing.assert_allclose(x, x_bis) + """ + + domain: constraints.Constraint = _REAL3d + codomain: constraints.Constraint = _REAL3d + + def __init__( + self, + params_fn: Optional[flowtorch.Lazy] = None, + *, + shape: torch.Size, + context_shape: Optional[torch.Size] = None, + clamp_values: bool = False, + log_scale_min_clip: float = -5.0, + log_scale_max_clip: float = 3.0, + sigmoid_bias: float = 2.0, + positive_map: str = "softplus", + positive_bias: Optional[float] = None, + ) -> None: + + if not len(shape) == 3: + raise ValueError(f"Expected a 3d-tensor shape, got {shape}") + + if params_fn is None: + params_fn = ConvCoupling() # type: ignore + + AffineOp.__init__( + self, + params_fn, + shape=shape, + context_shape=context_shape, + clamp_values=clamp_values, + log_scale_min_clip=log_scale_min_clip, + log_scale_max_clip=log_scale_max_clip, + sigmoid_bias=sigmoid_bias, + positive_map=positive_map, + positive_bias=positive_bias, + ) diff --git a/flowtorch/bijectors/elu.py b/flowtorch/bijectors/elu.py index ac5a3494..d05f6fea 100644 --- a/flowtorch/bijectors/elu.py +++ b/flowtorch/bijectors/elu.py @@ -15,15 +15,18 @@ class ELU(Fixed): # TODO: Setting the alpha value of ELU as __init__ argument def _forward( - self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = inputs[0] y = F.elu(x) ladj = self._log_abs_det_jacobian(x, y, params) return y, ladj def _inverse( - self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + y = inputs[0] + x = torch.max(y, torch.zeros_like(y)) + torch.min( torch.log1p(y + eps), torch.zeros_like(y) ) @@ -31,6 +34,9 @@ def _inverse( return x, ladj def _log_abs_det_jacobian( - self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, + x: torch.Tensor, + y: torch.Tensor, + params: Optional[Sequence[torch.Tensor]], ) -> torch.Tensor: return -F.relu(-x) diff --git a/flowtorch/bijectors/exp.py b/flowtorch/bijectors/exp.py index 2856d312..a0e2e461 100644 --- a/flowtorch/bijectors/exp.py +++ b/flowtorch/bijectors/exp.py @@ -14,20 +14,26 @@ class Exp(Fixed): codomain = constraints.positive def _forward( - self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = inputs[0] y = torch.exp(x) ladj = self._log_abs_det_jacobian(x, y, params) return y, ladj def _inverse( - self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + y = inputs[0] + x = y.log() ladj = self._log_abs_det_jacobian(x, y, params) return x, ladj def _log_abs_det_jacobian( - self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, + x: torch.Tensor, + y: torch.Tensor, + params: Optional[Sequence[torch.Tensor]], ) -> torch.Tensor: return x diff --git a/flowtorch/bijectors/leaky_relu.py b/flowtorch/bijectors/leaky_relu.py index 79ce58f4..bbbaa59b 100644 --- a/flowtorch/bijectors/leaky_relu.py +++ b/flowtorch/bijectors/leaky_relu.py @@ -12,21 +12,27 @@ class LeakyReLU(Fixed): # TODO: Setting the slope of Leaky ReLU as __init__ argument def _forward( - self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = inputs[0] y = F.leaky_relu(x) ladj = self._log_abs_det_jacobian(x, y, params) return y, ladj def _inverse( - self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + y = inputs[0] + x = F.leaky_relu(y, negative_slope=100.0) ladj = self._log_abs_det_jacobian(x, y, params) return x, ladj def _log_abs_det_jacobian( - self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, + x: torch.Tensor, + y: torch.Tensor, + params: Optional[Sequence[torch.Tensor]], ) -> torch.Tensor: return torch.where( x >= 0.0, torch.zeros_like(x), torch.ones_like(x) * math.log(0.01) diff --git a/flowtorch/bijectors/ops/affine.py b/flowtorch/bijectors/ops/affine.py index d9cdf56f..058ce690 100644 --- a/flowtorch/bijectors/ops/affine.py +++ b/flowtorch/bijectors/ops/affine.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc -from typing import Optional, Sequence, Tuple +from typing import Callable, Dict, Optional, Sequence, Tuple import flowtorch import torch @@ -8,6 +8,17 @@ from flowtorch.ops import clamp_preserve_gradients from torch.distributions.utils import _sum_rightmost +_DEFAULT_POSITIVE_BIASES = { + "softplus": 0.5413248538970947, + "exp": 0.0, +} + +_POSITIVE_MAPS: Dict[str, Callable[[torch.Tensor], torch.Tensor]] = { + "softplus": torch.nn.functional.softplus, + "sigmoid": torch.sigmoid, + "exp": torch.exp, +} + class Affine(Bijector): r""" @@ -22,38 +33,70 @@ def __init__( *, shape: torch.Size, context_shape: Optional[torch.Size] = None, + clamp_values: bool = False, log_scale_min_clip: float = -5.0, log_scale_max_clip: float = 3.0, sigmoid_bias: float = 2.0, + positive_map: str = "softplus", + positive_bias: Optional[float] = None, ) -> None: super().__init__(params_fn, shape=shape, context_shape=context_shape) + self.clamp_values = clamp_values self.log_scale_min_clip = log_scale_min_clip self.log_scale_max_clip = log_scale_max_clip self.sigmoid_bias = sigmoid_bias + if positive_bias is None: + positive_bias = _DEFAULT_POSITIVE_BIASES[positive_map] + self.positive_bias = positive_bias + if positive_map not in _POSITIVE_MAPS: + raise RuntimeError(f"Unknwon positive map {positive_map}") + self._positive_map = _POSITIVE_MAPS[positive_map] + self._exp_map = self._positive_map is torch.exp and self.positive_bias == 0 + + def positive_map(self, x: torch.Tensor) -> torch.Tensor: + return self._positive_map(x + self.positive_bias) + def _forward( - self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] - ) -> Tuple[torch.Tensor, torch.Tensor]: + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = inputs[0] assert params is not None - mean, log_scale = params - log_scale = clamp_preserve_gradients( - log_scale, self.log_scale_min_clip, self.log_scale_max_clip - ) - scale = torch.exp(log_scale) + mean, unbounded_scale = params + if self.clamp_values: + unbounded_scale = clamp_preserve_gradients( + unbounded_scale, + self.log_scale_min_clip, + self.log_scale_max_clip, + ) + scale = self.positive_map(unbounded_scale) + log_scale = scale.log() if not self._exp_map else unbounded_scale y = scale * x + mean return y, _sum_rightmost(log_scale, self.domain.event_dim) def _inverse( - self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] - ) -> Tuple[torch.Tensor, torch.Tensor]: - assert params is not None + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + y = inputs[0] + assert ( + params is not None + ), f"{self.__class__.__name__}._inverse got no parameters" - mean, log_scale = params - log_scale = clamp_preserve_gradients( - log_scale, self.log_scale_min_clip, self.log_scale_max_clip - ) - inverse_scale = torch.exp(-log_scale) + mean, unbounded_scale = params + if self.clamp_values: + unbounded_scale = clamp_preserve_gradients( + unbounded_scale, + self.log_scale_min_clip, + self.log_scale_max_clip, + ) + + if not self._exp_map: + inverse_scale = self.positive_map(unbounded_scale).reciprocal() + log_scale = -inverse_scale.log() + else: + inverse_scale = torch.exp(-unbounded_scale) + log_scale = unbounded_scale x_new = (y - mean) * inverse_scale return x_new, _sum_rightmost(log_scale, self.domain.event_dim) @@ -65,9 +108,17 @@ def _log_abs_det_jacobian( ) -> torch.Tensor: assert params is not None - _, log_scale = params - log_scale = clamp_preserve_gradients( - log_scale, self.log_scale_min_clip, self.log_scale_max_clip + _, unbounded_scale = params + if self.clamp_values: + unbounded_scale = clamp_preserve_gradients( + unbounded_scale, + self.log_scale_min_clip, + self.log_scale_max_clip, + ) + log_scale = ( + self.positive_map(unbounded_scale).log() + if not self._exp_map + else unbounded_scale ) return _sum_rightmost(log_scale, self.domain.event_dim) diff --git a/flowtorch/bijectors/ops/spline.py b/flowtorch/bijectors/ops/spline.py index 9145cad3..e7b92e5a 100644 --- a/flowtorch/bijectors/ops/spline.py +++ b/flowtorch/bijectors/ops/spline.py @@ -44,14 +44,16 @@ def __init__( super().__init__(params_fn, shape=shape, context_shape=context_shape) def _forward( - self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] - ) -> Tuple[torch.Tensor, torch.Tensor]: + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = inputs[0] y, log_detJ = self._op(x, params) return y, _sum_rightmost(log_detJ, self.domain.event_dim) def _inverse( - self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] - ) -> Tuple[torch.Tensor, torch.Tensor]: + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + y = inputs[0] x_new, log_detJ = self._op(y, params, inverse=True) return x_new, _sum_rightmost(-log_detJ, self.domain.event_dim) diff --git a/flowtorch/bijectors/permute.py b/flowtorch/bijectors/permute.py index f371c9af..dc2f87d8 100644 --- a/flowtorch/bijectors/permute.py +++ b/flowtorch/bijectors/permute.py @@ -27,8 +27,9 @@ def __init__( self.permutation = permutation def _forward( - self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = inputs[0] if self.permutation is None: self.permutation = torch.randperm(x.shape[-1]) @@ -37,8 +38,10 @@ def _forward( return y, ladj def _inverse( - self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + y = inputs[0] + if self.permutation is None: self.permutation = torch.randperm(y.shape[-1]) @@ -53,6 +56,8 @@ def inv_permutation(self) -> Optional[torch.Tensor]: result = torch.empty_like(self.permutation, dtype=torch.long) result[self.permutation] = torch.arange( - self.permutation.size(0), dtype=torch.long, device=self.permutation.device + self.permutation.size(0), + dtype=torch.long, + device=self.permutation.device, ) return result diff --git a/flowtorch/bijectors/power.py b/flowtorch/bijectors/power.py index 0aea101e..d755ac4f 100644 --- a/flowtorch/bijectors/power.py +++ b/flowtorch/bijectors/power.py @@ -28,17 +28,17 @@ def __init__( self.exponent = exponent def _forward( - self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = inputs[0] y = x.pow(self.exponent) ladj = self._log_abs_det_jacobian(x, y, params) return y, ladj def _inverse( - self, - y: torch.Tensor, - params: Optional[Sequence[torch.Tensor]] = None, + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + y = inputs[0] x = y.pow(1 / self.exponent) ladj = self._log_abs_det_jacobian(x, y, params) return x, ladj diff --git a/flowtorch/bijectors/sigmoid.py b/flowtorch/bijectors/sigmoid.py index e9fccaec..aabcf5ee 100644 --- a/flowtorch/bijectors/sigmoid.py +++ b/flowtorch/bijectors/sigmoid.py @@ -13,15 +13,18 @@ class Sigmoid(Fixed): codomain = constraints.unit_interval def _forward( - self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = inputs[0] y = clipped_sigmoid(x) ladj = self._log_abs_det_jacobian(x, y, params) return y, ladj def _inverse( - self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + y = inputs[0] + finfo = torch.finfo(y.dtype) y = y.clamp(min=finfo.tiny, max=1.0 - finfo.eps) x = y.log() - torch.log1p(-y) @@ -29,6 +32,9 @@ def _inverse( return x, ladj def _log_abs_det_jacobian( - self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, + x: torch.Tensor, + y: torch.Tensor, + params: Optional[Sequence[torch.Tensor]], ) -> torch.Tensor: return -F.softplus(-x) - F.softplus(x) diff --git a/flowtorch/bijectors/softplus.py b/flowtorch/bijectors/softplus.py index 5633a3dd..4ea5e249 100644 --- a/flowtorch/bijectors/softplus.py +++ b/flowtorch/bijectors/softplus.py @@ -16,20 +16,26 @@ class Softplus(Fixed): codomain = constraints.positive def _forward( - self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = inputs[0] y = F.softplus(x) ladj = self._log_abs_det_jacobian(x, y, params) return y, ladj def _inverse( - self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + y = inputs[0] + x = flowtorch.ops.softplus_inv(y) ladj = self._log_abs_det_jacobian(x, y, params) return x, ladj def _log_abs_det_jacobian( - self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, + x: torch.Tensor, + y: torch.Tensor, + params: Optional[Sequence[torch.Tensor]], ) -> torch.Tensor: return -F.softplus(-x) diff --git a/flowtorch/bijectors/split_bijector.py b/flowtorch/bijectors/split_bijector.py new file mode 100644 index 00000000..6ae0ef42 --- /dev/null +++ b/flowtorch/bijectors/split_bijector.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc +from typing import Tuple, Optional, Sequence + +import flowtorch +import torch +from torch.nn.functional import softplus + +from ..parameters import ZeroConv2d +from . import Bijector +from .bijective_tensor import BijectiveTensor +from .utils import _sum_rightmost_over_tuple + + +class ReshapeBijector(Bijector): + pass + + +class SplitBijector(ReshapeBijector): + BIAS_SOFTPLUS = 0.54 + + def __init__( + self, + params_fn: Optional[flowtorch.Lazy] = None, + *, + shape: torch.Size, + transform: Bijector, + chunk_dim: int = -3, + context_shape: Optional[torch.Size] = None, + ) -> None: + if params_fn is None: + params_fn = ZeroConv2d() # type: ignore + + super().__init__(params_fn, shape=shape, context_shape=context_shape) + self._transform = transform + self.chunk_dim = chunk_dim + + def _forward_pre_ops(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]: + x1, x2 = x.chunk(2, dim=self.chunk_dim) + return x1, x2 + + def _inverse_pre_ops(self, y: torch.Tensor) -> Tuple[torch.Tensor, ...]: + y1, y2 = y.chunk(2, dim=self.chunk_dim) + x1 = self._transform.inverse(y1) + return x1, y2 + + def _forward( + self, + *x: torch.Tensor, + params: Optional[Sequence[torch.Tensor]], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x1, x2 = x + assert params is not None + loc, scale = params + scale = softplus(scale + self.BIAS_SOFTPLUS) + y1 = self._transform.forward(x1) + y2 = (x2 - loc) / scale.clamp_min(1e-5) + ldj = self._transform.log_abs_det_jacobian(x1, y1) + ldj1, ldj2 = _sum_rightmost_over_tuple(ldj, -scale.log()) + return torch.cat([y1, y2], self.chunk_dim), ldj1 + ldj2 + + def _inverse( + self, + *y: torch.Tensor, + params: Optional[Sequence[torch.Tensor]], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x1, y2 = y + assert params is not None + assert isinstance(x1, BijectiveTensor) + loc, scale = params + scale = softplus(scale + self.BIAS_SOFTPLUS) + x2 = y2 * scale + loc + ldj = self._transform.log_abs_det_jacobian( + x1, x1.get_parent_from_bijector(self._transform) + ) + ldj1, ldj2 = _sum_rightmost_over_tuple(ldj, -scale.log()) + return torch.cat([x1, x2], self.chunk_dim), ldj1 + ldj2 + + def param_shapes(self, shape: torch.Size) -> Tuple[torch.Size, torch.Size]: + # A mean and log variance for every dimension of the event shape + return shape, shape diff --git a/flowtorch/bijectors/tanh.py b/flowtorch/bijectors/tanh.py index 5aae732e..199f64ac 100644 --- a/flowtorch/bijectors/tanh.py +++ b/flowtorch/bijectors/tanh.py @@ -16,20 +16,26 @@ class Tanh(Fixed): codomain = constraints.interval(-1.0, 1.0) def _forward( - self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = inputs[0] y = torch.tanh(x) ladj = self._log_abs_det_jacobian(x, y, params) return y, ladj def _inverse( - self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, *inputs: torch.Tensor, params: Optional[Sequence[torch.Tensor]] ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + y = inputs[0] + x = torch.atanh(y) ladj = self._log_abs_det_jacobian(x, y, params) return x, ladj def _log_abs_det_jacobian( - self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, + x: torch.Tensor, + y: torch.Tensor, + params: Optional[Sequence[torch.Tensor]], ) -> torch.Tensor: return 2.0 * (math.log(2.0) - x - F.softplus(-2.0 * x)) diff --git a/flowtorch/bijectors/utils.py b/flowtorch/bijectors/utils.py index 376751f1..7eba1015 100644 --- a/flowtorch/bijectors/utils.py +++ b/flowtorch/bijectors/utils.py @@ -1,6 +1,8 @@ # Copyright (c) Meta Platforms, Inc import functools -from typing import Any, Callable, List, Sequence +from typing import Any, Callable, List, Sequence, Tuple + +import torch _RECORD_FLOW = True @@ -56,3 +58,14 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def requires_log_detJ() -> bool: return _REQUIRES_LOG_DETJ + + +def _sum_rightmost_over_tuple(*x: torch.Tensor) -> Tuple[torch.Tensor, ...]: + min_dim = min(_x.ndimension() for _x in x) + x = tuple( + _x.sum(dim=list(range(min_dim, _x.ndimension()))) + if _x.ndimension() > min_dim + else _x + for _x in x + ) + return x diff --git a/flowtorch/bijectors/volume_preserving.py b/flowtorch/bijectors/volume_preserving.py index 4f8b6bd7..66ba2f83 100644 --- a/flowtorch/bijectors/volume_preserving.py +++ b/flowtorch/bijectors/volume_preserving.py @@ -9,7 +9,10 @@ class VolumePreserving(Bijector): def _log_abs_det_jacobian( - self, x: torch.Tensor, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + self, + x: torch.Tensor, + y: torch.Tensor, + params: Optional[Sequence[torch.Tensor]], ) -> torch.Tensor: # TODO: Confirm that this should involve `x`/`self.domain` and not # `y`/`self.codomain` diff --git a/flowtorch/distributions/flow.py b/flowtorch/distributions/flow.py index bfb0e97d..6b7d6488 100644 --- a/flowtorch/distributions/flow.py +++ b/flowtorch/distributions/flow.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc -from typing import Any, Dict, Optional, Union, Iterator +from typing import Any, Dict, Iterator, Optional, Union import flowtorch import torch diff --git a/flowtorch/nn/made.py b/flowtorch/nn/made.py index 2de1e007..7180cc31 100644 --- a/flowtorch/nn/made.py +++ b/flowtorch/nn/made.py @@ -110,7 +110,11 @@ class MaskedLinear(nn.Linear): """ def __init__( - self, in_features: int, out_features: int, mask: torch.Tensor, bias: bool = True + self, + in_features: int, + out_features: int, + mask: torch.Tensor, + bias: bool = True, ) -> None: super().__init__(in_features, out_features, bias) self.register_buffer("mask", mask.data) diff --git a/flowtorch/parameters/__init__.py b/flowtorch/parameters/__init__.py index 86f8045c..da167d57 100644 --- a/flowtorch/parameters/__init__.py +++ b/flowtorch/parameters/__init__.py @@ -7,7 +7,17 @@ """ from flowtorch.parameters.base import Parameters +from flowtorch.parameters.conv2d import ZeroConv2d +from flowtorch.parameters.coupling import ConvCoupling +from flowtorch.parameters.coupling import DenseCoupling from flowtorch.parameters.dense_autoregressive import DenseAutoregressive from flowtorch.parameters.tensor import Tensor -__all__ = ["Parameters", "DenseAutoregressive", "Tensor"] +__all__ = [ + "Parameters", + "ZeroConv2d", + "ConvCoupling", + "DenseCoupling", + "DenseAutoregressive", + "Tensor", +] diff --git a/flowtorch/parameters/base.py b/flowtorch/parameters/base.py index 72e4b69f..558bded9 100644 --- a/flowtorch/parameters/base.py +++ b/flowtorch/parameters/base.py @@ -24,15 +24,17 @@ def __init__( def forward( self, - x: Optional[torch.Tensor] = None, + *input: torch.Tensor, + inverse: bool, context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: # TODO: Caching etc. - return self._forward(x, context) + return self._forward(*input, inverse=inverse, context=context) def _forward( self, - x: Optional[torch.Tensor] = None, + *inputs: torch.Tensor, + inverse: bool, context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: # I raise an exception rather than using @abstractmethod and diff --git a/flowtorch/parameters/conv2d.py b/flowtorch/parameters/conv2d.py new file mode 100644 index 00000000..0f962c3a --- /dev/null +++ b/flowtorch/parameters/conv2d.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc +from typing import Sequence, Optional + +import torch +from flowtorch.parameters import Parameters +from torch import nn + + +class ZeroConv2d(Parameters): + autoregressive = False + + def __init__( + self, + param_shapes: Sequence[torch.Size], + input_shape: torch.Size, + context_shape: Optional[torch.Size], + *, + kernel_size: int = 3, + padding: int = 1, + ) -> None: + super().__init__(param_shapes, input_shape, context_shape) + + self.kernel_size = kernel_size + self.channels = self.input_shape[-3] // 2 + self.padding = padding + + self._build() + + def _build( + self, + ) -> None: + self.conv2d = nn.Conv2d( + self.channels, + 2 * self.channels, + kernel_size=self.kernel_size, + padding=self.padding, + ) + for p in self.conv2d.parameters(): + p.data.zero_() + + def _forward( + self, + *input: torch.Tensor, + inverse: bool, + context: Optional[torch.Tensor] = None, + ) -> Optional[Sequence[torch.Tensor]]: + x1, x2_or_y2 = input + return self.conv2d(x1).chunk(2, dim=-3) + + def __repr__(self) -> str: + string = f"{self.__class__.__name__}(conv={self.conv2d})" + return string diff --git a/flowtorch/parameters/coupling.py b/flowtorch/parameters/coupling.py new file mode 100644 index 00000000..330f8839 --- /dev/null +++ b/flowtorch/parameters/coupling.py @@ -0,0 +1,357 @@ +# Copyright (c) Meta Platforms, Inc + +from typing import Callable, Optional, Sequence + +import torch +import torch.nn as nn +from flowtorch.nn.made import MaskedLinear +from flowtorch.parameters.base import Parameters + + +def _make_mask(shape: torch.Size, mask_type: str) -> torch.Tensor: + if mask_type.startswith("neg_"): + return _make_mask(shape, mask_type[4:]) + elif mask_type == "chessboard": + z = torch.zeros(shape, dtype=torch.bool) + z[:, ::2, ::2] = 1 + z[:, 1::2, 1::2] = 1 + return z + elif mask_type == "quadrant": + z = torch.zeros(shape, dtype=torch.bool) + z[:, shape[1] // 2 :, : shape[2] // 2] = 1 + z[:, : shape[1] // 2, shape[2] // 2 :] = 1 + return z + else: + raise NotImplementedError(shape) + + +class DenseCoupling(Parameters): + autoregressive = False + + def __init__( + self, + param_shapes: Sequence[torch.Size], + input_shape: torch.Size, + context_shape: Optional[torch.Size], + *, + hidden_dims: Sequence[int] = (256, 256), + nonlinearity: Callable[[], nn.Module] = nn.ReLU, + permutation: Optional[torch.LongTensor] = None, + skip_connections: bool = False, + ) -> None: + super().__init__(param_shapes, input_shape, context_shape) + + # Check consistency of input_shape with param_shapes + # We need each param_shapes to match input_shape in + # its leftmost dimensions + for s in param_shapes: + assert (len(s) >= len(input_shape)) and ( + s[: len(input_shape)] == input_shape + ) + + self.hidden_dims = hidden_dims + self.nonlinearity = nonlinearity + self.skip_connections = skip_connections + self._build(input_shape, param_shapes, context_shape, permutation) + + def _build( + self, + input_shape: torch.Size, + param_shapes: Sequence[torch.Size], + context_shape: Optional[torch.Size], + permutation: Optional[torch.LongTensor], + ) -> None: + + # Work out flattened input and output shapes + param_shapes_ = list(param_shapes) + input_dims = sum(input_shape) + self.input_dims = input_dims + if input_dims == 0: + input_dims = 1 # scalars represented by torch.Size([]) + if permutation is None: + # permutation will define the split of the input + permutation = torch.LongTensor( + torch.randperm(input_dims, device="cpu").to( + torch.LongTensor((1,)).device + ) + ) + else: + # The permutation is chosen by the user + permutation = torch.LongTensor(permutation) + + self.param_dims = [ + int(max(torch.prod(torch.tensor(s[len(input_shape) :])).item(), 1)) + for s in param_shapes_ + ] + + self.output_multiplier = sum(self.param_dims) + + if input_dims == 1: + raise ValueError( + "Coupling input_dim = 1. Coupling transforms require at least " + "two features." + ) + + self.register_buffer("permutation", permutation) + self.register_buffer("inv_permutation", permutation.argsort()) + + # Create masks + hidden_dims = self.hidden_dims + + # Create masked layers: + # input is [x1 ; 0] + # output is [0 ; mu2], [0 ; sig2] + mask_input = torch.ones(hidden_dims[0], input_dims) + self.x1_dim = x1_dim = input_dims // 2 + mask_input[:, x1_dim:] = 0.0 + mask_input = mask_input[:, self.permutation] + + out_dims = input_dims * self.output_multiplier + mask_output = torch.ones( + self.output_multiplier, + input_dims, + hidden_dims[-1], + dtype=torch.bool, + ) + mask_output[:, :x1_dim] = 0.0 + mask_output = mask_output[:, self.permutation] + mask_output_reg = mask_output[0, :, 0] + mask_output = mask_output.view(-1, hidden_dims[-1]) + + self._bias = nn.Parameter( + torch.zeros(self.output_multiplier, x1_dim, requires_grad=True) + ) + + layers = [ + MaskedLinear( + input_dims, # + context_dims, + hidden_dims[0], + mask_input, + ), + self.nonlinearity(), + ] + for i in range(1, len(hidden_dims)): + layers.extend( + [ + nn.Linear(hidden_dims[i - 1], hidden_dims[i]), + self.nonlinearity(), + ] + ) + layers.append( + MaskedLinear( + hidden_dims[-1], + out_dims, + mask_output, + bias=False, + ) + ) + + if self.skip_connections: + self.skip_layer = MaskedLinear( + input_dims, # + context_dims, + out_dims, + mask_output, + bias=False, + ) + + self.layers = nn.Sequential(*layers) + self.register_buffer("mask_output", mask_output_reg.to(torch.bool)) + self._init_weights() + + def _init_weights(self) -> None: + for layer in self.modules(): + if hasattr(layer, "weight"): + layer.weight.data.normal_(0.0, 1e-3) # type: ignore + if hasattr(layer, "bias") and layer.bias is not None: + layer.bias.data.fill_(0.0) # type: ignore + + @property + def bias(self) -> torch.Tensor: + z = torch.zeros( + self.output_multiplier, + self.input_dims - self.x1_dim, + device=self._bias.device, + dtype=self._bias.dtype, + ) + return torch.cat([z, self._bias], -1).view(-1) + + def _forward( + self, + *inputs: torch.Tensor, + inverse: bool, + context: Optional[torch.Tensor] = None, + ) -> Optional[Sequence[torch.Tensor]]: + + input = inputs[0] + input_masked = input.masked_fill(self.mask_output, 0.0) # type: ignore + if context is not None: + input_aug = torch.cat( + [context.expand((*input.shape[:-1], -1)), input_masked], dim=-1 + ) + else: + input_aug = input_masked + + h = self.layers(input_aug) + self.bias + + # TODO: Get skip_layers working again! + if self.skip_connections: + h = h + self.skip_layer(input_aug) + + # Shape the output + h = h.view(*input.shape[:-1], self.output_multiplier, -1) + + result = h.unbind(-2) + result = tuple( + r.masked_fill(~self.mask_output.expand_as(r), 0.0) # type: ignore + for r in result # type: ignore + ) + return result + + +class ConvCoupling(Parameters): + autoregressive = False + _mask_types = [ + "chessboard", + "quadrants", + "inv_chessboard", + "inv_quadrants", + ] + + def __init__( + self, + param_shapes: Sequence[torch.Size], + input_shape: torch.Size, + context_shape: Optional[torch.Size], + *, + cnn_activate_input: bool = True, + cnn_channels: int = 256, + cnn_kernel: Sequence[int] = None, + cnn_padding: Sequence[int] = None, + cnn_stride: Sequence[int] = None, + nonlinearity: Callable[[], nn.Module] = nn.ReLU, + skip_connections: bool = False, + mask_type: str = "chessboard", + ) -> None: + super().__init__(param_shapes, input_shape, context_shape) + + # Check consistency of input_shape with param_shapes + # We need each param_shapes to match input_shape in + # its leftmost dimensions + for s in param_shapes: + assert (len(s) >= len(input_shape)) and ( + s[: len(input_shape)] == input_shape + ) + + if cnn_kernel is None: + cnn_kernel = [3, 1, 3] + if cnn_padding is None: + cnn_padding = [1, 0, 1] + if cnn_stride is None: + cnn_stride = [1, 1, 1] + + self.cnn_channels = cnn_channels + self.cnn_activate_input = cnn_activate_input + self.cnn_kernel = cnn_kernel + self.cnn_padding = cnn_padding + self.cnn_stride = cnn_stride + + self.nonlinearity = nonlinearity + self.skip_connections = skip_connections + self._build(input_shape, param_shapes, context_shape, mask_type) + + def _build( + self, + input_shape: torch.Size, # something like [C, W, H] + param_shapes: Sequence[torch.Size], # [[C, W, H], [C, W, H]] + context_shape: Optional[torch.Size], + mask_type: str, + ) -> None: + + mask = _make_mask(input_shape, mask_type) + self.register_buffer("mask", mask) + self.output_multiplier = len(param_shapes) + + out_channels, width, height = input_shape + + layers = [] + if self.cnn_activate_input: + layers.append(self.nonlinearity()) + layers.append( + nn.LazyConv2d( + out_channels=self.cnn_channels, + kernel_size=self.cnn_kernel[0], + padding=self.cnn_padding[0], + stride=self.cnn_stride[0], + ) + ) + layers.append(self.nonlinearity()) + layers.append( + nn.Conv2d( + in_channels=self.cnn_channels, + out_channels=self.cnn_channels, + kernel_size=self.cnn_kernel[1], + padding=self.cnn_padding[1], + stride=self.cnn_stride[1], + ) + ) + layers.append(self.nonlinearity()) + layers.append( + nn.Conv2d( + in_channels=self.cnn_channels, + out_channels=out_channels * self.output_multiplier, + kernel_size=self.cnn_kernel[2], + padding=self.cnn_padding[2], + stride=self.cnn_stride[2], + ) + ) + + self.layers = nn.Sequential(*layers) + self._init_weights() + + def _init_weights(self) -> None: + for layer in self.modules(): + if hasattr(layer, "weight"): + layer.weight.data.normal_(0.0, 1e-3) # type: ignore + if hasattr(layer, "bias") and layer.bias is not None: + layer.bias.data.fill_(0.0) # type: ignore + + def _forward( + self, + *inputs: torch.Tensor, + inverse: bool, + context: Optional[torch.Tensor] = None, + ) -> Optional[Sequence[torch.Tensor]]: + + input = inputs[0] + unsqueeze = False + if input.ndimension() == 3: + # mostly for initialization + unsqueeze = True + input = input.unsqueeze(0) + + input_masked = input.masked_fill(self.mask, 0.0) # type: ignore + if context is not None: + context_shape = [shape for shape in input_masked.shape] + context_shape[-3] = context.shape[-3] + input_aug = torch.cat( + [context.expand(*context_shape), input_masked], dim=-1 + ) + else: + input_aug = input_masked + + h = self.layers(input_aug) + + if self.skip_connections: + h = h + input_masked + + # Shape the output + + if unsqueeze: + h = h.squeeze(0) + result = h.chunk(2, -3) + + result = tuple( + r.masked_fill(~self.mask.expand_as(r), 0.0) for r in result # type: ignore + ) + + return result diff --git a/flowtorch/parameters/dense_autoregressive.py b/flowtorch/parameters/dense_autoregressive.py index 8110e5a6..b400c426 100644 --- a/flowtorch/parameters/dense_autoregressive.py +++ b/flowtorch/parameters/dense_autoregressive.py @@ -45,6 +45,7 @@ def _build( ) -> None: # Work out flattened input and output shapes param_shapes_ = list(param_shapes) + # Why not just (sum(input_shape))? input_dims = int(torch.sum(torch.tensor(input_shape)).int().item()) if input_dims == 0: input_dims = 1 # scalars represented by torch.Size([]) @@ -60,6 +61,7 @@ def _build( # The permutation is chosen by the user permutation = torch.LongTensor(permutation) + # why not math.pod(s[len(input_shape):]), where math.prod([])=1? self.param_dims = [ int(max(torch.prod(torch.tensor(s[len(input_shape) :])).item(), 1)) for s in param_shapes_ @@ -141,33 +143,36 @@ def _build( ) ) + # Why not using regular sequential? self.layers = nn.ModuleList(layers) def _forward( self, - x: Optional[torch.Tensor] = None, + *inputs: torch.Tensor, + inverse: bool, context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: - assert x is not None - # Flatten x - batch_shape = x.shape[: len(x.shape) - len(self.input_shape)] + # Flatten input + input = inputs[0] + batch_shape = input.shape[: len(input.shape) - len(self.input_shape)] if len(batch_shape) > 0: - x = x.reshape(batch_shape + (-1,)) + input = input.reshape(batch_shape + (-1,)) if context is not None: # TODO: Fix the following! - h = torch.cat([context.expand((x.shape[0], -1)), x], dim=-1) + h = torch.cat([context.expand((input.shape[0], -1)), input], dim=-1) else: - h = x + h = input + # Why not using regular sequential? for idx in range(len(self.layers) // 2): h = self.layers[2 * idx + 1](self.layers[2 * idx](h)) h = self.layers[-1](h) # TODO: Get skip_layers working again! # if self.skip_layer is not None: - # h = h + self.skip_layer(x) + # h = h + self.skip_layer(input) # Shape the output # h ~ (batch_dims * input_dims, total_params_per_dim) diff --git a/flowtorch/parameters/tensor.py b/flowtorch/parameters/tensor.py index 3de8680a..ce497056 100644 --- a/flowtorch/parameters/tensor.py +++ b/flowtorch/parameters/tensor.py @@ -22,6 +22,9 @@ def __init__( ) def _forward( - self, x: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None + self, + *input: torch.Tensor, + inverse: bool, + context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: return list(self.params) diff --git a/scripts/copyright_headers.py b/scripts/copyright_headers.py index 670b8674..dd755379 100644 --- a/scripts/copyright_headers.py +++ b/scripts/copyright_headers.py @@ -108,7 +108,10 @@ def print_results(count_changed, args): help="just checks files and does not change any", ) parser.add_argument( - "-v", "--verbose", action="store_true", help="prints extra information on files" + "-v", + "--verbose", + action="store_true", + help="prints extra information on files", ) parser.add_argument( "paths", nargs="+", help="paths to search for Python source files" diff --git a/scripts/generate_api_docs.py b/scripts/generate_api_docs.py index a6a8fd9a..37912a64 100644 --- a/scripts/generate_api_docs.py +++ b/scripts/generate_api_docs.py @@ -180,7 +180,10 @@ def search_symbols(config): # Construct regular expressions for includes and excludes # Default include/exclude rules patterns = { - "include": {"modules": re.compile(r".+"), "symbols": re.compile(r".+")}, + "include": { + "modules": re.compile(r".+"), + "symbols": re.compile(r".+"), + }, "exclude": {"modules": re.compile(r""), "symbols": re.compile(r"")}, } @@ -315,11 +318,14 @@ def create_paths(path: str) -> None: with open( os.path.join( os.path.join( - main_path, config["paths"]["markdown"], article_name + ".mdx" + main_path, + config["paths"]["markdown"], + article_name + ".mdx", ) ), "w", ) as file: print( - generate_markdown(article_name, symbol_name, symbol_object), file=file + generate_markdown(article_name, symbol_name, symbol_object), + file=file, ) diff --git a/tests/test_bijectivetensor.py b/tests/test_bijectivetensor.py index 72bbdf70..fa340f57 100644 --- a/tests/test_bijectivetensor.py +++ b/tests/test_bijectivetensor.py @@ -15,7 +15,6 @@ def get_net() -> AffineAutoregressive: [ AffineAutoregressive(params.DenseAutoregressive()), AffineAutoregressive(params.DenseAutoregressive()), - AffineAutoregressive(params.DenseAutoregressive()), ] ) ar = ar( diff --git a/tests/test_bijector.py b/tests/test_bijector.py index e9344ef6..eec9f949 100644 --- a/tests/test_bijector.py +++ b/tests/test_bijector.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc +import math import warnings import flowtorch.bijectors as bijectors @@ -21,11 +22,13 @@ def test_bijector_constructor(): @pytest.fixture(params=[bij_name for _, bij_name in bijectors.standard_bijectors]) def flow(request): + torch.set_default_dtype(torch.double) bij = request.param event_dim = max(bij.domain.event_dim, 1) event_shape = event_dim * [3] base_dist = dist.Independent( - dist.Normal(torch.zeros(event_shape), torch.ones(event_shape)), event_dim + dist.Normal(torch.zeros(event_shape), torch.ones(event_shape)), + event_dim, ) flow = Flow(base_dist, bij) @@ -41,10 +44,12 @@ def test_jacobian(flow, epsilon=1e-2): x = torch.randn(*flow.event_shape) x = torch.distributions.transform_to(bij.domain)(x) y = bij.forward(x) - if bij.domain.event_dim == 1: - analytic_ldt = bij.log_abs_det_jacobian(x, y).data + if bij.domain.event_dim == 0: + analytic_ldt = bij.log_abs_det_jacobian(x, y).data.sum(-1) else: - analytic_ldt = bij.log_abs_det_jacobian(x, y).sum(-1).data + analytic_ldt = bij.log_abs_det_jacobian(x, y).data + for _ in range(bij.domain.event_dim - 1): + analytic_ldt = analytic_ldt.sum(-1) # Calculate numerical Jacobian # TODO: Better way to get all indices of array/tensor? @@ -86,7 +91,8 @@ def test_jacobian(flow, epsilon=1e-2): if hasattr(params, "permutation"): numeric_ldt = torch.sum(torch.log(torch.diag(jacobian))) else: - numeric_ldt = torch.log(torch.abs(jacobian.det())) + jacobian = jacobian.view(int(math.sqrt(jacobian.numel())), -1) + numeric_ldt = torch.log(torch.abs(jacobian.det())).sum() ldt_discrepancy = (analytic_ldt - numeric_ldt).abs() assert ldt_discrepancy < epsilon @@ -109,6 +115,7 @@ def test_inverse(flow, epsilon=1e-5): # Test g^{-1}(g(x)) = x x_true = base_dist.sample(torch.Size([10])) + assert x_true.dtype is torch.double x_true = torch.distributions.transform_to(bij.domain)(x_true) y = bij.forward(x_true) diff --git a/tests/test_distribution.py b/tests/test_distribution.py index db7c9095..25c065a5 100644 --- a/tests/test_distribution.py +++ b/tests/test_distribution.py @@ -15,7 +15,8 @@ def test_tdist_standalone(): def make_tdist(): # train a flow here base_dist = torch.distributions.Independent( - torch.distributions.Normal(torch.zeros(input_dim), torch.ones(input_dim)), 1 + torch.distributions.Normal(torch.zeros(input_dim), torch.ones(input_dim)), + 1, ) bijector = bijs.AffineAutoregressive() tdist = dist.Flow(base_dist, bijector) @@ -37,9 +38,9 @@ def test_neals_funnel_vi(): flow = dist.Flow(base_dist, bijector) bijector = flow.bijector - opt = torch.optim.Adam(flow.parameters(), lr=2e-3) + opt = torch.optim.Adam(flow.parameters(), lr=1e-2) num_elbo_mc_samples = 200 - for _ in range(100): + for _ in range(500): z0 = flow.base_dist.rsample(sample_shape=(num_elbo_mc_samples,)) zk = bijector.forward(z0) ldj = zk._log_detJ