Skip to content

Commit

Permalink
Add pre-conditioning matrix to Barker proposal (#731)
Browse files Browse the repository at this point in the history
* Draft pre-conditioning matrix in Barker proposal.

This is a first draft of adding the pre-conditioning to the Barker
proposal. This follows Algorithms 4 and 5 in Appendix G of the original
Barker proposal paper. It's somewhat unclear from the paper, but the
separate step size that was already implemented serves as a global
scale for the normal distribution of the proposal. The function
`_compute_acceptance_probability` now takes in the transpose sqrt mass
matrix and the inverse, also it has been flattened to accomodate
the corresponding matrix multiplicatios.

* Fix typing of inverse_mass_matrix argument
Fix typing of mass matrix.

* Fix docstrings.

The original docstring of step_size was incorrect, there is no
sympletic integrator.

* Make test for Barker in test_sampling run again

We make this possible by adding an identity pre-conditining matrix,
which should make the test run in the same way as before.

* Add test to ensure correctness of precond matrix

We add a new test to barker.py to ensure that our implementation of
the preconditioning matrix is correct. We follow Appendix G in the
paper that mentions that algorithm 4 and 5 (which we implemented)
should be equivalent to rescaling the parameters and the logdensity
in a specific way. We implement both approaches when using the barker
proposal to infer the mean and sigma of a normal distribution. We
check that with two different random seeds the chains outputted are
equivalent up to some tolerance.

We also patch the original test in this file by adding an identity
mass matrix.

* Fix dimensionality of identity matrix

* Add missing mass matrix in missing tests.

* added option to transpose the matrix when scaling

option to transpose the mass_matrix_sqrt or inv_mass_matrix_sqrt was
necessary for the barker algorithm as far as I can tell. This has not
been propagated to the riemannian metric

* use the metric scaling function in barker

Here we use the new metric.scale function to perform the operations
required by the Barker proposal algorithm, instead of passing around
the mass_matrix_sqrt and inv_mass_matrix_sqrt directly. We also
make the `inverse_mass_matrix` argument optional to avoid breaking
the API.

* update test_sampling with barker api

the mass matrix is now an optional argument in barker.

* update test_barker so it works with metric.scale

* fix tests add trans to scale

* add trans argument to riemannian scaling

* no default

* Update barker.py

Make acceptance function metric agnostic

* Update test_barker.py

Add invariance test

* simplify logic to remove _barker_sample_nd

* fix bug so now everything is tree_mapped in barker

* fix test to not use _barker_sample_nd

* Update blackjax/mcmc/metrics.py

make inv and trans required kwarg with type bool in metric.scale

Co-authored-by: Junpeng Lao <[email protected]>

* Update blackjax/mcmc/metrics.py

lax.cond might not be needed in metric.scale as inv and trans are static kwarg

Co-authored-by: Junpeng Lao <[email protected]>

* propagate changes of inv, trans as required kwarg

* fix test metrics

---------

Co-authored-by: Adrien Corenflos <[email protected]>
Co-authored-by: Junpeng Lao <[email protected]>
  • Loading branch information
3 people authored Oct 8, 2024
1 parent 5a25352 commit b107f9f
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 93 deletions.
146 changes: 79 additions & 67 deletions blackjax/mcmc/barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
from jax.scipy import stats
from jax.tree_util import tree_leaves, tree_map

import blackjax.mcmc.metrics as metrics
from blackjax.base import SamplingAlgorithm
from blackjax.mcmc.metrics import Metric
from blackjax.mcmc.proposal import static_binomial_sampling
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey
from blackjax.types import ArrayLikeTree, ArrayTree, Numeric, PRNGKey
from blackjax.util import generate_gaussian_noise

__all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "as_top_level_api"]

Expand Down Expand Up @@ -81,44 +83,70 @@ def build_kernel():
"""

def _compute_acceptance_probability(
state: BarkerState,
proposal: BarkerState,
) -> float:
state: BarkerState, proposal: BarkerState, metric: Metric
) -> Numeric:
"""Compute the acceptance probability of the Barker's proposal kernel."""

def ratio_proposal_nd(y, x, log_y, log_x):
num = -_log1pexp(-log_y * (x - y))
den = -_log1pexp(-log_x * (y - x))
x = state.position
y = proposal.position
log_x = state.logdensity_grad
log_y = proposal.logdensity_grad

return jnp.sum(num - den)
y_minus_x = jax.tree_util.tree_map(lambda a, b: a - b, y, x)
x_minus_y = jax.tree_util.tree_map(lambda a: -a, y_minus_x)
z_tilde_x_to_y = metric.scale(x, y_minus_x, inv=True, trans=True)
z_tilde_y_to_x = metric.scale(y, x_minus_y, inv=True, trans=True)

ratios_proposals = tree_map(
ratio_proposal_nd,
proposal.position,
state.position,
proposal.logdensity_grad,
state.logdensity_grad,
c_x_to_y = metric.scale(x, log_x, inv=False, trans=True)
c_y_to_x = metric.scale(y, log_y, inv=False, trans=True)

z_tilde_x_to_y_flat, _ = ravel_pytree(z_tilde_x_to_y)
z_tilde_y_to_x_flat, _ = ravel_pytree(z_tilde_y_to_x)

c_x_to_y_flat, _ = ravel_pytree(c_x_to_y)
c_y_to_x_flat, _ = ravel_pytree(c_y_to_x)

num = metric.kinetic_energy(x_minus_y, y) - _log1pexp(
-z_tilde_y_to_x_flat * c_y_to_x_flat
)
ratio_proposal = sum(tree_leaves(ratios_proposals))
denom = metric.kinetic_energy(y_minus_x, x) - _log1pexp(
-z_tilde_x_to_y_flat * c_x_to_y_flat
)

ratio_proposal = jnp.sum(num - denom)

return proposal.logdensity - state.logdensity + ratio_proposal

def kernel(
rng_key: PRNGKey, state: BarkerState, logdensity_fn: Callable, step_size: float
rng_key: PRNGKey,
state: BarkerState,
logdensity_fn: Callable,
step_size: float,
inverse_mass_matrix: metrics.MetricTypes | None = None,
) -> tuple[BarkerState, BarkerInfo]:
"""Generate a new sample with the MALA kernel."""
"""Generate a new sample with the Barker kernel."""
if inverse_mass_matrix is None:
p, _ = ravel_pytree(state.position)
(m,) = p.shape
inverse_mass_matrix = jnp.ones((m,))
metric = metrics.default_metric(inverse_mass_matrix)
grad_fn = jax.value_and_grad(logdensity_fn)

key_sample, key_rmh = jax.random.split(rng_key)

proposed_pos = _barker_sample(
key_sample, state.position, state.logdensity_grad, step_size
key_sample,
state.position,
state.logdensity_grad,
step_size,
metric,
)

proposed_logdensity, proposed_logdensity_grad = grad_fn(proposed_pos)
proposed_state = BarkerState(
proposed_pos, proposed_logdensity, proposed_logdensity_grad
)

log_p_accept = _compute_acceptance_probability(state, proposed_state)
log_p_accept = _compute_acceptance_probability(state, proposed_state, metric)
accepted_state, info = static_binomial_sampling(
key_rmh, log_p_accept, state, proposed_state
)
Expand All @@ -131,6 +159,7 @@ def kernel(
def as_top_level_api(
logdensity_fn: Callable,
step_size: float,
inverse_mass_matrix: metrics.MetricTypes | None = None,
) -> SamplingAlgorithm:
"""Implements the (basic) user interface for the Barker's proposal :cite:p:`Livingstone2022Barker` kernel with a
Gaussian base kernel.
Expand Down Expand Up @@ -174,7 +203,9 @@ def as_top_level_api(
logdensity_fn
The log-density function we wish to draw samples from.
step_size
The value to use for the step size in the symplectic integrator.
The value of the step_size correspnoding to the global scale of the proposal distribution.
inverse_mass_matrix
The inverse mass matrix to use for pre-conditioning (see Appendix G of :cite:p:`Livingstone2022Barker`).
Returns
-------
Expand All @@ -189,74 +220,55 @@ def init_fn(position: ArrayLikeTree, rng_key=None):
return init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
return kernel(rng_key, state, logdensity_fn, step_size)
return kernel(rng_key, state, logdensity_fn, step_size, inverse_mass_matrix)

return SamplingAlgorithm(init_fn, step_fn)


def _barker_sample_nd(key, mean, a, scale):
"""
Sample from a multivariate Barker's proposal distribution. In 1D, this has the following probability density function:
.. math::
p(x; \\mu, a, \\sigma) = 2 \frac{N(x; \\mu, \\sigma^2)}{1 + \\exp(-a (x - \\mu)}
def _generate_bernoulli(
rng_key: PRNGKey, position: ArrayLikeTree, p: ArrayLikeTree
) -> ArrayTree:
pos, unravel_fn = ravel_pytree(position)
p_flat, _ = ravel_pytree(p)
sample = jax.random.bernoulli(rng_key, p=p_flat, shape=pos.shape)
return unravel_fn(sample)

where :math:`N(x; \\mu, \\sigma^2)` is the normal distribution with mean :math:`\\mu` and standard deviation :math:`\\sigma`.
The multivariate Barker's proposal distribution is the product of one-dimensional Barker's proposal distributions.

def _barker_sample(key, mean, a, scale, metric):
r"""
Sample from a multivariate Barker's proposal distribution for PyTrees.
Parameters
----------
key
A PRNG key.
mean
The mean of the normal distribution, an Array. This corresponds to :math:`\\mu` in the equation above.
The mean of the normal distribution, a PyTree. This corresponds to :math:`\mu` in the equation above.
a
The parameter :math:`a` in the equation above, an Array. This is a skewness parameter.
The parameter :math:`a` in the equation above, the same PyTree as `mean`. This is a skewness parameter.
scale
The standard deviation of the normal distribution, a scalar. This corresponds to :math:`\\sigma` in the equation above.
The global scale, a scalar. This corresponds to :math:`\\sigma` in the equation above.
It encodes the step size of the proposal.
Returns
-------
A sample from the Barker's multidimensional proposal distribution.
metric
A `metrics.MetricTypes` object encoding the mass matrix information.
"""

key1, key2 = jax.random.split(key)
z = scale * jax.random.normal(key1, shape=mean.shape)

z = generate_gaussian_noise(key1, mean, sigma=scale)
c = metric.scale(mean, a, inv=False, trans=True)

# Sample b=1 with probability p and 0 with probability 1 - p where
# p = 1 / (1 + exp(-a * (z - mean)))
log_p = -_log1pexp(-a * z)
b = jax.random.bernoulli(key2, p=jnp.exp(log_p), shape=mean.shape)

# return mean + z if b == 1 else mean - z
return mean + b * z - (1 - b) * z

log_p = jax.tree_util.tree_map(lambda x, y: -_log1pexp(-x * y), c, z)
p = jax.tree_util.tree_map(lambda x: jnp.exp(x), log_p)
b = _generate_bernoulli(key2, mean, p=p)

def _barker_sample(key, mean, a, scale):
r"""
Sample from a multivariate Barker's proposal distribution for PyTrees.
Parameters
----------
key
A PRNG key.
mean
The mean of the normal distribution, a PyTree. This corresponds to :math:`\mu` in the equation above.
a
The parameter :math:`a` in the equation above, the same PyTree as `mean`. This is a skewness parameter.
scale
The standard deviation of the normal distribution, a scalar. This corresponds to :math:`\sigma` in the equation above.
It encodes the step size of the proposal.
"""
bz = jax.tree_util.tree_map(lambda x, y: x * y - (1 - x) * y, b, z)

flat_mean, unravel_fn = ravel_pytree(mean)
flat_a, _ = ravel_pytree(a)
flat_sample = _barker_sample_nd(key, flat_mean, flat_a, scale)
return unravel_fn(flat_sample)
return jax.tree_util.tree_map(
lambda a, b: a + b, mean, metric.scale(mean, bz, inv=False, trans=False)
)


def _log1pexp(a):
Expand Down
55 changes: 39 additions & 16 deletions blackjax/mcmc/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
"""
from typing import Callable, NamedTuple, Optional, Protocol, Union

import jax
import jax.numpy as jnp
import jax.scipy as jscipy
from jax.flatten_util import ravel_pytree
Expand Down Expand Up @@ -62,7 +61,12 @@ def __call__(

class Scale(Protocol):
def __call__(
self, position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
self,
position: ArrayLikeTree,
element: ArrayLikeTree,
*,
inv: bool,
trans: bool,
) -> ArrayLikeTree:
...

Expand Down Expand Up @@ -187,7 +191,11 @@ def is_turning(
return turning_at_left | turning_at_right

def scale(
position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
position: ArrayLikeTree,
element: ArrayLikeTree,
*,
inv: bool,
trans: bool,
) -> ArrayLikeTree:
"""Scale elements by the mass matrix.
Expand All @@ -197,10 +205,11 @@ def scale(
The current position. Not used in this metric.
elements
Elements to scale
invs
inv
Whether to scale the elements by the inverse mass matrix or the mass matrix.
If True, the element is scaled by the inverse square root mass matrix, i.e., elem <- (M^{1/2})^{-1} elem.
Same pytree structure as `elements`.
trans
whether to transpose mass matrix when scaling
Returns
-------
Expand All @@ -209,11 +218,16 @@ def scale(
"""

ravelled_element, unravel_fn = ravel_pytree(element)
scaled = jax.lax.cond(
inv,
lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element),
lambda: linear_map(mass_matrix_sqrt, ravelled_element),
)

if inv:
left_hand_side_matrix = inv_mass_matrix_sqrt
else:
left_hand_side_matrix = mass_matrix_sqrt
if trans:
left_hand_side_matrix = left_hand_side_matrix.T

scaled = linear_map(left_hand_side_matrix, ravelled_element)

return unravel_fn(scaled)

return Metric(momentum_generator, kinetic_energy, is_turning, scale)
Expand Down Expand Up @@ -279,7 +293,11 @@ def is_turning(
# return turning_at_left | turning_at_right

def scale(
position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
position: ArrayLikeTree,
element: ArrayLikeTree,
*,
inv: bool,
trans: bool,
) -> ArrayLikeTree:
"""Scale elements by the mass matrix.
Expand All @@ -298,11 +316,16 @@ def scale(
mass_matrix, is_inv=False
)
ravelled_element, unravel_fn = ravel_pytree(element)
scaled = jax.lax.cond(
inv,
lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element),
lambda: linear_map(mass_matrix_sqrt, ravelled_element),
)

if inv:
left_hand_side_matrix = inv_mass_matrix_sqrt
else:
left_hand_side_matrix = mass_matrix_sqrt
if trans:
left_hand_side_matrix = left_hand_side_matrix.T

scaled = linear_map(left_hand_side_matrix, ravelled_element)

return unravel_fn(scaled)

return Metric(momentum_generator, kinetic_energy, is_turning, scale)
Expand Down
Loading

1 comment on commit b107f9f

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'Python Benchmark with pytest-benchmark'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: b107f9f Previous: 441412a Ratio
tests/test_benchmarks.py::test_regression_hmc 0.043082865753184714 iter/sec (stddev: 0.058494683116586135) 0.0983662024930494 iter/sec (stddev: 0.12340149282010071) 2.28

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.