Skip to content

Commit

Permalink
refactor multinormal for better masking
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Aug 11, 2023
1 parent 78c5549 commit 9fc9c55
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 92 deletions.
6 changes: 1 addition & 5 deletions src/stream_ml/pytorch/builtin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"Parallax2DistMod",
# -- multivariate
"MultivariateNormal",
"MultivariateMissingNormal",
]

from dataclasses import field, make_dataclass
Expand All @@ -42,10 +41,7 @@
StreamMassFunction,
UniformStreamMassFunction,
)
from stream_ml.pytorch.builtin._multinormal import (
MultivariateMissingNormal,
MultivariateNormal,
)
from stream_ml.pytorch.builtin._multinormal import MultivariateNormal
from stream_ml.pytorch.builtin._skewnorm import SkewNormal
from stream_ml.pytorch.builtin._sloped import Sloped
from stream_ml.pytorch.builtin._truncskewnorm import TruncatedSkewNormal
Expand Down
1 change: 0 additions & 1 deletion src/stream_ml/pytorch/builtin/_isochrone/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,6 @@ def ln_likelihood(
- mean[sel]
)

lnliks = xp.zeros((len(data), len(self._gamma_points))) # (N, I)
lnliks = -0.5 * ( # (N, I, 1, 1) -> (N, I)
D[:, None] * _log2pi
+ logdet
Expand Down
163 changes: 77 additions & 86 deletions src/stream_ml/pytorch/builtin/_multinormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

__all__: list[str] = []

from dataclasses import KW_ONLY, dataclass
from dataclasses import dataclass
from typing import TYPE_CHECKING

import torch as xp
from torch.distributions import MultivariateNormal as TorchMultivariateNormal

from stream_ml.core.builtin._utils import WhereRequiredError

from stream_ml.pytorch._base import ModelBase

Expand All @@ -34,7 +35,9 @@ def ln_likelihood(
/,
data: Data[Array],
*,
where: Data[Array] | None = None,
correlation_matrix: Array | None = None,
correlation_det: Array | None = None,
**kwargs: Array,
) -> Array:
r"""Log-likelihood of the distribution.
Expand All @@ -47,16 +50,26 @@ def ln_likelihood(
data : Data[Array]
Data (phi1, phi2, ...).
correlation_matrix : Array[(N,F,F)], optional keyword-only
The correlation matrix. If not provided, then the covariance matrix is
assumed to be diagonal.
The covariance matrix is computed as:
where : Data[Array[(N,), bool]] | None, optional keyword-only
Where to evaluate the log-likelihood. If not provided, then the
log-likelihood is evaluated at all data points. ``where`` must
contain the fields in ``phot_names``. Each field must be a boolean
array of the same length as `data`. `True` indicates that the data
point is available, and `False` indicates that the data point is not
available.
correlation_matrix : Array[(N,F,F)] | None, optional keyword-only
The correlation matrix. If not provided, then the covariance matrix
is assumed to be diagonal. The covariance matrix is computed as:
.. math::
\rm{cov}(X) = \rm{diag}(\vec{\sigma})
\cdot \rm{corr}
\cdot \rm{diag}(\vec{\sigma})
\cdot \rm{corr} \cdot \rm{diag}(\vec{\sigma})
correlation_det: Array[(N,)] | None, optional keyword-only
The determinant of the correlation matrix. If not provided, then
the determinant is only the product of the diagonal elements of the
covariance matrix.
**kwargs : Array
Additional arguments.
Expand All @@ -65,88 +78,66 @@ def ln_likelihood(
-------
Array
"""
marginals = xp.diag_embed(
self.xp.exp(self._stack_param(mpars, "ln-sigma", self.coord_names))
# 'where' is used to indicate which data points are available. If
# 'where' is not provided, then all data points are assumed to be
# available.
where_: Array # (N, F)
if where is not None:
where_ = where[self.coord_names].array
elif self.require_where:
raise WhereRequiredError
else:
where_ = self.xp.ones((len(data), self.nF), dtype=bool)

if correlation_matrix is not None and correlation_det is None:
msg = "Must provide `correlation_det`."
raise ValueError(msg)

# Covariance: model (N, F, F)
lnsigma = self.xp.exp(self._stack_param(mpars, "ln-sigma", self.coord_names))
cov_model = xp.diag_embed(xp.exp(2 * lnsigma))
# Covariance data "The covariance matrix can be written as the rescaling
# of a correlation matrix by the marginal variances:"
# (https://en.wikipedia.org/wiki/Covariance_matrix#Correlation_matrix)
stds = data[self.coord_err_names].array
std_data = (
xp.diag_embed(stds)
if self.coord_err_names is not None
else self.xp.zeros(1)
)
cov = (
marginals @ marginals
cov_data = (
std_data**2
if correlation_matrix is None
else marginals @ correlation_matrix @ marginals
else std_data @ correlation_matrix[:, :, :] @ std_data
)

return TorchMultivariateNormal(
self._stack_param(mpars, "mu", self.coord_names),
covariance_matrix=cov,
).log_prob(data[self.coord_names].array)


##############################################################################


@dataclass(unsafe_hash=True)
class MultivariateMissingNormal(MultivariateNormal): # (MultivariateNormal)
"""Multivariate Normal with missing data.
.. note::
Currently this requires a diagonal covariance matrix.
"""

_: KW_ONLY
require_mask: bool = True

def ln_likelihood(
self,
mpars: Params[Array],
/,
data: Data[Array],
*,
mask: Data[Array] | None = None,
**kwargs: Array,
) -> Array:
"""Negative log-likelihood.
Parameters
----------
mpars : Params[Array], positional-only
Model parameters. Note that these are different from the ML
parameters.
data : Data[Array]
Labelled data.
mask : Data[Array[bool]] | None, optional
Data availability. `True` if data is available, `False` if not.
Should have the same keys as `data`.
**kwargs : Array
Additional arguments.
"""
# Normal
x = data[self.coord_names].array
# The covariance, setting non-observed dimensions to 0. (N, F, F)
# positive definite.
idx_cov = xp.diag_embed(where_.to(dtype=data.dtype)) # (N, F, F)
cov = idx_cov @ (cov_data + cov_model) @ idx_cov
# The determinant, dropping the dimensionality of non-observed
# dimensions.
logdet = xp.log(
xp.linalg.det(cov + (xp.eye(self.nF)[None, None] - idx_cov))
) # (N, [I])

# Dimensionality, dropping missing dimensions (N, [I])
D = where_.sum(dim=-1) # noqa: N806

# Construct the data - mean (N, I, F), setting non-observed dimensions to 0.
mu = self._stack_param(mpars, "mu", self.coord_names)
sigma = self.xp.exp(self._stack_param(mpars, "ln-sigma", self.coord_names))

idx: Array
if mask is not None:
idx = mask[tuple(self.coord_bounds.keys())].array
elif self.require_mask:
msg = "mask is required"
raise ValueError(msg)
else:
idx = xp.ones_like(x, dtype=xp.int)
# shape (1, F) so that it can broadcast with (N, F)

D = idx.sum(dim=1) # Dimensionality (N,) # noqa: N806
dmm = idx * (x - mu) # Data - model (N, F)

# Covariance related
cov = idx * sigma**2 # (N, F) positive definite
det = (cov * idx + (1 - idx)).prod(dim=1) # (N,)
sel = where_[:, None, :].expand(-1, self.nI, -1)
x = xp.zeros((len(data), self.nI, self.nF), dtype=data.dtype)
x[sel] = (
data[self.coord_names].array[:, None, :].expand(-1, self.nI, -1)[sel]
- mu[sel]
)

return -0.5 * (
return -0.5 * ( # (N, I, 1, 1) -> (N, I)
D * _log2pi
+ xp.log(det)
+ logdet
+ (
dmm[:, None, :] # (N, 1, F)
@ xp.linalg.pinv(xp.diag_embed(cov)) # (N, F, F)
@ dmm[..., None] # (N, F, 1)
).flatten() # (N, 1, 1) -> (N,)
) # (N,)
x[:, None, :] # (N, 1, F)
@ xp.linalg.pinv(cov) # (N, F, F)
@ x[..., None] # (N, F, 1)
)[..., 0, 0]
)

0 comments on commit 9fc9c55

Please sign in to comment.