Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Implement Full-Rank VI #720

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
7 changes: 7 additions & 0 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .smc import adaptive_tempered
from .smc import inner_kernel_tuning as _inner_kernel_tuning
from .smc import tempered
from .vi import fullrank_vi as _fullrank_vi
from .vi import meanfield_vi as _meanfield_vi
from .vi import pathfinder as _pathfinder
from .vi import schrodinger_follmer as _schrodinger_follmer
Expand Down Expand Up @@ -131,6 +132,12 @@ def generate_top_level_api_from(module):
svgd = generate_top_level_api_from(_svgd)

# variational inference
fullrank_vi = GenerateVariationalAPI(
_fullrank_vi.as_top_level_api,
_fullrank_vi.init,
_fullrank_vi.step,
_fullrank_vi.sample
)
meanfield_vi = GenerateVariationalAPI(
_meanfield_vi.as_top_level_api,
_meanfield_vi.init,
Expand Down
4 changes: 2 additions & 2 deletions blackjax/vi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import meanfield_vi, pathfinder, schrodinger_follmer, svgd
from . import fullrank_vi, meanfield_vi, pathfinder, schrodinger_follmer, svgd

__all__ = ["pathfinder", "meanfield_vi", "svgd", "schrodinger_follmer"]
__all__ = ["fullrank_vi", "meanfield_vi", "pathfinder", "svgd", "schrodinger_follmer"]
165 changes: 165 additions & 0 deletions blackjax/vi/fullrank_vi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp
import jax.scipy as jsp
from optax import GradientTransformation, OptState

from blackjax.base import VIAlgorithm
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey

__all__ = [
"FRVIState",
"FRVIInfo",
"sample",
"generate_fullrank_logdensity",
"step",
"as_top_level_api",
]


class FRVIState(NamedTuple):
mu: ArrayTree
rho: ArrayTree
L: ArrayTree
opt_state: OptState


class FRVIInfo(NamedTuple):
elbo: float


def init(
position: ArrayLikeTree,
optimizer: GradientTransformation,
*optimizer_args,
**optimizer_kwargs,
) -> FRVIState:
"""Initialize the full-rank VI state."""
mu = jax.tree.map(jnp.zeros_like, position)
rho = jax.tree.map(jnp.zeros_like, position)
L = jax.tree.map(lambda x: jnp.zeros((*x.shape, x.shape)), position)
opt_state = optimizer.init((mu, rho, L))
return FRVIState(mu, rho, L, opt_state)


def step(
rng_key: PRNGKey,
state: FRVIState,
logdensity_fn: Callable,
optimizer: GradientTransformation,
num_samples: int = 5,
stl_estimator: bool = True,
) -> tuple[FRVIState, FRVIInfo]:
"""Approximate the target density using the full-rank approximation.

Parameters
----------
rng_key
Key for JAX's pseudo-random number generator.
init_state
Initial state of the full-rank approximation.
logdensity_fn
Function that represents the target log-density to approximate.
optimizer
Optax `GradientTransformation` to be used for optimization.
num_samples
The number of samples that are taken from the approximation
at each step to compute the Kullback-Leibler divergence between
the approximation and the target log-density.
stl_estimator
Whether to use stick-the-landing (STL) gradient estimator :cite:p:`roeder2017sticking` for gradient estimation.
The STL estimator has lower gradient variance by removing the score function term
from the gradient. It is suggested by :cite:p:`agrawal2020advances` to always keep it in order for better results.

"""

parameters = (state.mu, state.rho, state.L)

def kl_divergence_fn(parameters):
mu, rho, L = parameters
z = _sample(rng_key, mu, rho, L, num_samples)
if stl_estimator:
parameters = jax.tree_map(jax.lax.stop_gradient, (mu, rho, L))
logq = jax.vmap(generate_fullrank_logdensity(mu, rho, L))(z)
logp = jax.vmap(logdensity_fn)(z)
return (logq - logp).mean()

elbo, elbo_grad = jax.value_and_grad(kl_divergence_fn)(parameters)
updates, new_opt_state = optimizer.update(elbo_grad, state.opt_state, parameters)
new_parameters = jax.tree.map(lambda p, u: p + u, parameters, updates)

new_mu, new_rho, new_L = new_parameters
return FRVIState(new_mu, new_rho, new_L, new_opt_state), FRVIInfo(elbo)


def sample(rng_key: PRNGKey, state: FRVIState, num_samples: int = 1):
"""Sample from the full-rank approximation."""
return _sample(rng_key, state.mu, state.rho, state.L, num_samples)


def as_top_level_api(
logdensity_fn: Callable,
optimizer: GradientTransformation,
num_samples: int = 100,
):
"""High-level implementation of Full-Rank Variational Inference.

Parameters
----------
logdensity_fn
A function that represents the log-density function associated with
the distribution we want to sample from.
optimizer
Optax optimizer to use to optimize the ELBO.
num_samples
Number of samples to take at each step to optimize the ELBO.

Returns
-------
A ``VIAlgorithm``.

"""

def init_fn(position: ArrayLikeTree):
return init(position, optimizer)

def step_fn(rng_key: PRNGKey, state: FRVIState) -> tuple[FRVIState, FRVIInfo]:
return step(rng_key, state, logdensity_fn, optimizer, num_samples)

def sample_fn(rng_key: PRNGKey, state: FRVIState, num_samples: int):
return sample(rng_key, state, num_samples)

return VIAlgorithm(init_fn, step_fn, sample_fn)


def _sample(rng_key, mu, rho, L, num_samples):
cholesky = jnp.tril(L, k=-1) + jnp.diag(jnp.exp(L))
eps = jax.random.normal(rng_key, (num_samples,) + mu.shape)
return mu + eps @ cholesky.T


def generate_fullrank_logdensity(mu, rho, L):
cholesky = jnp.tril(L, k=-1) + jnp.diag(jnp.exp(L))
log_det = 2 * jnp.sum(rho)
const = -0.5 * mu.shape[-1] * jnp.log(2 * jnp.pi)

def fullrank_logdensity(position):
y = jsp.linalg.solve_triangular(cholesky, position - mu, lower=True)
mahalanobis_dist = jnp.sum(y ** 2, axis=-1)
return const - 0.5 * log_det - 0.5 * mahalanobis_dist

return fullrank_logdensity
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use multivariate_normal.logpdf from JAX?

Copy link
Contributor Author

@gil2rok gil2rok Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to avoid computing the inverse and log determinant of the covariance matrix $\Sigma$ by using the Cholesky factor when computing the logpdf.

Does jax.random.multivariate_normal.logpdf take Cholesky factors as input? I want to avoid needing to compute the covariance matrix $\Sigma = C C^T$, and then pass it into JAX's multivariate normal which separates it back into the Cholesky factor $C$.

From https://jax.readthedocs.io/en/latest/_autosummary/jax.random.multivariate_normal.html it appears the multivariate normal log density only accepts the covariance as a dense matrix!

Screenshot 2024-08-15 at 2 21 17 PM

See jax-ml/jax#11386. Thoughts on tradeoff btwn readability (with JAX's multivariate normal) and speed (custom implementation)?

49 changes: 49 additions & 0 deletions tests/vi/test_fullrank_vi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import chex
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import optax
from absl.testing import absltest

import blackjax


class FullRankVITest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(42)

@chex.variants(with_jit=True, without_jit=True)
def test_recover_posterior(self):
ground_truth = [
# loc, scale
(2, 4),
(3, 5),
]

def logdensity_fn(x):
logpdf = stats.norm.logpdf(x["x_1"], *ground_truth[0]) + stats.norm.logpdf(
x["x_2"], *ground_truth[1]
)
return jnp.sum(logpdf)

initial_position = {"x_1": 0.0, "x_2": 0.0}

num_steps = 50_000
num_samples = 500

optimizer = optax.sgd(1e-2)
frvi = blackjax.fullrank_vi(logdensity_fn, optimizer, num_samples)
state = frvi.init(initial_position)

rng_key = self.key
for i in range(num_steps):
subkey = jax.random.split(rng_key, i)
state, _ = self.variant(frvi.step)(subkey, state)

loc_1, loc_2 = state.mu["x_1"], state.mu["x_2"]
self.assertAlmostEqual(loc_1, ground_truth[0][0], delta=0.01)
self.assertAlmostEqual(loc_2, ground_truth[1][0], delta=0.01)

if __name__ == "__main__":
absltest.main()
Loading