Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First DDE version #169

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open

First DDE version #169

wants to merge 14 commits into from

Conversation

thibmonsel
Copy link
Contributor

@thibmonsel thibmonsel commented Oct 7, 2022

New files :

  • discontinuity.py that does the root finding during integration steps.

Modified files:

  • integrate.py changed a bit of the code but essentially looks the same but with more if statements. There is also the discontinuity handling before each integration step done. Added 2 new arguments to _State (discontinuities, discontinuities_save_index)
  • constant.py does the discontinuity checking and returns the next integration step. But as said in WIP the prevbefore and nextafter are done in the loop()

Followed your suggestion regarding dropping y0_history and putting it in y0. However by doing this we must pass y0 to the loop function now. Haven't done the PyTree handling of delays yet. Only works for constant stepsize controller, doing adaptive now.

PS : I dont have the save saving format as you so terms.py shows some deletion and addition for no reason ....

Boilerplate code for a dde :

def vector_field(t, y, args, *, history):
        return 1.8 * y * (1 - history[0])

    delays = [lambda t, y, args: 1]
    y0_history = lambda t: 1.2
    discontinuity = (0.0,)

    made_jump = discontinuity is None
    t0, t1 = 0.0, 100.0
    ts = jnp.linspace(t0, t1, 1000)
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(vector_field),
        diffrax.Dopri8(),
        t0=ts[0],
        t1=ts[-1],
        dt0=ts[1] - ts[0],
        y0=y0_history,
        max_steps=2 ** 16,
        stepsize_controller=diffrax.ConstantStepSize(),
        saveat=diffrax.SaveAt(ts=ts, dense=True),
        delays=delays,
        discontinuity=discontinuity,
        made_jump=made_jump,
    )

    plt.plot(sol.ts, sol.ys)
    plt.show()

@thibmonsel thibmonsel changed the title Fix version DDE version First DDE version Oct 7, 2022
@patrick-kidger
Copy link
Owner

patrick-kidger commented Oct 7, 2022

Okay, so there's a lot of spurious changes here, mostly due to unneeded formatting changes. Take a look at CONTRIBUTING.md, and in particular the pre-commit hooks. These will autoformat etc. the code. I'll be able to do a proper review then.

Regarding passing y0_history into loop: I'm thinking what we should do is something like:

def diffeqsolve(y0, delays, ...):
    if delays is None:
        y0_history = None
    else:
        y0_history = y0
        y0 = y0_history(t0)

    adjoint.loop(..., y0_history=y0_history)

so that internally we still disambiguate between y0 and y0_history.

Regarding the changes to constant step sizing: hmm, this seems strange to me. I don't think we should need to change any stepsize controller at all. I think the stepsize controller changes we need to make (due to discontinuities) should happen entirely within integrate.py, so that they can apply to all stepsize controllers. (e.g. even user-specified stepsize controllers, that don't have any special support)

diffrax/integrate.py Outdated Show resolved Hide resolved
@thibmonsel
Copy link
Contributor Author

Hello,
Im back with some updates !

1/ I used the pre-commit hooks but still have one small spurious change in terms.py 😲 .
1/ y0 and y0_history are disambiguated !
2/ The controllers are untouched and better for modular code.
3/ Discontinuity handling is done in loop file.

4/ 1 edges cases was found, i haven't thought too much and did a "sloppy" fix for now (https://github.com/patrick-kidger/diffrax/pull/169/files#r991488410). Essentially it comes when we integrate a step from tprev to tnext and the integration bound new_tprev=tnext is right next to a discontinuity forcing to redo the step and having an new interval being [tnext ; tnext + epsilon = discontinuity_jump]. This in hand with the code handling will put tprev > tnext and throw an error.
Regarless, for most cases the solver seems to work

Things not done :
1/ y0_history is still not a PyTree of callable but when that will be done i suppose to change its structure one should use eqx.tree_at ?

diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/term.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/misc/discontinuity.py Outdated Show resolved Hide resolved
diffrax/misc/discontinuity.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
@thibmonsel
Copy link
Contributor Author

thibmonsel commented Oct 14, 2022

Latest commit does what you suggested

  • Reverted back to not touching the state.tnext and such.
  • Integrated your bissection code into the discontinuity detection.
  • Only search for discontinuities if step is rejected.
  • delays are PyTrees

To do :

  • code cleanup ie removing discontinuity.py that now is obsolete (will do next commit when something stable is available)
  • implicit stepping when the state.tnext - state.tprev > min(delays)
    • regarding this I checked the julia paper (4.2), I understand the issue there but the explanation on how to handle it is rather opaque i find... (*). In this case, the use of the continuous extension (ie dense_interp from dense_infos) makes the overall method implicit even if the discrete method we are using is explicit (this is called overlapping apparently).

Comments :

  • Regarding the first commit where lots of thing where changed, I wasn't aware of the 2 point of views (philosophies) on how to integrate DDEs. The first is to track the discontinuities and the other relies on error estimation of the solver's method. On the latter discontinuity tracking is somewhat useless. Julia uses therefore the second idea.
  • Since the implicit stepping (cf *) is not done some integrations are faulty.

@thibmonsel
Copy link
Contributor Author

thibmonsel commented Oct 18, 2022

Latest commit has :

  • added a discont_update in the discontinuity checking part to make sure we correctly update state.discontinuity
  • added a discont_check argument in diffeqsolve to check for breaking points at each step
  • implemented the implicit step in history_extrapolation_implicit
  • removed the discontinuity.py so everything is in integrate.py
  • enhanced discontinuity tracking with another functionality to check also for roots in subintervals of the [tprev, tnext], this helps to not miss any potential jumps that would otherwise be undetected. Not sure, but i think the vmaping here doesn't destroy the computational gain of the unvmap ?

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I've reviewed about half of it! I'll try and get to the other half in the next week or so.

Overall I like where this is going. I think we now have some first implementations for every piece.

Regarding my comments about avoiding in-place updates. I think this is probably doable by updating HistoryVectorField to operate on three regions: y0_history, the recorded dense_infos, and finally the current step. (Much like how it is operating over two regions at the moment.) I think this should allow us to generate efficient code.

diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
@thibmonsel
Copy link
Contributor Author

thibmonsel commented Oct 28, 2022

Latest commit updates some remarks of code comment after latest review #169 (review).

I've bundled together as you said the delays term together for an easier API.

class _Delays(eqx.Module):
    delays: Optional[PyTree[Callable]]
    initial_discontinuities: Union[Array, Tuple]
    max_discontinuities: Int
    recurrent_checking: Bool
    rtol: float
    atol: float
    eps : float

delays is the regular delayed term definition, initial_discontinuities corresponds to the discontinuities that the user needs to give to get proper DDE integration (I refer you to my response #169 (comment)) , discont_checking is now recurrent_checking and the atol, rtolare for the Newton solver in the implicit step and eps the tolerance error or the step.

What you suggested is great for later iterations because we can just slap any new arguments into _Delays !

What needs to be done :

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, let's focus on this bit before we move on to discussing the discontinuity handling below. I'll leave you to make the changes already discussed here. Let me know if any of them aren't clear.

diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
@thibmonsel
Copy link
Contributor Author

Okay, let's focus on this bit before we move on to discussing the discontinuity handling below. I'll leave you to make the changes already discussed here. Let me know if any of them aren't clear.

Sounds good, I'll take care of the first bullet point later, the second one should be ok on my side however i'd like to have your take on the third one with the in-place operations (for dense_ts, dense_infos) since this is some very sharp JAX bit 🔪 .

@patrick-kidger
Copy link
Owner

Sure thing. I'm suggesting that _HistoryVectorField should look something like this:

class _HistoryVectorField(eqx.Module):
    ...
    tprev: float
    tnext: float
    dense_info: PyTree[Array]
    interpolation_cls: Type[AbstractLocalInterpolation]

    def __call__(self, t, y, args):
        ...
        if self.dense_interp is None:
            ...
        else:
            for delay in delays:
            
                delay_val = delay(t, y, args)
                alpha_val = t - delay_val

                is_before_t0 = alpha_val < self.t0
                is_before_tprev = alpha_val < self.tprev
                at_most_t0 = jnp.min(alpha_val, self.t0)
                t0_to_tprev = jnp.clip(alpha_val, self.t0, self.tprev)
                at_least_tprev = jnp.max(alpha_val, self.tprev)

                step_interpolation = self.interpolation_cls(
                    t0=self.tprev, t1=self.tnext, **self.dense_info
                )
                switch = jnp.where(is_before_t0, 0, jnp.where(is_before_tprev, 1, 2))
                history_val = lax.switch(switch, [lambda: self.y0_history(at_most_t0),
                                                  lambda: self.dense_interp(t0_to_tprev),
                                                  lambda: step_interpolation.evaluate(at_least_tprev)])
                history_vals.append(history_val)
        ...
        return ...

And then when it is called inside the implicit routine:

def body_fun(val):
    dense_info, ... = val
    ...
    _HistoryVectorField(..., state.tprev, state.tnext, dense_info, solver.interpolation_cls)
    ...
    return new_dense_info, ...

@thibmonsel
Copy link
Contributor Author

thibmonsel commented Nov 4, 2022

Latest commit should have handle all of the issues mentionned above.

  • Not do an extra step after doing the implicit step by integrating de-facto the explicit step in the lax.for_loopof the function history_extrapolation_implicit with the conditional your proposed (First DDE version  #169 (comment))
  • In-place update handled with _HistoryVectorField

Discussion/Bottleneck for implicit step

Regarding the implicit step we have a issue when it comes to large steps because an step_interpolation with only 2 points won't suffice. This depends on the fact that the snippet below is indeed a 2 point interpolation :

step_interpolation = self.interpolation_cls(t0=self.tprev, t1=self.tnext, **self.dense_info)

To elaborate a bit more, if we have an implicit step from state.tprev to state.tnext. Our associated history function for the equation y'(t) = f(t, y(t-tau)) will be known from state.tprev up to state.tprev + tau and from state.prev+tau to state.tnext we need its extrapolation. If the history function in the interval [state.prev+tau : state.tnext] is non monotonous (we could image half period of a sinus for example) well a 2 point extrapolation won't capture the function correctly but we would need rather 10 points lets say to get a good estimate. To this regard we would need also to change the conditioning of your implict step from

_pred = (((y - y_prev) / y) > delays.eps).any()

to something that checks the MSE of the extrapolated history function before and after the integration step. Not sure with this in mind the _HistoryVectorField from #169 (comment) as is will do the trick.

This also impacts too the population of the ys in dense_ts since the values are interpolated with the computed steps of y. (If we go with an implicit step from state.tprev to state.tnext and we have to save some points in between, the current procedure is to use a 2 point (yprev=state.y and ynext=y) interpolation right ?) With that being said, this discussion is only relevant if the time mesh that we have (ie dense_ts) is precise enough.

@patrick-kidger
Copy link
Owner

Our associated history function for the equation y'(t) = f(t, y(t-tau)) will be known from state.tprev up to state.tprev + tau

I don't think this is true. Anything after tprev hasn't been evaluated yet. The whole [tprev, tnext] region is initialised as an extrapolation from the previous step.

Regarding 2-point interpolations: this isn't the case for most solvers. Each solver evaluates various intermediate quantities during its step (e.g. the stages of an RK solver) and these also feed into the interpolation.

Even if were, though, I don't think it matters: we just need to converge to a solution of the implicit problem.

@thibmonsel
Copy link
Contributor Author

thibmonsel commented Nov 4, 2022

Ok, this makes sense, so i'll take back what I said in my bottleneck "Discussion/Bottleneck for implicit step", thanks for the clarification ! The _pred condition should be on the dense_info then !

Edit : _pred condition as it seems to be working fine !
For this part of the code, i'd say its ready for a review before moving to discontinuity checking ! (unconstrained time stepping works well)

diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
@thibmonsel
Copy link
Contributor Author

thibmonsel commented Nov 7, 2022

Relevant changes are made :

  • All the code is now managed with PyTree operations (i think)
  • Moved Delays and implicit step into a new file (but there is a circular import there ... not sure if this is an issue)
  • Removed eps features of the implicit stepping

In order to get something backprop compatible in history_extrapolation_implicit I'll need to use your bounded_while_loop instead of lax.while_loopin the meantime until your PR is merged ?

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, this is probably the final round of review on the implicit stepping. (Nice work!) I'll go over the discontinuity checking at some point shortly.

Regarding backprop through the lax.while_loop, the correct thing to do here is actually to use the implicit function theorem, e.g. as is already done with Newton's method. This can be wired up using the misc.implicit_jvp helper. Let me know if you're not familiar with this and I can give you some pointers on how this works.

diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/misc/__init__.py Outdated Show resolved Hide resolved
diffrax/misc/delays.py Outdated Show resolved Hide resolved
diffrax/misc/delays.py Outdated Show resolved Hide resolved
diffrax/misc/delays.py Outdated Show resolved Hide resolved
diffrax/misc/delays.py Outdated Show resolved Hide resolved
@thibmonsel
Copy link
Contributor Author

thibmonsel commented Nov 12, 2022

Great news !
Im not familiar with the wiring of the implicit_jvp up for some pointers.
Other then that relevant changes were made

@@ -6,6 +6,7 @@
RecursiveCheckpointAdjoint,
)
from .brownian import AbstractBrownianPath, UnsafeBrownianPath, VirtualBrownianTree
from .delays import _HistoryVectorField, Delays, history_extrapolation_implicit
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think we only want to expose Delays in the public interface. Everything else is Diffrax-internal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry not sure to understand what you mean

diffrax/delays.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
)
return _next_ta, _next_tb, _pred, _step, max_step

_init_val = (sub_tprev, sub_tnext, True, 0, 400)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we hit max_steps here? I think we need a mechanism in this whole discontinuity procedure to allow the discontinuity-finding to fail. (Namely, reject the step and use some more naive way of picking tnext, e.g. tprev plus 0.1 times the current interval length or something.)

Copy link
Contributor Author

@thibmonsel thibmonsel Nov 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we hit max_steps that would just means that the NewtonSolver would take the lead with the current root approximation ? Speaking from experience, since each sub_tprev and sub_tnext used are pretty small intervals this would probably not happen ? The interval [tprev, tnext] is splitted into N sub intervals where root tracking is done.

diffrax/integrate.py Outdated Show resolved Hide resolved
_discont = _discont_solver(_h, _tb, args).root
_disconts.append(_discont)
if _discont.size == 0:
return jnp.inf
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can ever hit this branch?
(And indeed, as currently written it would crash, since it doesn't also return a bool.)

Copy link
Contributor Author

@thibmonsel thibmonsel Nov 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup we never would hit this, no problem should be encountered since NewtonNonlinearSolver return NaNs if it fails right ? (if that is the case we should be good then )

diffrax/integrate.py Outdated Show resolved Hide resolved
@patrick-kidger
Copy link
Owner

patrick-kidger commented Nov 18, 2022

Alright, on to the next block of code!

As for implicit_jvp -- I actually have some in-progress work that may simplify this. If you'd like to be able to backpropagate through this code soon-ish then I can expand on what I mean here? (But if it's not a rush then I'm happy to put a pin in that for now.)

@thibmonsel
Copy link
Contributor Author

  • Relevant changes were made for the last code block, everything is in delays.py now !
  • Added 2 new attributes for Delays : nb_sub_intervals and max_steps
  • Lots of wrapper and deletion makes the code more readable !

As for implicit_jvp, backpropagating throught this part would be great !
Looking into this #169 (comment) a bit deeper now

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking good. Let's start adding this into the documentation and so on:

  • Add a new doc page under the "Advanced API" section.
  • Add a short example. (See the steady-state example for inspiration.)

Let's also add a test_delays.py file, checking all the edge cases. Off the top of my head:

  • Basic checks that the solver actually runs without crashing.
  • Numerical checks that the expected solution is obtained.
  • Check what happens when we exceed Delays.max_discontinuities.
  • Check what happens when we exceed Delays.max_steps.
  • Test combining delays with stochastic terms.
  • Test combining delays with PIDController(jump_ts=...).
  • Test a 'smooth' DDE (whose initial history is such that there is no discontinuity), and check that this can be solved without any expensive discontinuity detection at all. (Actually, I don't think this is possible at the moment -- maybe the recurrent_checking argument should be generalised: discontinuity_checking=True/False/None for always check / only check on rejected steps / never check?)
  • Test DDEs that hit both the implicit-step and the explicit-step branches, and that those branches are taken. (To test this: perhaps we can count the number of implicit and explicit steps, and return that in the auxiliary stats.)

Other changes that come to mind:

  • Another good auxiliary statistic could be how many discontinuities were encountered. (Including any discontinuities arising from jump_ts?)
  • I think the numerics might have still have a subtle bug from the lack of something like
    def _clip_to_end(tprev, tnext, t1, keep_step):
    not being used every time we clip tnext.

As for implicit_jvp. Fix some function f, and define y(θ) as being the value of y satisfying f(y, θ) = 0. (And we assume there is a unique such y.) Then we can see that the function f has implicitly defined a function θ -> y(θ).

We can now seek to evaluate the derivative dy/dθ. This actually involves a linear solve, and is what implicit_jvp does. This is pretty easy to do. Probably the best reference is:
http://implicit-layers-tutorial.org/implicit_functions/
Also see the final equation of this section:
https://en.wikipedia.org/wiki/Implicit_function_theorem#Statement_of_the_theorem

You can see this in action for the existing nonlinear solvers, which use the implicit function theorem to differentiate through their root-finding operation.

diffrax/__init__.py Show resolved Hide resolved
from .local_interpolation import AbstractLocalInterpolation
from .misc import rms_norm
from .misc.omega import ω
from .misc.unvmap import unvmap_any
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before merging, we'll have to rebase your branch onto the current main version of Diffrax. When that happens, heads-up that these imports will change to:

import equinox.internal as eqxi
from equinox.internal import ω

eqxi.unvmap_any

diffrax/delays.py Outdated Show resolved Hide resolved
diffrax/delays.py Outdated Show resolved Hide resolved
diffrax/delays.py Outdated Show resolved Hide resolved
diffrax/delays.py Outdated Show resolved Hide resolved
diffrax/delays.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
diffrax/integrate.py Outdated Show resolved Hide resolved
@patrick-kidger patrick-kidger changed the base branch from delay to main November 26, 2022 22:03
@thibmonsel
Copy link
Contributor Author

thibmonsel commented Nov 28, 2022

Thanks for the review Patrick :) , so from what I understood what i need to do is to create a new class DDEImplicitNonLinearSolve (for e.g) that inherits from AbstractNonlinearSolver in order to use implicit_jvp. Basically i'll need to rewrite the _solve method and then in history_extrapolation_implicit call in an instance of DDEImplicitNonLinearSolve.

class DDEImplicitNonLinearSolve(AbstractNonlinearSolver):
    def _solve(
        self,
        fn: callable, # would be terms
        x : Pytree, # here would be state.y
        nondiff_args: PyTree, # here would be all the other args from current history_extrapolation_implicit(...)
        diff_args: PyTree,
    ) 
def history_extrapolation_implicit(...):
     nonlinearsolver = DDEImplicitNonLinearSolve(...)
     results = nonlinearsolver(terms, y, args).root 
     y, y_error, dense_info, solver_state, solver_result = results 
     return y, y_error, dense_info, solver_state, solver_result

If thats the case could you explain how you usually work with your non_diff_args and diff_args since I think its more native to JAX's jvp/vjp and are later on used with implicit_jvp and _rewrite, etc... ?

@thibmonsel
Copy link
Contributor Author

Adding [...] did the trick for ts

ts = state.dense_ts[...]

Unfortunately unwrapping the Buffer with other structure's like DenseInfos didn't seem to work.

infos=jtu.tree_map(lambda x: x[...], state.dense_infos)

The problem came from the _pred argument of the _Buffer

However, I found another way

unwrapped_buffer = jtu.tree_leaves(
    eqx.filter(state.dense_infos, eqx.is_inexact_array),
    is_leaf=eqx.is_inexact_array,
)
unwrapped_dense_infos = dict(zip(state.dense_infos.keys(), unwrapped_buffer))

@patrick-kidger
Copy link
Owner

Notice how I pass in an additional argument in my previous snippet: jtu.tree_map(lambda _, x: x[...], dense_info, state.dense_infos).

This is because each _Buffer object is actually a PyTree, that's sort of masquerading as an array. So to handle this we actually iterate over a different pytree, that happens to have the correct structure.

The use of buffers is a pretty advanced/annoying detail. I'm pondering using something like Quax to create a safer API for this, but that's a long way down the to-do list.

@thibmonsel
Copy link
Contributor Author

Indeed your right, I just realised that in the documentation yesterday, this makes sense now !
However, we only have access to dense_info after creating DenseInterpolation.

i.e.

 dense_interp = DenseInterpolation(
                ts=state.dense_ts[...],
                infos = jtu.tree_map(lambda _, x: x[...], dense_info, state.dense_infos), 
                ...
            )

(
    y,
    y_error,
    dense_info,
    solver_state,
    solver_result,
) = history_extrapolation_implicit(
    ...
)

So I would agree that this works if dense_info was available which seems to be not the case. Hence not sure how to do it. Moreover, I did some preliminary testing and doing this #169 (comment) yields wrong gradients. Nonetheless, the code makes the promise to only read from locations that have already been written to with an in-place update so my unwrapping method is/seems erroneous.

@patrick-kidger
Copy link
Owner

Since I'm starting to see activity here again -- you can track my progress updating Diffrax in #217 and https://github.com/patrick-kidger/diffrax/tree/big-refactor.

(This currently depends on the unreleased versions of Equinox and jaxtyping.)

Mostly there now!

@thibmonsel thibmonsel changed the base branch from main to dev December 12, 2023 12:58
@patrick-kidger patrick-kidger deleted the branch patrick-kidger:main January 8, 2024 22:27
@thibmonsel
Copy link
Contributor Author

thibmonsel commented Oct 28, 2024

Reviving this PR !
Given #512 and the issue with tracer leaks seems to be fixed by disabling jit i.e export JAX_DISABLE_JIT=1.

The MWE to test it out :

import jax
import equinox as eqx
import jax.numpy as jnp
import jax.random as jrandom
import diffrax


class Func(eqx.Module):
    linear: eqx.nn.Linear

    def __init__(self, data_size, *, key, **kwargs):
        super().__init__(**kwargs)
        self.linear = eqx.nn.Linear(2 * data_size, data_size, key=key)

    def __call__(self, t, y, args, *, history):
        return self.linear(jnp.hstack([y, *history]))


class NeuralDDE(eqx.Module):
    func: Func
    delays: diffrax.Delays

    def __init__(self, data_size, delays, *, key, **kwargs):
        super().__init__(**kwargs)
        self.func = Func(data_size, key=key)
        self.delays = delays

    def __call__(self, ts, y0):
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Bosh3(),
            t0=ts[0],
            t1=ts[-1],
            dt0=ts[1] - ts[0],
            y0=lambda t: y0,
            saveat=diffrax.SaveAt(ts=ts, dense=True),
            delays=self.delays,
        )
        return solution.ys


@eqx.filter_value_and_grad
def grad_loss(model, ti, yi):
    y_pred = model(ti, yi[0])
    return jnp.mean((yi - y_pred) ** 2)


@eqx.filter_value_and_grad
def grad_loss_batch(model, ti, yi):
    y_pred = jax.vmap(model, (None, 0))(ti, yi[:, 0])
    return jnp.mean((yi - y_pred) ** 2)

key = jrandom.PRNGKey(0)
ys = jnp.ones((1, 3, 1))
ts = jnp.linspace(0, 1.0, 3)
_, length_size, datasize = ys.shape
delays = diffrax.Delays(delays=(lambda t, y, args: 1.0,))
model_dde = NeuralDDE(datasize, delays, key=key)

print("Starting check_tracer_leaks()")
with jax.check_tracer_leaks():
    loss, grads = grad_loss(model_dde, ts, ys[0])
    print("SUCCESS with check_tracer_leaks() with grad_loss()")
    loss2, grads2 = grad_loss_batch(model_dde, ts, ys)
    print("SUCCESS with check_tracer_leaks() with grad_loss_batch()")

Neural DDE works with RecursiveCheckpointAdjoint().

A modified version of the optimistix's fixed point algorithm is used in order to have unconstrained time stepping :
https://github.com/thibmonsel/diffrax/blob/7e7d1b443e76c2573458e2d7f4a72223967cb01d/diffrax/_delays.py#L32
I found that using optx.RecursiveCheckpointAdjoint() for the fixed point algorithm provided more stable training.

To integrate a DDE, two methods can be used :

  • 1/ check for discontinuites at every step
  • 2/ trust that the adaptive stepsize controller will reject steps that are too far from the ground truth.

In hindsight, I would like to go for the first option (1/). This would mean that some code (e.g https://github.com/thibmonsel/diffrax/blob/7e7d1b443e76c2573458e2d7f4a72223967cb01d/diffrax/_delays.py#L297) and many attributes of Delays(eqx.Module) (https://github.com/thibmonsel/diffrax/blob/7e7d1b443e76c2573458e2d7f4a72223967cb01d/diffrax/_delays.py#L97) could be removed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants