From af6bb6baf4f523ce3c0792e7f7a98ff17d94fe74 Mon Sep 17 00:00:00 2001 From: nstarman Date: Wed, 26 Jul 2023 12:19:13 -0400 Subject: [PATCH 1/2] Separate isochrone into private module Signed-off-by: nstarman --- .../pytorch/builtin/_isochrone/__init__.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/stream_ml/pytorch/builtin/_isochrone/__init__.py b/src/stream_ml/pytorch/builtin/_isochrone/__init__.py index 485f028..32634fe 100644 --- a/src/stream_ml/pytorch/builtin/_isochrone/__init__.py +++ b/src/stream_ml/pytorch/builtin/_isochrone/__init__.py @@ -2,15 +2,6 @@ from __future__ import annotations -from stream_ml.pytorch.builtin._isochrone.core import IsochroneMVNorm -from stream_ml.pytorch.builtin._isochrone.mf import ( - HardCutoffMassFunction, - StepwiseMassFunction, - StreamMassFunction, - UniformStreamMassFunction, -) -from stream_ml.pytorch.builtin._isochrone.utils import Parallax2DistMod - __all__: list[str] = [ # Mass Function "StreamMassFunction", @@ -22,3 +13,12 @@ # Utils "Parallax2DistMod", ] + +from stream_ml.pytorch.builtin._isochrone.core import IsochroneMVNorm +from stream_ml.pytorch.builtin._isochrone.mf import ( + HardCutoffMassFunction, + StepwiseMassFunction, + StreamMassFunction, + UniformStreamMassFunction, +) +from stream_ml.pytorch.builtin._isochrone.utils import Parallax2DistMod From e020385d95aa3a35b684155c06093dfe79bee21c Mon Sep 17 00:00:00 2001 From: nstarman Date: Wed, 26 Jul 2023 12:41:19 -0400 Subject: [PATCH 2/2] Kroupa IMF Signed-off-by: nstarman --- pyproject.toml | 2 + src/stream_ml/pytorch/builtin/__init__.py | 50 ++++------- .../pytorch/builtin/_isochrone/__init__.py | 2 + .../pytorch/builtin/_isochrone/mf.py | 88 ++++++++++++++++--- 4 files changed, 100 insertions(+), 42 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d6dd048..3014695 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ "tqdm", "nflows", "zuko", + "torchcubicspline @ git+https://github.com/patrick-kidger/torchcubicspline.git", ] test = [ "coverage[toml]", @@ -136,6 +137,7 @@ where = ["src"] "asdf.*", "scipy.*", "torch.*", + "torchcubicspline.*", ] ignore_missing_imports = true diff --git a/src/stream_ml/pytorch/builtin/__init__.py b/src/stream_ml/pytorch/builtin/__init__.py index 1757bc4..7c4c198 100644 --- a/src/stream_ml/pytorch/builtin/__init__.py +++ b/src/stream_ml/pytorch/builtin/__init__.py @@ -1,27 +1,5 @@ """Stream models.""" -__all__ = [ - # modules - "compat", - # classes - "Uniform", - "Sloped", - "Exponential", - "Normal", - "TruncatedNormal", - "SkewNormal", - "TruncatedSkewNormal", - # -- isochrone - "IsochroneMVNorm", - "StreamMassFunction", - "UniformStreamMassFunction", - "HardCutoffMassFunction", - "StepwiseMassFunction", - "Parallax2DistMod", - # -- multivariate - "MultivariateNormal", -] - from dataclasses import field, make_dataclass import torch as xp @@ -32,21 +10,31 @@ from stream_ml.core.builtin._uniform import Uniform as CoreUniform from stream_ml.pytorch._base import ModelBase -from stream_ml.pytorch.builtin import compat -from stream_ml.pytorch.builtin._isochrone import ( - HardCutoffMassFunction, - IsochroneMVNorm, - Parallax2DistMod, - StepwiseMassFunction, - StreamMassFunction, - UniformStreamMassFunction, -) +from stream_ml.pytorch.builtin import _isochrone, compat +from stream_ml.pytorch.builtin._isochrone import * # noqa: F403 from stream_ml.pytorch.builtin._multinormal import MultivariateNormal from stream_ml.pytorch.builtin._skewnorm import SkewNormal from stream_ml.pytorch.builtin._sloped import Sloped from stream_ml.pytorch.builtin._truncskewnorm import TruncatedSkewNormal from stream_ml.pytorch.typing import Array, ArrayNamespace, NNModel +__all__ = [ + # modules + "compat", + # classes + "Uniform", + "Sloped", + "Exponential", + "Normal", + "TruncatedNormal", + "SkewNormal", + "TruncatedSkewNormal", + # -- multivariate + "MultivariateNormal", +] +__all__ += _isochrone.__all__ + + Normal = make_dataclass( "Normal", [("array_namespace", ArrayNamespace[Array], field(default=xp, kw_only=True))], diff --git a/src/stream_ml/pytorch/builtin/_isochrone/__init__.py b/src/stream_ml/pytorch/builtin/_isochrone/__init__.py index 32634fe..5d35fec 100644 --- a/src/stream_ml/pytorch/builtin/_isochrone/__init__.py +++ b/src/stream_ml/pytorch/builtin/_isochrone/__init__.py @@ -8,6 +8,7 @@ "UniformStreamMassFunction", "HardCutoffMassFunction", "StepwiseMassFunction", + "KroupaIMF", # Core "IsochroneMVNorm", # Utils @@ -17,6 +18,7 @@ from stream_ml.pytorch.builtin._isochrone.core import IsochroneMVNorm from stream_ml.pytorch.builtin._isochrone.mf import ( HardCutoffMassFunction, + KroupaIMF, StepwiseMassFunction, StreamMassFunction, UniformStreamMassFunction, diff --git a/src/stream_ml/pytorch/builtin/_isochrone/mf.py b/src/stream_ml/pytorch/builtin/_isochrone/mf.py index 4d43832..8c49cd4 100644 --- a/src/stream_ml/pytorch/builtin/_isochrone/mf.py +++ b/src/stream_ml/pytorch/builtin/_isochrone/mf.py @@ -5,15 +5,16 @@ __all__: list[str] = [] from dataclasses import dataclass -from typing import TYPE_CHECKING, Protocol +from math import log +from typing import TYPE_CHECKING, ClassVar, Protocol + +import torch as xp +from torchcubicspline import NaturalCubicSpline, natural_cubic_spline_coeffs if TYPE_CHECKING: from stream_ml.pytorch import Data from stream_ml.pytorch.typing import Array, ArrayNamespace -# ============================================================================= -# Cluster Mass Function - class StreamMassFunction(Protocol): """Stream Mass Function. @@ -26,7 +27,7 @@ class StreamMassFunction(Protocol): """ def __call__( - self, gamma: Array, x: Data[Array], *, xp: ArrayNamespace[Array] + self, gamma: Array, x: Data[Array] | None, *, xp: ArrayNamespace[Array] ) -> Array: r"""Log-probability of stars at position 'x' having mass 'gamma'. @@ -51,9 +52,9 @@ def __call__( @dataclass(frozen=True) class UniformStreamMassFunction(StreamMassFunction): def __call__( - self, gamma: Array, x: Data[Array], *, xp: ArrayNamespace[Array] + self, gamma: Array, x: Data[Array] | None, *, xp: ArrayNamespace[Array] ) -> Array: - return xp.zeros((len(x), len(gamma))) + return xp.zeros((1 if x is None else len(x), len(gamma))) @dataclass(frozen=True) @@ -64,9 +65,9 @@ class HardCutoffMassFunction(StreamMassFunction): upper: float = 1 def __call__( - self, gamma: Array, x: Data[Array], *, xp: ArrayNamespace[Array] + self, gamma: Array, x: Data[Array] | None, *, xp: ArrayNamespace[Array] ) -> Array: - out = xp.full((len(x), len(gamma)), -xp.inf) + out = xp.full((1 if x is None else len(x), len(gamma)), -xp.inf) out[:, (gamma >= self.lower) & (gamma <= self.upper)] = 0 return out @@ -79,11 +80,76 @@ class StepwiseMassFunction(StreamMassFunction): log_probs: tuple[float, ...] # (B,) def __call__( - self, gamma: Array, x: Data[Array], *, xp: ArrayNamespace[Array] + self, gamma: Array, x: Data[Array] | None, *, xp: ArrayNamespace[Array] ) -> Array: - out = xp.full((len(x), len(gamma)), -xp.inf) + out = xp.full((1 if x is None else len(x), len(gamma)), -xp.inf) for lower, upper, lnp in zip( self.boundaries[:-1], self.boundaries[1:], self.log_probs, strict=True ): out[:, (gamma >= lower) & (gamma < upper)] = lnp return out + + +# -------------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class KroupaIMF(StreamMassFunction): + """Kroupa IMF. + + https://arxiv.org/pdf/astro-ph/0102155.pdf + + Parameters + ---------- + gamma_to_mass : :class:`torchcubicspline.NaturalCubicSpline` + A callable that takes in gamma and returns the mass. + """ + + gamma_to_mass: NaturalCubicSpline + + ranges: ClassVar[tuple[float, ...]] = (0.01, 0.08, 0.5, 100) + ln_ranges: ClassVar[tuple[float, ...]] = (1e-2, 8e-2, 5e-1, 1e2) + + def __post_init__(self) -> None: + # TODO: need to normalize to gamma? + # Compute the normalization by integrating over the mass range. + self.ln_norm: float + object.__setattr__(self, "ln_norm", 0) # to avoid missing self-reference. + gammas = xp.linspace(0, 1, 10_000) + masses = self.gamma_to_mass.evaluate(gammas)[:, 0].to(dtype=gammas.dtype) + lnpdfs = self(gammas, None, xp=xp) + norm = float(xp.sum(xp.exp(lnpdfs)[:-1] * xp.diff(masses))) + object.__setattr__(self, "ln_norm", log(norm)) + + def __call__( + self, gamma: Array, x: Data[Array] | None, *, xp: ArrayNamespace[Array] + ) -> Array: + out = xp.empty_like(gamma) + mass = self.gamma_to_mass.evaluate(gamma)[:, 0].to(dtype=gamma.dtype) + rng = xp.asarray(self.ranges, dtype=gamma.dtype) + ln_rng = xp.asarray(self.ln_ranges, dtype=gamma.dtype) + + # https://arxiv.org/pdf/astro-ph/0102155.pdf + if xp.any((mass < rng[0]) | (mass >= rng[-1])): + msg = f"mass must be >= {rng[0]}." + raise ValueError(msg) + + sel = (mass >= rng[0]) & (mass < rng[1]) + out[sel] = -0.3 * (xp.log(mass[sel]) - ln_rng[1]) + + sel = (mass >= rng[1]) & (mass < rng[2]) + out[sel] = -1.3 * (xp.log(mass[sel]) - ln_rng[1]) + + sel = (mass >= rng[2]) & (mass < rng[3]) + out[sel] = ln_rng[2] + 1.3 * ln_rng[1] - 2.3 * xp.log(mass[sel]) + + return out - self.ln_norm + + # =============================================================== + # Constructors + + @classmethod + def from_arrays(cls, gamma: Array, mass: Array) -> KroupaIMF: + """Construct from arrays.""" + coeffs = natural_cubic_spline_coeffs(gamma, mass[:, None]) + return cls(gamma_to_mass=NaturalCubicSpline(coeffs))