diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index f4009b16e..557924ec3 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -14,12 +14,14 @@ """Symplectic, time-reversible, integrators for Hamiltonian trajectories.""" from typing import Any, Callable, NamedTuple, Tuple +import chex import jax import jax.numpy as jnp +import jax.tree_util as jtu from jax.flatten_util import ravel_pytree from blackjax.mcmc.metrics import KineticEnergy -from blackjax.types import ArrayTree +from blackjax.types import Array, ArrayTree __all__ = [ "mclachlan", @@ -29,6 +31,7 @@ "isokinetic_leapfrog", "isokinetic_mclachlan", "isokinetic_yoshida", + "rattle", ] @@ -479,3 +482,153 @@ def _step(args: ArrayTree) -> Tuple[ArrayTree, ArrayTree]: return IntegratorState(q, p, *logdensity_and_grad_fn(q)) return one_step + + +@chex.dataclass +class NewtonState: + x: ArrayTree + delta: ArrayTree + n: chex.Scalar + aux: ArrayTree + + +def solve_newton( + func: Callable[[ArrayTree], Tuple[ArrayTree, ArrayTree]], + x0: ArrayTree, + *, + convergence_tol: float = 1e-6, + # divergence_tol: float = 1e10, + max_iters: int = 100, + # norm_fn: Callable[[ArrayTree], float] = lambda x: jnp.max(jnp.abs(x)), +): + x0arr, unflatten = ravel_pytree(x0) + + def surogate_func(x: ArrayTree): + x_tree = unflatten(x) + y, aux = func(x_tree) + y, _ = ravel_pytree(y) + return y, aux + + jf = jax.jacobian(surogate_func, has_aux=True) + + def step_fun(x: NewtonState) -> NewtonState: + J, _ = jf(x.x) + F, aux = surogate_func(x.x) + + delta = jnp.linalg.solve(J, -F) + return NewtonState( + x=x.x + delta, delta=delta, n=x.n + jnp.ones_like(x.n), aux=aux + ) + + def cond(x: NewtonState): + return jnp.logical_and( + x.n < max_iters, jnp.linalg.norm(x.delta) > convergence_tol + ) + + sol = jax.lax.while_loop( + cond, + step_fun, + NewtonState( + x=x0arr, delta=x0arr, n=jnp.zeros((), dtype=jnp.int32), aux=func(x0)[1] + ), + ) + return sol.replace(x=unflatten(sol.x), delta=unflatten(sol.delta)) + + +class RattleVars(NamedTuple): + p_1_2: Array # Midpoint momentum + q_1: Array # Final position + lam: Array # Lagrange multiplier (state) + p_1: Array # Final momentum + mu: Array # Lagrange multiplier (momentum) + + +def rattle( + logdensity_fn: Callable, + kinetic_energy_fn: KineticEnergy, + constrain_fn: Callable, + *, + solver: Callable = solve_newton, + **solver_kwargs: Any, +) -> Integrator: + """Rattle integrator. + + Symplectic method. Does not support adaptive step sizing. Uses 1st order local + linear interpolation for dense/ts output. + """ + logdensity_and_grad_fn = jax.value_and_grad(logdensity_fn) + kinetic_energy_grad_fn = jax.grad( + lambda q, p: kinetic_energy_fn(p, position=q), argnums=(0, 1) + ) + + def one_step(state: IntegratorState, step_size: float) -> IntegratorState: + q0, p0, _, _ = state + h = 0.5 * step_size + + def eq(x: RattleVars) -> tuple: + _, vjp_fun = jax.vjp(constrain_fn, q0) + _, vjp_fun_mu = jax.vjp(constrain_fn, x.q_1) + + dUdq = state.logdensity_grad + + dTdq, dHdp = kinetic_energy_grad_fn(q0, p0) + dHdq = jax.tree_util.tree_map(jnp.subtract, dTdq, dUdq) + dTdq12, dHdp12 = kinetic_energy_grad_fn(q0, p0) + + # TODOD check + dTdq12, dHdp12 = kinetic_energy_grad_fn(q0, x.p_1_2) + Uq1, dUdq1 = logdensity_and_grad_fn(x.q_1) + dHdq12 = jtu.tree_map(jnp.subtract, dTdq12, dUdq1) + + zero = ( + jtu.tree_map( + lambda _p0, _dhdq, _dcl, _p12: _p0 - h * (_dhdq + _dcl) - _p12, + p0, + dHdq, + vjp_fun(x.lam)[0], + x.p_1_2, + ), + jtu.tree_map( + lambda _q0, _dhdp0, _dhdp1, _q1: _q0 + h * (_dhdp0 + _dhdp1) - _q1, + q0, + kinetic_energy_grad_fn(q0, x.p_1_2)[1], + kinetic_energy_grad_fn(x.q_1, x.p_1_2)[1], + x.q_1, + ), + constrain_fn(x.q_1), + jtu.tree_map( + lambda _p12, _dhdq, _dc, _p1: _p12 - h * (_dhdq + _dc) - _p1, + x.p_1_2, + dHdq12, + vjp_fun_mu(x.mu)[0], + x.p_1, + ), + jax.jvp( + constrain_fn, (x.q_1,), (kinetic_energy_grad_fn(x.q_1, x.p_1)[1],) + )[1], + ) + + return zero, (Uq1, dUdq1) + + cs = jax.eval_shape(constrain_fn, q0) + + init_vars = RattleVars( + p_1_2=p0, + # TODO check better starting point + q_1=jtu.tree_map(lambda x: x, q0), + p_1=p0, + lam=jtu.tree_map(jnp.zeros_like, cs), # TODO keep this in a state + mu=jtu.tree_map(jnp.zeros_like, cs), + ) + + sol = solver(eq, init_vars, **solver_kwargs) + Uq1, dUdq1 = sol.aux + next_state = IntegratorState( + position=sol.x.q_1, + momentum=sol.x.p_1, + logdensity=Uq1, + logdensity_grad=dUdq1, + ) + return next_state + + return one_step diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index 1368a8441..c78c04f32 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -28,8 +28,10 @@ We can also generate a relativistic dynamic :cite:p:`lu2017relativistic`. """ +from functools import partial from typing import Callable, NamedTuple, Optional, Protocol, Union +import jax import jax.numpy as jnp import jax.scipy as jscipy from jax.flatten_util import ravel_pytree @@ -284,3 +286,90 @@ def is_turning( # return turning_at_left | turning_at_right return Metric(momentum_generator, kinetic_energy, is_turning) + + + +def gaussian_implicit_riemannian( + mass_matrix_fn: Callable, + constrain_fn: Callable +) -> Metric: + + def factorize_mass(position: ArrayLikeTree): + M = mass_matrix_fn(position) + cholesky = jscipy.linalg.cholesky(M, True) + inverse = jscipy.linalg.solve_triangular(cholesky.T, + jscipy.linalg.solve_triangular( + cholesky, + jnp.eye(*M.shape), + lower=True), + lower=False) + return cholesky, inverse + + @partial(jax.vmap, in_axes=(None, 1), out_axes=1) + def jmp(x, v): + """# Jacobian matrix product""" + return jax.jvp(constrain_fn, (x,), (v,))[1] + + # https://github.com/krzysztofrusek/jax_chmc/blob/d8c12e4b55b8a9877228de1c130937a971de5b52/jax_chmc/kernels.py#L81 + def momentum_generator(rng_key: PRNGKey, + position: ArrayLikeTree) -> ArrayLikeTree: + flat_position, unflaten = ravel_pytree(position) + cholesky, inverse = factorize_mass(position) + + z = jax.random.normal(rng_key, shape=flat_position.shape) + p0 = cholesky @ z + + # dc/dq . m^-1 + # Jacobian matrix product, TODO handle diagonala and scalar + D = jmp(position, inverse) + #dc = jax.jacobian(constrain_fn)(position) + #DD = dc@inverse + + #TODO check jaxopt projection here + p0 = p0 - D.T @ jnp.linalg.solve(D @ D.T, D @ p0) + return unflaten(p0) + + + # https://github.com/krzysztofrusek/jax_chmc/blob/d8c12e4b55b8a9877228de1c130937a971de5b52/jax_chmc/kernels.py#L54C1-L62C1 + def kinetic_energy( + momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None + ) -> float: + cholesky, inverse = factorize_mass(position) + flat_p, _ = ravel_pytree(momentum) + + + D = jmp(position, inverse) + cholMhat = cholesky - D.T @ jnp.linalg.solve(D @ D.T, + D @ cholesky) + d = jnp.linalg.svd(cholMhat, compute_uv=False, hermitian=True) + + def _shape_fn(position): + x,_ = ravel_pytree(position) + c,_ = ravel_pytree(constrain_fn(position)) + return (x,c) + + dc_shape = jax.eval_shape(_shape_fn, position) + + top_d, _ = jax.lax.top_k(d, dc_shape[0].shape[0] - dc_shape[1].shape[0]) + pseudo_log_det = jnp.sum(jnp.log(top_d)) + + T = flat_p.T@ inverse@flat_p/2. + + return T + pseudo_log_det + + + + def is_turning( + momentum_left: ArrayLikeTree, + momentum_right: ArrayLikeTree, + momentum_sum: ArrayLikeTree, + position_left: Optional[ArrayLikeTree] = None, + position_right: Optional[ArrayLikeTree] = None, + ) -> bool: + del momentum_left, momentum_right, momentum_sum, position_left, position_right + raise NotImplementedError( + "NUTS sampling is not yet implemented for implicitly defined " + "manifolds" + ) + + return Metric(momentum_generator, kinetic_energy, is_turning) diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index ddb13ad57..60b79f934 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -1,3 +1,4 @@ +import functools import itertools import chex @@ -10,7 +11,10 @@ from scipy.special import ellipj import blackjax.mcmc.integrators as integrators -from blackjax.mcmc.integrators import esh_dynamics_momentum_update_one_step +from blackjax.mcmc.integrators import ( + esh_dynamics_momentum_update_one_step, +) +from blackjax.mcmc.metrics import gaussian_euclidean from blackjax.util import generate_unit_vector @@ -143,6 +147,10 @@ def kinetic_energy(p, position=None): "isokinetic_leapfrog": {"algorithm": integrators.isokinetic_leapfrog}, "isokinetic_mclachlan": {"algorithm": integrators.isokinetic_mclachlan}, "isokinetic_yoshida": {"algorithm": integrators.isokinetic_yoshida}, + "rattle": { + "algorithm": functools.partial(integrators.rattle, constrain_fn=lambda x: None), + "precision": 1e-4, + }, } @@ -164,12 +172,7 @@ class IntegratorTest(chex.TestCase): "planetary_motion", "multivariate_normal", ], - [ - "velocity_verlet", - "mclachlan", - "yoshida", - "implicit_midpoint", - ], + ["velocity_verlet", "mclachlan", "yoshida", "implicit_midpoint", "rattle"], ) ) def test_euclidean_integrator(self, example_name, integrator_name): @@ -378,5 +381,100 @@ def scan_body(state, _): self.assertAlmostEqual(energy, new_energy, delta=1e-4) +class ConstrainedIntegratorTest(chex.TestCase): + @chex.all_variants(with_pmap=False) + def test_rattle(self): + @self.variant + def constrain(q): + return jnp.sqrt(jnp.sum(q**2)) - 1.0 + + p0, q0 = (jnp.asarray([1.0, 0.0]), jnp.asarray([0.0, 1.0])) + t1 = 2 * jnp.pi / 4 + n = 2**10 + dt = t1 / n + + g = gaussian_euclidean(jnp.ones(2)) + + def logdensity_fn(q): + return 0.0 + + one_step = integrators.rattle( + logdensity_fn, kinetic_energy_fn=g.kinetic_energy, constrain_fn=constrain + ) + one_step = self.variant(one_step) + + state = integrators.new_integrator_state( + position=q0, momentum=p0, logdensity_fn=logdensity_fn + ) + + final_state = jax.lax.fori_loop(0, n, lambda i, x: one_step(x, dt), state) + q1 = final_state.position + p1 = final_state.momentum + self.assertTrue( + jnp.allclose(p1, jnp.asarray([0.0, -1.0]), rtol=1e-4, atol=1e-4) + ) + self.assertTrue(jnp.allclose(q1, jnp.asarray([1.0, 0.0]), rtol=1e-4, atol=1e-4)) + ... + + @chex.all_variants(with_pmap=False) + def test_rattle_inclined_plane(self): + def constrain(q): + return jnp.sum(q) - 1.0 + + m = 1 + g = 9.81 + sqrt2 = np.sqrt(2) + a = g / sqrt2 + l = sqrt2 + + # l=a t^2/2 => t = sqrt(2 l/a) + + p0 = m * jnp.asarray([0.0, 0.0]) + q0 = jnp.asarray([0.0, 1.0]) + + t1 = np.sqrt(2 * l / a) + n = 2**10 + dt = t1 / n + + ge = gaussian_euclidean(jnp.asarray([m, m])) + + def logdensity_fn(q): + return -m * g * q[1] + + one_step = integrators.rattle( + logdensity_fn, kinetic_energy_fn=ge.kinetic_energy, constrain_fn=constrain + ) + one_step = self.variant(one_step) + + state = integrators.new_integrator_state( + position=q0, momentum=p0, logdensity_fn=logdensity_fn + ) + + final_state = jax.lax.fori_loop(0, n, lambda i, x: one_step(x, dt), state) + q1 = final_state.position + self.assertTrue(jnp.allclose(q1, jnp.asarray([1.0, 0.0]), rtol=1e-4, atol=1e-4)) + + p1 = final_state.momentum + # mgh = p^2/2m => p=m sqrt(2gh) + p_expected = jnp.asarray([jnp.sqrt(g), -jnp.sqrt(g)]) + self.assertTrue(jnp.allclose(p1, p_expected, rtol=1e-4, atol=1e-4)) + + @chex.all_variants(with_pmap=False) + def test_newton_solver_2d(self): + def f(x): + return dict(y=(x - 2.0) * (x - 3)), None + + x0 = jnp.asarray([1.0, 4.0]) + + @self.variant + def _solve(x0): + sol = integrators.solve_newton(jax.vmap(f), x0, max_iters=20) + return sol + + sol = _solve(x0) + + self.assertTrue(jnp.allclose(sol.x, jnp.asarray([2, 3]))) + + if __name__ == "__main__": absltest.main() diff --git a/tests/mcmc/test_metrics.py b/tests/mcmc/test_metrics.py index f806a375c..59fb2f8a8 100644 --- a/tests/mcmc/test_metrics.py +++ b/tests/mcmc/test_metrics.py @@ -142,6 +142,33 @@ def test_gaussian_euclidean_dim_2(self): np.testing.assert_allclose(expected_momentum_val, momentum_val) np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) +class GaussianImplicitRiemannianMetricsTest(chex.TestCase): + def setUp(self): + super().setUp() + self.key = random.key(0) + self.dtype = "float32" + + @chex.all_variants(with_pmap=False) + def test_metric(self): + condfun = lambda x: x.sum(keepdims=True) + M = 3.0 * jnp.diag(jnp.ones(4)) + + g = metrics.gaussian_implicit_riemannian(lambda q: M, condfun) + + q0 = jnp.zeros(4) + p0 = self.variant(g.sample_momentum)(self.key, q0) + v = self.variant(g.kinetic_energy)(p0,q0) + + @chex.all_variants(with_pmap=False) + def test_metric_tree(self): + condfun = lambda x: x['var'].sum(keepdims=True) + M = 3.0 * jnp.diag(jnp.ones(4)) + + g = metrics.gaussian_implicit_riemannian(lambda q: M, condfun) + + q0 = dict(var=jnp.zeros(4)) + p0 = self.variant(g.sample_momentum)(self.key, q0) + v = self.variant(g.kinetic_energy)(p0,q0) if __name__ == "__main__": absltest.main()