Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Isochrone kroupa #128

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"tqdm",
"nflows",
"zuko",
"torchcubicspline @ git+https://github.com/patrick-kidger/torchcubicspline.git",
]
test = [
"coverage[toml]",
Expand Down Expand Up @@ -136,6 +137,7 @@ where = ["src"]
"asdf.*",
"scipy.*",
"torch.*",
"torchcubicspline.*",
]
ignore_missing_imports = true

Expand Down
50 changes: 19 additions & 31 deletions src/stream_ml/pytorch/builtin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,5 @@
"""Stream models."""

__all__ = [
# modules
"compat",
# classes
"Uniform",
"Sloped",
"Exponential",
"Normal",
"TruncatedNormal",
"SkewNormal",
"TruncatedSkewNormal",
# -- isochrone
"IsochroneMVNorm",
"StreamMassFunction",
"UniformStreamMassFunction",
"HardCutoffMassFunction",
"StepwiseMassFunction",
"Parallax2DistMod",
# -- multivariate
"MultivariateNormal",
]

from dataclasses import field, make_dataclass

import torch as xp
Expand All @@ -32,21 +10,31 @@
from stream_ml.core.builtin._uniform import Uniform as CoreUniform

from stream_ml.pytorch._base import ModelBase
from stream_ml.pytorch.builtin import compat
from stream_ml.pytorch.builtin._isochrone import (
HardCutoffMassFunction,
IsochroneMVNorm,
Parallax2DistMod,
StepwiseMassFunction,
StreamMassFunction,
UniformStreamMassFunction,
)
from stream_ml.pytorch.builtin import _isochrone, compat
from stream_ml.pytorch.builtin._isochrone import * # noqa: F403
from stream_ml.pytorch.builtin._multinormal import MultivariateNormal
from stream_ml.pytorch.builtin._skewnorm import SkewNormal
from stream_ml.pytorch.builtin._sloped import Sloped
from stream_ml.pytorch.builtin._truncskewnorm import TruncatedSkewNormal
from stream_ml.pytorch.typing import Array, ArrayNamespace, NNModel

__all__ = [
# modules
"compat",
# classes
"Uniform",
"Sloped",
"Exponential",
"Normal",
"TruncatedNormal",
"SkewNormal",
"TruncatedSkewNormal",
# -- multivariate
"MultivariateNormal",
]
__all__ += _isochrone.__all__


Normal = make_dataclass(
"Normal",
[("array_namespace", ArrayNamespace[Array], field(default=xp, kw_only=True))],
Expand Down
20 changes: 11 additions & 9 deletions src/stream_ml/pytorch/builtin/_isochrone/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,25 @@

from __future__ import annotations

from stream_ml.pytorch.builtin._isochrone.core import IsochroneMVNorm
from stream_ml.pytorch.builtin._isochrone.mf import (
HardCutoffMassFunction,
StepwiseMassFunction,
StreamMassFunction,
UniformStreamMassFunction,
)
from stream_ml.pytorch.builtin._isochrone.utils import Parallax2DistMod

__all__: list[str] = [
# Mass Function
"StreamMassFunction",
"UniformStreamMassFunction",
"HardCutoffMassFunction",
"StepwiseMassFunction",
"KroupaIMF",
# Core
"IsochroneMVNorm",
# Utils
"Parallax2DistMod",
]

from stream_ml.pytorch.builtin._isochrone.core import IsochroneMVNorm
from stream_ml.pytorch.builtin._isochrone.mf import (
HardCutoffMassFunction,
KroupaIMF,
StepwiseMassFunction,
StreamMassFunction,
UniformStreamMassFunction,
)
from stream_ml.pytorch.builtin._isochrone.utils import Parallax2DistMod
88 changes: 77 additions & 11 deletions src/stream_ml/pytorch/builtin/_isochrone/mf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
__all__: list[str] = []

from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol
from math import log
from typing import TYPE_CHECKING, ClassVar, Protocol

import torch as xp
from torchcubicspline import NaturalCubicSpline, natural_cubic_spline_coeffs

if TYPE_CHECKING:
from stream_ml.pytorch import Data
from stream_ml.pytorch.typing import Array, ArrayNamespace

# =============================================================================
# Cluster Mass Function


class StreamMassFunction(Protocol):
"""Stream Mass Function.
Expand All @@ -26,7 +27,7 @@ class StreamMassFunction(Protocol):
"""

def __call__(
self, gamma: Array, x: Data[Array], *, xp: ArrayNamespace[Array]
self, gamma: Array, x: Data[Array] | None, *, xp: ArrayNamespace[Array]
) -> Array:
r"""Log-probability of stars at position 'x' having mass 'gamma'.

Expand All @@ -51,9 +52,9 @@ def __call__(
@dataclass(frozen=True)
class UniformStreamMassFunction(StreamMassFunction):
def __call__(
self, gamma: Array, x: Data[Array], *, xp: ArrayNamespace[Array]
self, gamma: Array, x: Data[Array] | None, *, xp: ArrayNamespace[Array]
) -> Array:
return xp.zeros((len(x), len(gamma)))
return xp.zeros((1 if x is None else len(x), len(gamma)))


@dataclass(frozen=True)
Expand All @@ -64,9 +65,9 @@ class HardCutoffMassFunction(StreamMassFunction):
upper: float = 1

def __call__(
self, gamma: Array, x: Data[Array], *, xp: ArrayNamespace[Array]
self, gamma: Array, x: Data[Array] | None, *, xp: ArrayNamespace[Array]
) -> Array:
out = xp.full((len(x), len(gamma)), -xp.inf)
out = xp.full((1 if x is None else len(x), len(gamma)), -xp.inf)
out[:, (gamma >= self.lower) & (gamma <= self.upper)] = 0
return out

Expand All @@ -79,11 +80,76 @@ class StepwiseMassFunction(StreamMassFunction):
log_probs: tuple[float, ...] # (B,)

def __call__(
self, gamma: Array, x: Data[Array], *, xp: ArrayNamespace[Array]
self, gamma: Array, x: Data[Array] | None, *, xp: ArrayNamespace[Array]
) -> Array:
out = xp.full((len(x), len(gamma)), -xp.inf)
out = xp.full((1 if x is None else len(x), len(gamma)), -xp.inf)
for lower, upper, lnp in zip(
self.boundaries[:-1], self.boundaries[1:], self.log_probs, strict=True
):
out[:, (gamma >= lower) & (gamma < upper)] = lnp
return out


# --------------------------------------------------------------------------------


@dataclass(frozen=True)
class KroupaIMF(StreamMassFunction):
"""Kroupa IMF.

https://arxiv.org/pdf/astro-ph/0102155.pdf

Parameters
----------
gamma_to_mass : :class:`torchcubicspline.NaturalCubicSpline`
A callable that takes in gamma and returns the mass.
"""

gamma_to_mass: NaturalCubicSpline

ranges: ClassVar[tuple[float, ...]] = (0.01, 0.08, 0.5, 100)
ln_ranges: ClassVar[tuple[float, ...]] = (1e-2, 8e-2, 5e-1, 1e2)

def __post_init__(self) -> None:
# TODO: need to normalize to gamma?
# Compute the normalization by integrating over the mass range.
self.ln_norm: float
object.__setattr__(self, "ln_norm", 0) # to avoid missing self-reference.
gammas = xp.linspace(0, 1, 10_000)
masses = self.gamma_to_mass.evaluate(gammas)[:, 0].to(dtype=gammas.dtype)
lnpdfs = self(gammas, None, xp=xp)
norm = float(xp.sum(xp.exp(lnpdfs)[:-1] * xp.diff(masses)))
object.__setattr__(self, "ln_norm", log(norm))

def __call__(
self, gamma: Array, x: Data[Array] | None, *, xp: ArrayNamespace[Array]
) -> Array:
out = xp.empty_like(gamma)
mass = self.gamma_to_mass.evaluate(gamma)[:, 0].to(dtype=gamma.dtype)
rng = xp.asarray(self.ranges, dtype=gamma.dtype)
ln_rng = xp.asarray(self.ln_ranges, dtype=gamma.dtype)

# https://arxiv.org/pdf/astro-ph/0102155.pdf
if xp.any((mass < rng[0]) | (mass >= rng[-1])):
msg = f"mass must be >= {rng[0]}."
raise ValueError(msg)

sel = (mass >= rng[0]) & (mass < rng[1])
out[sel] = -0.3 * (xp.log(mass[sel]) - ln_rng[1])

sel = (mass >= rng[1]) & (mass < rng[2])
out[sel] = -1.3 * (xp.log(mass[sel]) - ln_rng[1])

sel = (mass >= rng[2]) & (mass < rng[3])
out[sel] = ln_rng[2] + 1.3 * ln_rng[1] - 2.3 * xp.log(mass[sel])

return out - self.ln_norm

# ===============================================================
# Constructors

@classmethod
def from_arrays(cls, gamma: Array, mass: Array) -> KroupaIMF:
"""Construct from arrays."""
coeffs = natural_cubic_spline_coeffs(gamma, mass[:, None])
return cls(gamma_to_mass=NaturalCubicSpline(coeffs))