Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EQX_ON_ERROR=breakpoint and diffeqsolve causes tracer leakage ? #512

Open
thibmonsel opened this issue Oct 10, 2024 · 2 comments
Open

EQX_ON_ERROR=breakpoint and diffeqsolve causes tracer leakage ? #512

thibmonsel opened this issue Oct 10, 2024 · 2 comments
Labels
question User queries

Comments

@thibmonsel
Copy link
Contributor

thibmonsel commented Oct 10, 2024

Hi Patrick, i'm getting weird behavior with HEAD and its use with EQX_ON_ERROR=breakpoint.
Here is a MWE :

import jax
import jax.numpy as jnp
import diffrax

with jax.checking_leaks():

    ts = jnp.linspace(0.0, 1.0, 10)
    ys = diffrax.diffeqsolve(
        diffrax.ODETerm(lambda t, y, args: -y),
        diffrax.Bosh3(),
        t0=ts[0],
        t1=ts[-1],
        dt0=ts[1] - ts[0],
        y0=jnp.ones((1, )),
        saveat=diffrax.SaveAt(ts=ts),
    )

Launching the file once is OK.
Then if I set the bash variable export EQX_ON_ERROR=breakpoint in another terminal and re-run the MWE, the error stack outputs leaked tracers (and EQX_ON_ERROR=breakpoint doesn't open a jax.debug.breakpoint where the error arises) :

Traceback (most recent call last):
  File "/home/monsel/Desktop/sandbox_diffrax/mwe.py", line 10, in <module>
    ys = diffrax.diffeqsolve(
  File "/home/monsel/miniconda3/envs/sandbox_diffrax/lib/python3.10/site-packages/equinox/_jit.py", line 239, in __call__
    return self._call(False, args, kwargs)
  File "/home/monsel/miniconda3/envs/sandbox_diffrax/lib/python3.10/site-packages/equinox/_module.py", line 1093, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
  File "/home/monsel/miniconda3/envs/sandbox_diffrax/lib/python3.10/site-packages/equinox/_jit.py", line 212, in _call
    out = self._cached(dynamic_donate, dynamic_nodonate, static)
  File "/home/monsel/miniconda3/envs/sandbox_diffrax/lib/python3.10/site-packages/equinox/_errors.py", line 187, in fixed_jit_impl
    return jit_fun(*args2, **kwargs2)
  File "/home/monsel/miniconda3/envs/sandbox_diffrax/lib/python3.10/contextlib.py", line 142, in __exit__
    next(self.gen)
Exception: Leaked trace MainTrace(1,DynamicJaxprTrace). Leaked tracer(s):

Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/monsel/Desktop/sandbox_diffrax/mwe.py:10 (<module>)
<DynamicJaxprTracer 139338190526784> is referred to by <function 139338169797072> (_allocate_output) closed-over variable y0
<function 139338169797072> is referred to by <list 139338138048256>[11]
<list 139338138048256> is referred to by <tuple 139338137653056>[1]

Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/monsel/Desktop/sandbox_diffrax/mwe.py:10 (<module>)
<DynamicJaxprTracer 139338169959664> is referred to by <function 139338169797072> (_allocate_output) closed-over variable t0
<function 139338169797072> is referred to by <list 139338138048256>[11]
<list 139338138048256> is referred to by <tuple 139338137653056>[1]

Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/monsel/Desktop/sandbox_diffrax/mwe.py:10 (<module>)
<DynamicJaxprTracer 139338169486448> is referred to by <function 139338169797072> (_allocate_output) closed-over variable direction
<function 139338169797072> is referred to by <list 139338138048256>[11]
<list 139338138048256> is referred to by <tuple 139338137653056>[1]

Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/monsel/Desktop/sandbox_diffrax/mwe.py:10 (<module>)
<DynamicJaxprTracer 139338190525104> is referred to by <function 139338169246064> (_check_subsaveat_ts) closed-over variable t1
<function 139338169246064> is referred to by <list 139338138048256>[5]
<list 139338138048256> is referred to by <tuple 139338137653056>[1]

equinox==0.11.7
jax==0.4.34
jaxlib==0.4.34
jaxtyping==0.2.34
lineax==0.0.6
ml_dtypes==0.5.0
numpy==2.1.2
opt_einsum==3.4.0
optimistix==0.0.8
scipy==1.14.1
typeguard==2.13.3
typing_extensions==4.12.2
@patrick-kidger
Copy link
Owner

This is probably a variant of jax-ml/jax#16732.

Equinox already has a workaround for the specific reported version above (when EQX_ON_ERROR=breakpoint is set then we monkey-patch jax.jit to conditionally disable it), but I believe it can also occur for some other JAX operations, like jax.custom_vjp.

I think what this really needs is someone to fix this in JAX itself, unfortunately.

Other than that, you can try setting JAX_DISABLE_JIT=1 and sidestep the issue that way.

@patrick-kidger patrick-kidger added the question User queries label Oct 10, 2024
@thibmonsel
Copy link
Contributor Author

thibmonsel commented Oct 10, 2024

Thanks for the clear explanation ! I'll give JAX_DISABLE_JIT=1 a try.
This bug is definitely misleading.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants