Skip to content

Commit

Permalink
move isochrone stuff to core
Browse files Browse the repository at this point in the history
  • Loading branch information
nstarman committed Oct 11, 2023
1 parent 4a95821 commit b2a87b0
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 157 deletions.
36 changes: 27 additions & 9 deletions src/stream_ml/pytorch/builtin/_isochrone/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,6 @@

from __future__ import annotations

from stream_ml.pytorch.builtin._isochrone.core import IsochroneMVNorm
from stream_ml.pytorch.builtin._isochrone.mf import (
HardCutoffMassFunction,
StepwiseMassFunction,
StreamMassFunction,
UniformStreamMassFunction,
)
from stream_ml.pytorch.builtin._isochrone.utils import Parallax2DistMod

__all__: list[str] = [
# Mass Function
"StreamMassFunction",
Expand All @@ -22,3 +13,30 @@
# Utils
"Parallax2DistMod",
]

from dataclasses import field, make_dataclass

import torch as xp

from stream_ml.core.builtin._isochrone.mf import (
HardCutoffMassFunction,
StepwiseMassFunction,
StreamMassFunction,
UniformStreamMassFunction,
)
from stream_ml.core.builtin._isochrone.utils import (
Parallax2DistMod as CoreParallax2DistMod,
)

from stream_ml.pytorch.builtin._isochrone.core import IsochroneMVNorm
from stream_ml.pytorch.typing import Array, ArrayNamespace

# -----------------------------------------------------------------------------

Parallax2DistMod = make_dataclass(
"Parallax2DistMod",
[("array_namespace", ArrayNamespace[Array], field(default=xp, kw_only=True))],
bases=(CoreParallax2DistMod[Array],),
unsafe_hash=True,
repr=False,
)
11 changes: 5 additions & 6 deletions src/stream_ml/pytorch/builtin/_isochrone/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,16 @@

import torch as xp

from stream_ml.core._core.field import NNField
from stream_ml.core import Data, NNField
from stream_ml.core.builtin._isochrone.mf import (
StreamMassFunction,
UniformStreamMassFunction,
)
from stream_ml.core.builtin._utils import WhereRequiredError
from stream_ml.core.utils.frozen_dict import FrozenDict, FrozenDictField
from stream_ml.core.utils.funcs import within_bounds

from stream_ml.pytorch import Data
from stream_ml.pytorch._base import ModelBase
from stream_ml.pytorch.builtin._isochrone.mf import (
StreamMassFunction,
UniformStreamMassFunction,
)

if TYPE_CHECKING:
from scipy.interpolate import CubicSpline
Expand Down
89 changes: 0 additions & 89 deletions src/stream_ml/pytorch/builtin/_isochrone/mf.py

This file was deleted.

53 changes: 0 additions & 53 deletions src/stream_ml/pytorch/builtin/_isochrone/utils.py

This file was deleted.

0 comments on commit b2a87b0

Please sign in to comment.