Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Implement chmc #644

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 154 additions & 1 deletion blackjax/mcmc/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
"""Symplectic, time-reversible, integrators for Hamiltonian trajectories."""
from typing import Any, Callable, NamedTuple, Tuple

import chex
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from jax.flatten_util import ravel_pytree

from blackjax.mcmc.metrics import KineticEnergy
from blackjax.types import ArrayTree
from blackjax.types import Array, ArrayTree

__all__ = [
"mclachlan",
Expand All @@ -29,6 +31,7 @@
"isokinetic_leapfrog",
"isokinetic_mclachlan",
"isokinetic_yoshida",
"rattle",
]


Expand Down Expand Up @@ -479,3 +482,153 @@ def _step(args: ArrayTree) -> Tuple[ArrayTree, ArrayTree]:
return IntegratorState(q, p, *logdensity_and_grad_fn(q))

return one_step


@chex.dataclass
class NewtonState:
x: ArrayTree
delta: ArrayTree
n: chex.Scalar
aux: ArrayTree


def solve_newton(
func: Callable[[ArrayTree], Tuple[ArrayTree, ArrayTree]],
x0: ArrayTree,
*,
convergence_tol: float = 1e-6,
# divergence_tol: float = 1e10,
max_iters: int = 100,
# norm_fn: Callable[[ArrayTree], float] = lambda x: jnp.max(jnp.abs(x)),
):
x0arr, unflatten = ravel_pytree(x0)

def surogate_func(x: ArrayTree):
x_tree = unflatten(x)
y, aux = func(x_tree)
y, _ = ravel_pytree(y)
return y, aux

jf = jax.jacobian(surogate_func, has_aux=True)

def step_fun(x: NewtonState) -> NewtonState:
J, _ = jf(x.x)
F, aux = surogate_func(x.x)

delta = jnp.linalg.solve(J, -F)
return NewtonState(
x=x.x + delta, delta=delta, n=x.n + jnp.ones_like(x.n), aux=aux
)

def cond(x: NewtonState):
return jnp.logical_and(
x.n < max_iters, jnp.linalg.norm(x.delta) > convergence_tol
)

sol = jax.lax.while_loop(
cond,
step_fun,
NewtonState(
x=x0arr, delta=x0arr, n=jnp.zeros((), dtype=jnp.int32), aux=func(x0)[1]
),
)
return sol.replace(x=unflatten(sol.x), delta=unflatten(sol.delta))


class RattleVars(NamedTuple):
p_1_2: Array # Midpoint momentum
q_1: Array # Final position
lam: Array # Lagrange multiplier (state)
p_1: Array # Final momentum
mu: Array # Lagrange multiplier (momentum)


def rattle(
logdensity_fn: Callable,
kinetic_energy_fn: KineticEnergy,
constrain_fn: Callable,
*,
solver: Callable = solve_newton,
**solver_kwargs: Any,
) -> Integrator:
"""Rattle integrator.

Symplectic method. Does not support adaptive step sizing. Uses 1st order local
linear interpolation for dense/ts output.
"""
logdensity_and_grad_fn = jax.value_and_grad(logdensity_fn)
kinetic_energy_grad_fn = jax.grad(
lambda q, p: kinetic_energy_fn(p, position=q), argnums=(0, 1)
)

def one_step(state: IntegratorState, step_size: float) -> IntegratorState:
q0, p0, _, _ = state
h = 0.5 * step_size

def eq(x: RattleVars) -> tuple:
_, vjp_fun = jax.vjp(constrain_fn, q0)
_, vjp_fun_mu = jax.vjp(constrain_fn, x.q_1)

dUdq = state.logdensity_grad

dTdq, dHdp = kinetic_energy_grad_fn(q0, p0)
dHdq = jax.tree_util.tree_map(jnp.subtract, dTdq, dUdq)
dTdq12, dHdp12 = kinetic_energy_grad_fn(q0, p0)

# TODOD check
dTdq12, dHdp12 = kinetic_energy_grad_fn(q0, x.p_1_2)
Uq1, dUdq1 = logdensity_and_grad_fn(x.q_1)
dHdq12 = jtu.tree_map(jnp.subtract, dTdq12, dUdq1)

zero = (
jtu.tree_map(
lambda _p0, _dhdq, _dcl, _p12: _p0 - h * (_dhdq + _dcl) - _p12,
p0,
dHdq,
vjp_fun(x.lam)[0],
x.p_1_2,
),
jtu.tree_map(
lambda _q0, _dhdp0, _dhdp1, _q1: _q0 + h * (_dhdp0 + _dhdp1) - _q1,
q0,
kinetic_energy_grad_fn(q0, x.p_1_2)[1],
kinetic_energy_grad_fn(x.q_1, x.p_1_2)[1],
x.q_1,
),
constrain_fn(x.q_1),
jtu.tree_map(
lambda _p12, _dhdq, _dc, _p1: _p12 - h * (_dhdq + _dc) - _p1,
x.p_1_2,
dHdq12,
vjp_fun_mu(x.mu)[0],
x.p_1,
),
jax.jvp(
constrain_fn, (x.q_1,), (kinetic_energy_grad_fn(x.q_1, x.p_1)[1],)
)[1],
)

return zero, (Uq1, dUdq1)

cs = jax.eval_shape(constrain_fn, q0)

init_vars = RattleVars(
p_1_2=p0,
# TODO check better starting point
q_1=jtu.tree_map(lambda x: x, q0),
p_1=p0,
lam=jtu.tree_map(jnp.zeros_like, cs), # TODO keep this in a state
mu=jtu.tree_map(jnp.zeros_like, cs),
)

sol = solver(eq, init_vars, **solver_kwargs)
Uq1, dUdq1 = sol.aux
next_state = IntegratorState(
position=sol.x.q_1,
momentum=sol.x.p_1,
logdensity=Uq1,
logdensity_grad=dUdq1,
)
return next_state

return one_step
89 changes: 89 additions & 0 deletions blackjax/mcmc/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
We can also generate a relativistic dynamic :cite:p:`lu2017relativistic`.

"""
from functools import partial
from typing import Callable, NamedTuple, Optional, Protocol, Union

import jax
import jax.numpy as jnp
import jax.scipy as jscipy
from jax.flatten_util import ravel_pytree
Expand Down Expand Up @@ -284,3 +286,90 @@ def is_turning(
# return turning_at_left | turning_at_right

return Metric(momentum_generator, kinetic_energy, is_turning)



def gaussian_implicit_riemannian(
mass_matrix_fn: Callable,
constrain_fn: Callable
) -> Metric:

def factorize_mass(position: ArrayLikeTree):
M = mass_matrix_fn(position)
cholesky = jscipy.linalg.cholesky(M, True)
inverse = jscipy.linalg.solve_triangular(cholesky.T,
jscipy.linalg.solve_triangular(
cholesky,
jnp.eye(*M.shape),
lower=True),
lower=False)
return cholesky, inverse

@partial(jax.vmap, in_axes=(None, 1), out_axes=1)
def jmp(x, v):
"""# Jacobian matrix product"""
return jax.jvp(constrain_fn, (x,), (v,))[1]

# https://github.com/krzysztofrusek/jax_chmc/blob/d8c12e4b55b8a9877228de1c130937a971de5b52/jax_chmc/kernels.py#L81
def momentum_generator(rng_key: PRNGKey,
position: ArrayLikeTree) -> ArrayLikeTree:
flat_position, unflaten = ravel_pytree(position)
cholesky, inverse = factorize_mass(position)

z = jax.random.normal(rng_key, shape=flat_position.shape)
p0 = cholesky @ z

# dc/dq . m^-1
# Jacobian matrix product, TODO handle diagonala and scalar
D = jmp(position, inverse)
#dc = jax.jacobian(constrain_fn)(position)
#DD = dc@inverse

#TODO check jaxopt projection here
p0 = p0 - D.T @ jnp.linalg.solve(D @ D.T, D @ p0)
return unflaten(p0)


# https://github.com/krzysztofrusek/jax_chmc/blob/d8c12e4b55b8a9877228de1c130937a971de5b52/jax_chmc/kernels.py#L54C1-L62C1
def kinetic_energy(
momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
) -> float:
cholesky, inverse = factorize_mass(position)
flat_p, _ = ravel_pytree(momentum)


D = jmp(position, inverse)
cholMhat = cholesky - D.T @ jnp.linalg.solve(D @ D.T,
D @ cholesky)
d = jnp.linalg.svd(cholMhat, compute_uv=False, hermitian=True)

def _shape_fn(position):
x,_ = ravel_pytree(position)
c,_ = ravel_pytree(constrain_fn(position))
return (x,c)

dc_shape = jax.eval_shape(_shape_fn, position)

top_d, _ = jax.lax.top_k(d, dc_shape[0].shape[0] - dc_shape[1].shape[0])
pseudo_log_det = jnp.sum(jnp.log(top_d))

T = flat_p.T@ inverse@flat_p/2.

return T + pseudo_log_det



def is_turning(
momentum_left: ArrayLikeTree,
momentum_right: ArrayLikeTree,
momentum_sum: ArrayLikeTree,
position_left: Optional[ArrayLikeTree] = None,
position_right: Optional[ArrayLikeTree] = None,
) -> bool:
del momentum_left, momentum_right, momentum_sum, position_left, position_right
raise NotImplementedError(
"NUTS sampling is not yet implemented for implicitly defined "
"manifolds"
)

return Metric(momentum_generator, kinetic_energy, is_turning)
Loading