Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request: Complex-Valued Integration With ZVODE - CVODE in Jax (autodiff) #477

Open
onurdanaci opened this issue Aug 6, 2024 · 5 comments
Labels
question User queries

Comments

@onurdanaci
Copy link

Hi,

Unfortunately, all the available (S)ODE integration subroutines in auto-differentiable Python frameworks (RK45, Dopri, etc.) behave very poorly with complex-valued functions [*]. In the Python ecosystem, only Scipy's Fortran wrappers titled ode (ZVODE) and complex_ode (using CVODE) seem to be working fine, but obviously, they are not differentiable and not applicable to the modern applications we love.

I was wondering if anybody wants to adapt these features to Diffrax, and make them auto-differentiable.

[*] https://arxiv.org/abs/2406.06361

@patrick-kidger
Copy link
Owner

We have some limited support for complex numbers in Diffrax. In particular I think all of the explicit solvers (Tsit5, Dopri etc.) should behave correctly. Glancing at the paper I can see they briefly mention Diffrax, but apparently indicate they had some difficulty getting reverse-mode working. I've not seen a bug report from them though so there's not much I can do until then. 🤷

More importantly though, I believe this whole thing is essentially a non-issue. It's trivial to make any real integrator work with complex numbers: just split into real and imaginary parts before passing your initial condition into the solver, and then combine them back together inside your vector field. Job done.

@patrick-kidger patrick-kidger added the question User queries label Aug 6, 2024
@sriharikrishna
Copy link

Hi. Thanks to @onurdanaci for asking the question and to @lockwo for pointing this question out to me. Apologies to @patrick-kidger for not posting the issue earlier (I am the author of the document mentioned above).

I have an MWE below. I would be happy to be told that this issue is minor or that I am using Diffrax incorrectly.

import diffrax
from diffrax import diffeqsolve, ODETerm, Tsit5
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)

def solver(y0, t_f, A, use_direct):
  def ode_fn(t, y, B):
    return jnp.matmul(B[0],y)
  term = ODETerm(ode_fn)
  ODEsolver = Tsit5()
  solver_args = dict(t0=0.0, t1=t_f.real, dt0=0.2, y0=y0, args=(A,))
  #Required for forward mode only
  if use_direct == True:
    solver_args |= dict(adjoint=diffrax.DirectAdjoint())
  solution = diffrax.diffeqsolve(term, ODEsolver, **solver_args)
  return solution.ys[0]

def driver(params, use_direct):
  #Create y0 from params
  time = params[0] * jnp.pi
  cos = jnp.cos(time / 2)
  sin = jnp.sin(time / 2)
  axis_angle = params[1] * jnp.pi
  KET_0 = jnp.array([1, 0], dtype=jnp.complex128)  # |0>, spin up
  KET_1 = jnp.array([0, 1], dtype=jnp.complex128)  # |1>, spin down
  y0 = cos * KET_1 - 1j * jnp.exp(-1j * axis_angle) * sin * KET_0

  A = jnp.array([[0-1j, 1.0+2j],
               [- 100.0+3j, 0+4j]], dtype=jnp.complex128)
  
  #Evolve y0. Time is influenced by params
  y = solver(y0, time, A, use_direct)
  return y

params_f = jnp.array([0.5,0.4], dtype=jnp.float64)
jacfwd_fun = jax.jacfwd(driver, argnums=(0))
jac_f = jacfwd_fun(params_f, True)

#Must be complex for reverse mode
params_b = jnp.array([0.5+0j,0.4+0j], dtype=jnp.complex128)
#Must set holomorphic=True for reverse mode
jacrev_fun = jax.jacrev(driver, argnums=(0), holomorphic=True)
jac_b = jacrev_fun(params_b, False)

print(jac_f-jac_b)

This generates

/usr/local/lib/python3.10/dist-packages/equinox/_jit.py:51: UserWarning: Complex dtype support is work in progress, please read https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully.
  out = fun(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/equinox/_jit.py:51: UserWarning: Complex dtype support is work in progress, please read https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully.
  out = fun(*args, **kwargs)
[[7.91624188e-09+1.36585934e+06j 1.16415322e-09-2.06637196e-09j]
 [4.65661287e-08+6.69479738e+06j 1.54832378e-08-1.45519152e-11j]]

The problem might be that params influences the initial state of the solver y0 as well the time t1.

Thanks for your attention and help!

@onurdanaci
Copy link
Author

onurdanaci commented Aug 23, 2024

Dear Patrick @patrick-kidger ,

Thank you for your answer. Indeed I can transform my complex valued system of equations into:

`dvdt = M @ v
v = vreal + 1j* vimag
M = Mreal + 1j*Mimag

d([vreal; vimag]) = [[Mreal, - Mimag];[Mimag, Mreal]] @ [vreal;vimag]

`

Then combine these two vector fields in post-processing. Of course it would have been much more convenient for the Quantum Technologies communities to have these features are pre-defined in libraries. But, I agree that this part is a non-issue. However, I am still suspicious.

Because the Scipy's VODE subroutine, which was inherited from Fortran libraries, use multi-step implicit Adams methods such as Adams-Moulton method for non-stiff problems, and BDF for stiff problems. I couldn't parse all the archaic Fortran code but my suspicion is that Scipy's ZVODE just use this VODE library by implementing your vector-field trick.

I have doubts, based on some small (but not systematic, elaborate or conclusive at any metric) numerical experiments and the paper that I shared before, that the cream de la cream explicit Runge-Kutta methods y'all provide such as Tsit5 and Dopri5 would be as good for the said non-stiff quantum problems as implicit Adams. Or, KenCarp4 would be as good as BDF for stiff problems. Maybe I am wrong. I will need to use them on some important unit tests to make sure that I do not get non-physical results. I will get back to you.

@patrick-kidger
Copy link
Owner

Thank you @sriharikrishna for the MWE! That's really useful. I'm going to tag @Randl as our resident complex autodiff expert. Any thoughts?

Other than that, thank you @onurdanaci for your write-up above! I'd like it if Diffrax could be useful to you regardless :)

@Randl
Copy link
Contributor

Randl commented Aug 25, 2024

@sriharikrishna
Isn't the mismatch since, in the first case, you calculate the gradient with respect to a real parameter, which is automatically real, and in the second case, the gradient is with respect to a complex parameter, thus it also has an imaginary part? I've tried running check_grads for a function equivalent to yours:

@pytest.mark.parametrize(
    "solver",
    [
        diffrax.Tsit5(),
    ],
)
def test_grad_complex(solver):

    def ode_fn(t, y, B):
        return jnp.matmul(B[0], y)

    term = ODETerm(ode_fn)
    @partial(jax.jit)
    def driver(pt, ang):
        # Create y0 from params
        time = pt * jnp.pi
        cos = jnp.cos(time / 2)
        sin = jnp.sin(time / 2)
        axis_angle = ang * jnp.pi
        KET_0 = jnp.array([1, 0], dtype=jnp.complex128)  # |0>, spin up
        KET_1 = jnp.array([0, 1], dtype=jnp.complex128)  # |1>, spin down
        y0 = cos * KET_1 - 1j * jnp.exp(-1j * axis_angle) * sin * KET_0
        jax.debug.print("{y0}",y0=y0)

        A = jnp.array([[0 - 1j, 1.0 + 2j],
                       [- 100.0 + 3j, 0 + 4j]], dtype=jnp.complex128)

        solver_args = dict(t0=0.0, t1=time.real, dt0=0.2, y0=y0, args=(A,))
        # # Required for forward mode only
        # if use_direct == True:
        solver_args |= dict(adjoint=diffrax.DirectAdjoint())
        # Evolve y0. Time is influenced by params
        solution = diffrax.diffeqsolve(term, solver, **solver_args)
        return solution.ys[0]


    # check_grads(driver, (0.5,0.4), order=2, modes=["fwd"])
    check_grads(driver, (0.5+0.j,0.4+0.j), order=2, modes=["rev"], atol=1e15)

Up to the fact that absolute differences are huge in rev case, I couldn't see a fail. If you could point out the mismatch vs numerical gradients (alternatively, there may be bug in the solver itself, which makes both analytic and numeric gradients wrong), that'd be helpful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

4 participants