Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additive SDE throws error with SRK style solvers #474

Open
ParticularlyPythonicBS opened this issue Aug 1, 2024 · 8 comments
Open

Additive SDE throws error with SRK style solvers #474

ParticularlyPythonicBS opened this issue Aug 1, 2024 · 8 comments
Labels
question User queries

Comments

@ParticularlyPythonicBS
Copy link
Contributor

Hi,
Can you help me debug why this SDE would throw errors for SRK solvers, but works and integrates fine with ERK and Milstein?
Here is a simplified version of the code:

import os
import multiprocessing

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    multiprocessing.cpu_count()
)
# os.environ["JAX_TRACEBACK_FILTERING"] = "off"
# os.environ["EQX_ON_ERROR"] = "breakpoint"

import jax
import diffrax as dfx
import jax.numpy as jnp
import time

SEED = 0
KEY = jax.random.PRNGKey(SEED)

m = 0.01 # inertia
gamma = 0.1 # viscosity

amplitude = 0.42 # amplitude of the driving force
omega = 1 # frequency of the driving force
drive_period = 2 * jnp.pi / omega 

alpha = -1 # linear spring constant
beta = 1 # cubic spring constant

sigma = 0.123 # noise intensity

x0 = 1.0 # initial position
v0 = 0.0 # initial velocity
state0 = jnp.array([x0, v0])

t_min = 0.0
t_max = 2**(10) * drive_period
dt = 2 **(-8) * drive_period

def functional_duffing(t: float, state: jnp.array,
                    args: list[float])->jnp.array:
    x,v = state
    dx = v
    
    gamma, alpha, beta, amplitude, omega, m = args

    driving = amplitude * jnp.cos(omega * t)
    damping = gamma * v
    spring = alpha * x + beta * x ** 3
    dv = (driving - damping - spring)/m

    dstate = jnp.array([dx, dv])
    return dstate

KEY, noise_key = jax.random.split(KEY)
term = dfx.ODETerm(functional_duffing)
args = [gamma, alpha, beta, amplitude, omega, m]

brownian_noise= dfx.VirtualBrownianTree(t_min, t_max, tol=1e-3, shape=(), key=noise_key)
def noise(t, y, args):
    return jnp.array([0, sigma])

noise_term = dfx.ControlTerm(noise, brownian_noise)
terms = dfx.MultiTerm(term, noise_term)
solver = dfx.ShARK()
saveat = dfx.SaveAt(ts = jnp.arange(t_min, t_max, dt))

begin = time.time()
sol = dfx.diffeqsolve(terms, solver, t_min, t_max, dt, state0, args, saveat=saveat, max_steps= 2**20)
end = time.time()
print(f"Elapsed time: {end-begin:.2f} s")

throws this error:

ValueError: `terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with structure diffrax._term.MultiTerm[tuple[diffrax._term.ODETerm, diffrax._term.AbstractTerm[typing.Any, diffrax._custom_types.AbstractSpaceTimeLevyArea]]]

but I am already using the multiTerm(odeTerm, controlTerm) format unless I am misunderstanding something.

Also this same simulation runs much faster in Mathematica(KloedenPlatenSchurz method), any suggestions on how to speed this up would be very helpful

Thanks for this great library!

@ParticularlyPythonicBS
Copy link
Contributor Author

The error was fixed by specifying levy_area=dfx.SpaceTimeLevyArea in the brownian noise function. Leaving this up in hopes for performance improvement suggestions and possible improvements to error message.

@patrick-kidger
Copy link
Owner

Definitely agreed that the error message could be improved. I'd be happy to take a PR on that!

As for performance, you appear to be including the compile time as well.

@patrick-kidger patrick-kidger added the question User queries label Aug 2, 2024
@ParticularlyPythonicBS
Copy link
Contributor Author

I would love to submit a PR for this!
This is the error that is currently thrown:

diffrax/diffrax/_integrate.py

Lines 1025 to 1031 in a37a276

if not _term_compatible(
y0, args, terms, solver.term_structure, solver.term_compatible_contr_kwargs
):
raise ValueError(
"`terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with "
f"structure {solver.term_structure}"
)

Should this be caught as a different error or is it better to augment this error message with a suggestion to check the levy area?

@patrick-kidger
Copy link
Owner

I think let's augment this error message.
Tagging @lockwo as I think he may have some idea on this one.

Probably we should write out something quite verbose -- in particular, what structure we actually got! And if it's the vector field / control type that goes wrong, we should call that out explicitly. (Probably we don't need to mention Levy area anywhere, that will naturally come out of a message of the form f"expected control type {foo} but got control type {bar}"

@ParticularlyPythonicBS
Copy link
Contributor Author

That sounds like a great idea, it would make that error more useful even outside the scope of SDE solvers!
I look forward to lockwo's input.

@lockwo
Copy link
Contributor

lockwo commented Aug 4, 2024

Augmenting the error message is definitely a good idea (related issues: #461, #446), the core issue currently is that the message isn't very informative about why the terms are failing. To that end, I think a straightforward augmentation would be to characterize the errors specifically inside the term checker (https://github.com/patrick-kidger/diffrax/blob/main/diffrax/_integrate.py#L119) and generate error messages based on that characterization, which would help people narrow down where the error is. Additionally, poorly formed shapes (preventing even drift.vf from running correctly) is a not uncommon error that results in this message (esp. for scalar/1D systems where you have some squeezing and unsqueezing), so raising specific errors based on if the eval_shape checks fail could be an option to.

For the levy area stuff, in general they are caught in the "expect but got format", but I think its worth making extra clear, since the default expected got would look not too dissimilar from the above where you have something like expected control term diffrax._term.AbstractTerm[typing.Any, diffrax._custom_types.AbstractSpaceTimeLevyArea]]] got diffrax._term.AbstractTerm[typing.Any, diffrax._custom_types.AbstractBrownianIncrement]]]. which is pretty clear but just adding a flag/specific text to say "This solver requires a levy area calculation, you need to add levy_area=diffrax.SpaceTimeLevyArea to your Brownian process" since I think that will be like the second most common error here.

Tangentially, the Levy Area docs could also probably be improved, they are printing a bunch of default attributes that aren't important and also some explanation of what a Levy Area is (and why they are integrals of space time or space time and time) would probably be beneficial.

Happy to take a crack at the above to show what I mean, or if you want to @ParticularlyPythonicBS also works.

@ParticularlyPythonicBS
Copy link
Contributor Author

@lockwo you seem to have expertise with the library that will let you do this much faster than I could. So I would be happy to just follow along.
If you are otherwise occupied, I am happy to take an attempt at it though.

@lockwo lockwo mentioned this issue Aug 7, 2024
@lockwo
Copy link
Contributor

lockwo commented Aug 7, 2024

My sort of idea: #478

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants