Skip to content

Commit

Permalink
In progress commit on branch delay.
Browse files Browse the repository at this point in the history
  • Loading branch information
thibmonsel committed Jan 12, 2024
1 parent 3020bb5 commit 56d8c0b
Show file tree
Hide file tree
Showing 24 changed files with 12,388 additions and 107 deletions.
179 changes: 93 additions & 86 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,107 +1,114 @@
import importlib.metadata

from ._adjoint import (
AbstractAdjoint as AbstractAdjoint,
BacksolveAdjoint as BacksolveAdjoint,
DirectAdjoint as DirectAdjoint,
ImplicitAdjoint as ImplicitAdjoint,
RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint,
)
from ._autocitation import citation as citation, citation_rules as citation_rules
from ._brownian import (
AbstractBrownianPath as AbstractBrownianPath,
UnsafeBrownianPath as UnsafeBrownianPath,
VirtualBrownianTree as VirtualBrownianTree,
)
from ._adjoint import AbstractAdjoint as AbstractAdjoint
from ._adjoint import BacksolveAdjoint as BacksolveAdjoint
from ._adjoint import DirectAdjoint as DirectAdjoint
from ._adjoint import ImplicitAdjoint as ImplicitAdjoint
from ._adjoint import RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint
from ._autocitation import citation as citation
from ._autocitation import citation_rules as citation_rules
from ._brownian import AbstractBrownianPath
from ._brownian import AbstractBrownianPath as AbstractBrownianPath
from ._brownian import UnsafeBrownianPath
from ._brownian import UnsafeBrownianPath as UnsafeBrownianPath
from ._brownian import VirtualBrownianTree
from ._brownian import VirtualBrownianTree as VirtualBrownianTree
from ._custom_types import LevyVal as LevyVal
from ._event import (
AbstractDiscreteTerminatingEvent as AbstractDiscreteTerminatingEvent,
DiscreteTerminatingEvent as DiscreteTerminatingEvent,
SteadyStateEvent as SteadyStateEvent,
)
from ._delays import Delays as Delays
from ._delays import bind_history as bind_history
from ._delays import history_extrapolation_implicit as history_extrapolation_implicit
from ._delays import maybe_find_discontinuity as maybe_find_discontinuity
from ._event import AbstractDiscreteTerminatingEvent
from ._event import AbstractDiscreteTerminatingEvent as AbstractDiscreteTerminatingEvent
from ._event import DiscreteTerminatingEvent
from ._event import DiscreteTerminatingEvent as DiscreteTerminatingEvent
from ._event import SteadyStateEvent
from ._event import SteadyStateEvent as SteadyStateEvent
from ._global_interpolation import (
AbstractGlobalInterpolation as AbstractGlobalInterpolation,
)
from ._global_interpolation import CubicInterpolation as CubicInterpolation
from ._global_interpolation import DenseInterpolation as DenseInterpolation
from ._global_interpolation import LinearInterpolation as LinearInterpolation
from ._global_interpolation import (
backward_hermite_coefficients as backward_hermite_coefficients,
CubicInterpolation as CubicInterpolation,
DenseInterpolation as DenseInterpolation,
linear_interpolation as linear_interpolation,
LinearInterpolation as LinearInterpolation,
)
from ._global_interpolation import linear_interpolation as linear_interpolation
from ._global_interpolation import (
rectilinear_interpolation as rectilinear_interpolation,
)
from ._integrate import diffeqsolve as diffeqsolve
from ._local_interpolation import (
AbstractLocalInterpolation as AbstractLocalInterpolation,
)
from ._local_interpolation import (
FourthOrderPolynomialInterpolation as FourthOrderPolynomialInterpolation,
LocalLinearInterpolation as LocalLinearInterpolation,
ThirdOrderHermitePolynomialInterpolation as ThirdOrderHermitePolynomialInterpolation, # noqa: E501
)
from ._local_interpolation import LocalLinearInterpolation as LocalLinearInterpolation
from ._local_interpolation import (
ThirdOrderHermitePolynomialInterpolation as ThirdOrderHermitePolynomialInterpolation,
) # noqa: E501
from ._misc import adjoint_rms_seminorm as adjoint_rms_seminorm
from ._path import AbstractPath as AbstractPath
from ._root_finder import (
VeryChord as VeryChord,
with_stepsize_controller_tols as with_stepsize_controller_tols,
)
from ._saveat import SaveAt as SaveAt, SubSaveAt as SubSaveAt
from ._solution import (
is_event as is_event,
is_okay as is_okay,
is_successful as is_successful,
RESULTS as RESULTS,
Solution as Solution,
)
from ._solver import (
AbstractAdaptiveSolver as AbstractAdaptiveSolver,
AbstractDIRK as AbstractDIRK,
AbstractERK as AbstractERK,
AbstractESDIRK as AbstractESDIRK,
AbstractImplicitSolver as AbstractImplicitSolver,
AbstractItoSolver as AbstractItoSolver,
AbstractRungeKutta as AbstractRungeKutta,
AbstractSDIRK as AbstractSDIRK,
AbstractSolver as AbstractSolver,
AbstractStratonovichSolver as AbstractStratonovichSolver,
AbstractWrappedSolver as AbstractWrappedSolver,
Bosh3 as Bosh3,
ButcherTableau as ButcherTableau,
CalculateJacobian as CalculateJacobian,
Dopri5 as Dopri5,
Dopri8 as Dopri8,
Euler as Euler,
EulerHeun as EulerHeun,
HalfSolver as HalfSolver,
Heun as Heun,
ImplicitEuler as ImplicitEuler,
ItoMilstein as ItoMilstein,
KenCarp3 as KenCarp3,
KenCarp4 as KenCarp4,
KenCarp5 as KenCarp5,
Kvaerno3 as Kvaerno3,
Kvaerno4 as Kvaerno4,
Kvaerno5 as Kvaerno5,
LeapfrogMidpoint as LeapfrogMidpoint,
Midpoint as Midpoint,
MultiButcherTableau as MultiButcherTableau,
Ralston as Ralston,
ReversibleHeun as ReversibleHeun,
SemiImplicitEuler as SemiImplicitEuler,
Sil3 as Sil3,
StratonovichMilstein as StratonovichMilstein,
Tsit5 as Tsit5,
)
from ._root_finder import VeryChord as VeryChord
from ._root_finder import with_stepsize_controller_tols as with_stepsize_controller_tols
from ._saveat import SaveAt as SaveAt
from ._saveat import SubSaveAt as SubSaveAt
from ._solution import RESULTS as RESULTS
from ._solution import Solution as Solution
from ._solution import is_event as is_event
from ._solution import is_okay as is_okay
from ._solution import is_successful as is_successful
from ._solver import AbstractAdaptiveSolver as AbstractAdaptiveSolver
from ._solver import AbstractDIRK as AbstractDIRK
from ._solver import AbstractERK as AbstractERK
from ._solver import AbstractESDIRK as AbstractESDIRK
from ._solver import AbstractImplicitSolver as AbstractImplicitSolver
from ._solver import AbstractItoSolver as AbstractItoSolver
from ._solver import AbstractRungeKutta as AbstractRungeKutta
from ._solver import AbstractSDIRK as AbstractSDIRK
from ._solver import AbstractSolver as AbstractSolver
from ._solver import AbstractStratonovichSolver as AbstractStratonovichSolver
from ._solver import AbstractWrappedSolver as AbstractWrappedSolver
from ._solver import Bosh3 as Bosh3
from ._solver import ButcherTableau as ButcherTableau
from ._solver import CalculateJacobian as CalculateJacobian
from ._solver import Dopri5 as Dopri5
from ._solver import Dopri8 as Dopri8
from ._solver import Euler as Euler
from ._solver import EulerHeun as EulerHeun
from ._solver import HalfSolver as HalfSolver
from ._solver import Heun as Heun
from ._solver import ImplicitEuler as ImplicitEuler
from ._solver import ItoMilstein as ItoMilstein
from ._solver import KenCarp3 as KenCarp3
from ._solver import KenCarp4 as KenCarp4
from ._solver import KenCarp5 as KenCarp5
from ._solver import Kvaerno3 as Kvaerno3
from ._solver import Kvaerno4 as Kvaerno4
from ._solver import Kvaerno5 as Kvaerno5
from ._solver import LeapfrogMidpoint as LeapfrogMidpoint
from ._solver import Midpoint as Midpoint
from ._solver import MultiButcherTableau as MultiButcherTableau
from ._solver import Ralston as Ralston
from ._solver import ReversibleHeun as ReversibleHeun
from ._solver import SemiImplicitEuler as SemiImplicitEuler
from ._solver import Sil3 as Sil3
from ._solver import StratonovichMilstein as StratonovichMilstein
from ._solver import Tsit5 as Tsit5
from ._step_size_controller import (
AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController,
AbstractStepSizeController as AbstractStepSizeController,
ConstantStepSize as ConstantStepSize,
PIDController as PIDController,
StepTo as StepTo,
)
from ._term import (
AbstractTerm as AbstractTerm,
ControlTerm as ControlTerm,
MultiTerm as MultiTerm,
ODETerm as ODETerm,
WeaklyDiagonalControlTerm as WeaklyDiagonalControlTerm,
from ._step_size_controller import (
AbstractStepSizeController as AbstractStepSizeController,
)

from ._step_size_controller import ConstantStepSize as ConstantStepSize
from ._step_size_controller import PIDController as PIDController
from ._step_size_controller import StepTo as StepTo
from ._term import AbstractTerm as AbstractTerm
from ._term import ControlTerm as ControlTerm
from ._term import MultiTerm as MultiTerm
from ._term import ODETerm as ODETerm
from ._term import WeaklyDiagonalControlTerm as WeaklyDiagonalControlTerm

__version__ = importlib.metadata.version("diffrax")
12 changes: 12 additions & 0 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def loop(
solver,
stepsize_controller,
discrete_terminating_event,
delays,
saveat,
t0,
t1,
Expand All @@ -128,6 +129,7 @@ def loop(
init_state,
passed_solver_state,
passed_controller_state,
y0_history,
) -> Any:
"""Runs the main solve loop. Subclasses can override this to provide custom
backpropagation behaviour; see for example the implementation of
Expand Down Expand Up @@ -552,13 +554,15 @@ def _loop_backsolve_bwd(
solver,
stepsize_controller,
discrete_terminating_event,
delays,
saveat,
t0,
t1,
dt0,
max_steps,
throw,
init_state,
y0_history,
):
assert discrete_terminating_event is None

Expand Down Expand Up @@ -596,6 +600,8 @@ def _loop_backsolve_bwd(
adjoint=self,
solver=solver,
stepsize_controller=stepsize_controller,
discrete_terminating_event=discrete_terminating_event,
delays=delays,
terms=adjoint_terms,
dt0=None if dt0 is None else -dt0,
max_steps=max_steps,
Expand Down Expand Up @@ -775,6 +781,7 @@ def loop(
passed_solver_state,
passed_controller_state,
discrete_terminating_event,
delays,
**kwargs,
):
if jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) != jtu.tree_structure(
Expand Down Expand Up @@ -820,6 +827,10 @@ def loop(
raise NotImplementedError(
"`diffrax.BacksolveAdjoint` is not compatible with events."
)
if delays is not None:
raise NotImplementedError(
"Cannot use `delays` with `adjoint=BacksolveAdjoint()`"
)

y = init_state.y
init_state = eqx.tree_at(lambda s: s.y, init_state, object())
Expand All @@ -834,6 +845,7 @@ def loop(
init_state=init_state,
solver=solver,
discrete_terminating_event=discrete_terminating_event,
delays=delays,
**kwargs,
)
final_state = _only_transpose_ys(final_state)
Expand Down
Loading

0 comments on commit 56d8c0b

Please sign in to comment.