Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request] A delay differential equations solver #406

Open
miguelgondu opened this issue Apr 20, 2024 · 4 comments
Open

[Feature request] A delay differential equations solver #406

miguelgondu opened this issue Apr 20, 2024 · 4 comments
Labels
feature New feature

Comments

@miguelgondu
Copy link

Dear Patrick,

Thanks for this library, it's pretty neat!

I'm currently supervising a bachelor's students thesis on NeuralODEs, and we've been meaning to use diffrax to study delay differential equations. I'm currently trying to implement a delay differential equations solver inside diffrax, and I would like to know if I'm in the right track.

For context, let me give a brief overview of how delay differential equations work, and how such a solver could be implemented. In it's simplest form, a (constant) Delay Differential Equation (DDE) has a vector field $f$ that depends not only on the current state $y(t)$, but also on $y(t-\tau)$ where $\tau\in\mathbb{R}_{>0}$. In other words

$$y'(t) = f(t, y(t), y(t-\tau)).$$

Initial value problems involving DDEs provide a history instead of a single initial value ($y(t) = \phi(t)$ for $t \leq 0$, for example), and are solved in chunks using the "method of steps". Shortly put, one solves an IVP in intervals of the form $[t_0 + k\tau, t_0 + (k+1)\tau]$. (More details in Chap. 9 of this reference).

In practice, solving a DDE numerically can be done by selecting the right step-size such that $y(t-\tau)$ is always in the grid. To predict y_{t+1} we need to evaluate the vector field in y_{t} and in some y_{t-k} corresponding to $y(t-\tau)$. Other ways of doing it would involve e.g. building an Hermite interpolation between the two relevant points in the grid if $y(t-\tau)$ happens to lie outside of the grid, but I plan to focus on the first alternative.

How could I adapt diffrax to let me pass terms of the form f(t, y(t), y(t-tau))dt? I imagine I have to implement a new DelayTerm that inherits from AbstractTerm with a different vf method; since solvers need to evaluate these vector fields, I imagine I would also need to modify/create a new one in which vf is called with the right signature, right?

I'm of course happy to contribute the implementation to diffrax once it's up and running.

@patrick-kidger
Copy link
Owner

So support for DDEs is something we've been noodling over in Diffrax for a while, see #169.

The main reason that PR stalled is that solving general DDEs requires solving several of nonlinear optimisation problems, and at the time Optimistix did not exist yet.

Now that it does we have been meaning to revisit that PR, fix it up to use the new root-finding functionality that is now available in Optimistix.

I must acknowledge that this is (a) fairly technical code, but also conversely (b) that the hard parts are already written.

If you'd be interested in reviving that PR then this is still a feature I'd be happy to see in Diffrax.

@patrick-kidger patrick-kidger added the feature New feature label Apr 21, 2024
@thibmonsel
Copy link
Contributor

thibmonsel commented Apr 30, 2024

Hello there,

As mentionned by @patrick-kidger, most of the code itself is there (I would say 90%) and functional but some jax related bugs are still there (e.g. tracer leakage) that makes backpropagation an issue. Happy to discuss if you are interested in giving a hand.

@miguelgondu
Copy link
Author

Hi both,

Thanks for the implementation, @thibmonsel! My student and I have been using your dde.ipynb example on a different system of DDEs, and it worked almost out-of-the-box. If I understand correctly, the neuraldde.ipynb example is not finished yet, right?

I would be happy to help, but I fear this is above my skill level (I'm only now starting to use jax). If you give me pointers on how to get started, I could give it a try, but I can't promise much.

@thibmonsel
Copy link
Contributor

That's great to here !
Integrating DDEs itself should be more than robust (so dde.ipynb should work fine).
However, fitting a DDE with a neural net will shootout some Exception: Leaked trace when combined with jax.check_tracer_leaks().

A MWE for this could be :

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np

import diffrax


class Func(eqx.Module):
    linear: eqx.nn.Linear

    def __init__(self, data_size, *, key, **kwargs):
        super().__init__(**kwargs)
        self.linear = eqx.nn.Linear(2 * data_size, data_size, key=key)

    def __call__(self, t, y, args, *, history):
        return self.linear(jnp.hstack([y, *history]))


class NeuralDDE(eqx.Module):
    func: Func
    delays: diffrax.Delays

    def __init__(self, data_size, delays, *, key, **kwargs):
        super().__init__(**kwargs)
        self.func = Func(data_size, key=key)
        self.delays = delays

    def __call__(self, ts, y0):
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Euler(),
            t0=ts[0],
            t1=ts[-1],
            dt0=ts[1] - ts[0],
            y0=lambda t: y0,
            saveat=diffrax.SaveAt(ts=ts, dense=True),
            adjoint=diffrax.DirectAdjoint(),
            delays=self.delays,
        )
        return solution.ys


@eqx.filter_value_and_grad
def grad_loss(model, ti, yi):
    y_pred = model(ti, yi[0])
    return jnp.mean((yi - y_pred) ** 2)


@eqx.filter_value_and_grad
def grad_loss_batch(model, ti, yi):
    y_pred = jax.vmap(model, (None, 0))(ti, yi[:, 0])
    return jnp.mean((yi - y_pred) ** 2)


if __name__ == "__main__":
    seed = np.random.randint(0, 1000)
    key = jrandom.PRNGKey(seed)
    ts = jnp.linspace(0.0, 1.0, 10)
    ys = jnp.ones_like(ts)[..., None]
    length_size, datasize = ys.shape

    delays = diffrax.Delays(delays=[lambda t, y, args: 1.0])
    model_dde = NeuralDDE(datasize, delays, key=key)

    with jax.check_tracer_leaks():
        loss, grads = grad_loss(model_dde, ts, ys)

    # Batched version
    ys = jnp.concatenate([2 * jnp.ones((1, 10, 1)), 3 * jnp.ones((1, 10, 1))], axis=0)

    # Silently side-effecting, no error ?
    loss, grads = grad_loss_batch(model_dde, ts, ys)

    # Batch leaked tracer or reporting false positive from Notes in link :
    # https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
    with jax.check_tracer_leaks():
        loss, grads = grad_loss_batch(model_dde, ts, ys)

Happy to discuss on this thread or more in depth via email (or other medium).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

3 participants