You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Launching the file once is OK.
Then if I set the bash variable export EQX_ON_ERROR=breakpoint in another terminal and re-run the MWE, the error stack outputs leaked tracers (and EQX_ON_ERROR=breakpoint doesn't open a jax.debug.breakpoint where the error arises) :
Traceback (most recent call last):
File "/home/monsel/Desktop/sandbox_diffrax/mwe.py", line 10, in <module>
ys = diffrax.diffeqsolve(
File "/home/monsel/miniconda3/envs/sandbox_diffrax/lib/python3.10/site-packages/equinox/_jit.py", line 239, in __call__
return self._call(False, args, kwargs)
File "/home/monsel/miniconda3/envs/sandbox_diffrax/lib/python3.10/site-packages/equinox/_module.py", line 1093, in __call__
return self.__func__(self.__self__, *args, **kwargs)
File "/home/monsel/miniconda3/envs/sandbox_diffrax/lib/python3.10/site-packages/equinox/_jit.py", line 212, in _call
out = self._cached(dynamic_donate, dynamic_nodonate, static)
File "/home/monsel/miniconda3/envs/sandbox_diffrax/lib/python3.10/site-packages/equinox/_errors.py", line 187, in fixed_jit_impl
return jit_fun(*args2, **kwargs2)
File "/home/monsel/miniconda3/envs/sandbox_diffrax/lib/python3.10/contextlib.py", line 142, in __exit__
next(self.gen)
Exception: Leaked trace MainTrace(1,DynamicJaxprTrace). Leaked tracer(s):
Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/monsel/Desktop/sandbox_diffrax/mwe.py:10 (<module>)
<DynamicJaxprTracer 139338190526784> is referred to by <function 139338169797072> (_allocate_output) closed-over variable y0
<function 139338169797072> is referred to by <list 139338138048256>[11]
<list 139338138048256> is referred to by <tuple 139338137653056>[1]
Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/monsel/Desktop/sandbox_diffrax/mwe.py:10 (<module>)
<DynamicJaxprTracer 139338169959664> is referred to by <function 139338169797072> (_allocate_output) closed-over variable t0
<function 139338169797072> is referred to by <list 139338138048256>[11]
<list 139338138048256> is referred to by <tuple 139338137653056>[1]
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/monsel/Desktop/sandbox_diffrax/mwe.py:10 (<module>)
<DynamicJaxprTracer 139338169486448> is referred to by <function 139338169797072> (_allocate_output) closed-over variable direction
<function 139338169797072> is referred to by <list 139338138048256>[11]
<list 139338138048256> is referred to by <tuple 139338137653056>[1]
Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/monsel/Desktop/sandbox_diffrax/mwe.py:10 (<module>)
<DynamicJaxprTracer 139338190525104> is referred to by <function 139338169246064> (_check_subsaveat_ts) closed-over variable t1
<function 139338169246064> is referred to by <list 139338138048256>[5]
<list 139338138048256> is referred to by <tuple 139338137653056>[1]
Equinox already has a workaround for the specific reported version above (when EQX_ON_ERROR=breakpoint is set then we monkey-patch jax.jit to conditionally disable it), but I believe it can also occur for some other JAX operations, like jax.custom_vjp.
I think what this really needs is someone to fix this in JAX itself, unfortunately.
Other than that, you can try setting JAX_DISABLE_JIT=1 and sidestep the issue that way.
Hi Patrick, i'm getting weird behavior with HEAD and its use with
EQX_ON_ERROR=breakpoint
.Here is a MWE :
Launching the file once is OK.
Then if I set the bash variable
export EQX_ON_ERROR=breakpoint
in another terminal and re-run the MWE, the error stack outputs leaked tracers (andEQX_ON_ERROR=breakpoint
doesn't open a jax.debug.breakpoint where the error arises) :The text was updated successfully, but these errors were encountered: