From cbaff4186245ece789e1391c31acda306a2b868b Mon Sep 17 00:00:00 2001 From: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Date: Wed, 14 Aug 2024 18:11:54 +0000 Subject: [PATCH 01/12] fullrank vi first draft --- blackjax/__init__.py | 7 ++ blackjax/vi/__init__.py | 4 +- blackjax/vi/fullrank_vi.py | 165 +++++++++++++++++++++++++++++++++++ tests/vi/test_fullrank_vi.py | 49 +++++++++++ 4 files changed, 223 insertions(+), 2 deletions(-) create mode 100644 blackjax/vi/fullrank_vi.py create mode 100644 tests/vi/test_fullrank_vi.py diff --git a/blackjax/__init__.py b/blackjax/__init__.py index dfdcfc545..ea07955cd 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -36,6 +36,7 @@ from .smc import adaptive_tempered from .smc import inner_kernel_tuning as _inner_kernel_tuning from .smc import tempered +from .vi import fullrank_vi as _fullrank_vi from .vi import meanfield_vi as _meanfield_vi from .vi import pathfinder as _pathfinder from .vi import schrodinger_follmer as _schrodinger_follmer @@ -131,6 +132,12 @@ def generate_top_level_api_from(module): svgd = generate_top_level_api_from(_svgd) # variational inference +fullrank_vi = GenerateVariationalAPI( + _fullrank_vi.as_top_level_api, + _fullrank_vi.init, + _fullrank_vi.step, + _fullrank_vi.sample +) meanfield_vi = GenerateVariationalAPI( _meanfield_vi.as_top_level_api, _meanfield_vi.init, diff --git a/blackjax/vi/__init__.py b/blackjax/vi/__init__.py index 44fe9760d..e9c1c8c2c 100644 --- a/blackjax/vi/__init__.py +++ b/blackjax/vi/__init__.py @@ -1,3 +1,3 @@ -from . import meanfield_vi, pathfinder, schrodinger_follmer, svgd +from . import fullrank_vi, meanfield_vi, pathfinder, schrodinger_follmer, svgd -__all__ = ["pathfinder", "meanfield_vi", "svgd", "schrodinger_follmer"] +__all__ = ["fullrank_vi", "meanfield_vi", "pathfinder", "svgd", "schrodinger_follmer"] diff --git a/blackjax/vi/fullrank_vi.py b/blackjax/vi/fullrank_vi.py new file mode 100644 index 000000000..729f37d3e --- /dev/null +++ b/blackjax/vi/fullrank_vi.py @@ -0,0 +1,165 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, NamedTuple + +import jax +import jax.numpy as jnp +import jax.scipy as jsp +from optax import GradientTransformation, OptState + +from blackjax.base import VIAlgorithm +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey + +__all__ = [ + "FRVIState", + "FRVIInfo", + "sample", + "generate_fullrank_logdensity", + "step", + "as_top_level_api", +] + + +class FRVIState(NamedTuple): + mu: ArrayTree + rho: ArrayTree + L: ArrayTree + opt_state: OptState + + +class FRVIInfo(NamedTuple): + elbo: float + + +def init( + position: ArrayLikeTree, + optimizer: GradientTransformation, + *optimizer_args, + **optimizer_kwargs, +) -> FRVIState: + """Initialize the full-rank VI state.""" + mu = jax.tree.map(jnp.zeros_like, position) + rho = jax.tree.map(jnp.zeros_like, position) + L = jax.tree.map(lambda x: jnp.zeros((*x.shape, x.shape)), position) + opt_state = optimizer.init((mu, rho, L)) + return FRVIState(mu, rho, L, opt_state) + + +def step( + rng_key: PRNGKey, + state: FRVIState, + logdensity_fn: Callable, + optimizer: GradientTransformation, + num_samples: int = 5, + stl_estimator: bool = True, +) -> tuple[FRVIState, FRVIInfo]: + """Approximate the target density using the full-rank approximation. + + Parameters + ---------- + rng_key + Key for JAX's pseudo-random number generator. + init_state + Initial state of the full-rank approximation. + logdensity_fn + Function that represents the target log-density to approximate. + optimizer + Optax `GradientTransformation` to be used for optimization. + num_samples + The number of samples that are taken from the approximation + at each step to compute the Kullback-Leibler divergence between + the approximation and the target log-density. + stl_estimator + Whether to use stick-the-landing (STL) gradient estimator :cite:p:`roeder2017sticking` for gradient estimation. + The STL estimator has lower gradient variance by removing the score function term + from the gradient. It is suggested by :cite:p:`agrawal2020advances` to always keep it in order for better results. + + """ + + parameters = (state.mu, state.rho, state.L) + + def kl_divergence_fn(parameters): + mu, rho, L = parameters + z = _sample(rng_key, mu, rho, L, num_samples) + if stl_estimator: + parameters = jax.tree_map(jax.lax.stop_gradient, (mu, rho, L)) + logq = jax.vmap(generate_fullrank_logdensity(mu, rho, L))(z) + logp = jax.vmap(logdensity_fn)(z) + return (logq - logp).mean() + + elbo, elbo_grad = jax.value_and_grad(kl_divergence_fn)(parameters) + updates, new_opt_state = optimizer.update(elbo_grad, state.opt_state, parameters) + new_parameters = jax.tree.map(lambda p, u: p + u, parameters, updates) + + new_mu, new_rho, new_L = new_parameters + return FRVIState(new_mu, new_rho, new_L, new_opt_state), FRVIInfo(elbo) + + +def sample(rng_key: PRNGKey, state: FRVIState, num_samples: int = 1): + """Sample from the full-rank approximation.""" + return _sample(rng_key, state.mu, state.rho, state.L, num_samples) + + +def as_top_level_api( + logdensity_fn: Callable, + optimizer: GradientTransformation, + num_samples: int = 100, +): + """High-level implementation of Full-Rank Variational Inference. + + Parameters + ---------- + logdensity_fn + A function that represents the log-density function associated with + the distribution we want to sample from. + optimizer + Optax optimizer to use to optimize the ELBO. + num_samples + Number of samples to take at each step to optimize the ELBO. + + Returns + ------- + A ``VIAlgorithm``. + + """ + + def init_fn(position: ArrayLikeTree): + return init(position, optimizer) + + def step_fn(rng_key: PRNGKey, state: FRVIState) -> tuple[FRVIState, FRVIInfo]: + return step(rng_key, state, logdensity_fn, optimizer, num_samples) + + def sample_fn(rng_key: PRNGKey, state: FRVIState, num_samples: int): + return sample(rng_key, state, num_samples) + + return VIAlgorithm(init_fn, step_fn, sample_fn) + + +def _sample(rng_key, mu, rho, L, num_samples): + cholesky = jnp.tril(L, k=-1) + jnp.diag(jnp.exp(L)) + eps = jax.random.normal(rng_key, (num_samples,) + mu.shape) + return mu + eps @ cholesky.T + + +def generate_fullrank_logdensity(mu, rho, L): + cholesky = jnp.tril(L, k=-1) + jnp.diag(jnp.exp(L)) + log_det = 2 * jnp.sum(rho) + const = -0.5 * mu.shape[-1] * jnp.log(2 * jnp.pi) + + def fullrank_logdensity(position): + y = jsp.linalg.solve_triangular(cholesky, position - mu, lower=True) + mahalanobis_dist = jnp.sum(y ** 2, axis=-1) + return const - 0.5 * log_det - 0.5 * mahalanobis_dist + + return fullrank_logdensity \ No newline at end of file diff --git a/tests/vi/test_fullrank_vi.py b/tests/vi/test_fullrank_vi.py new file mode 100644 index 000000000..61cd3ed34 --- /dev/null +++ b/tests/vi/test_fullrank_vi.py @@ -0,0 +1,49 @@ +import chex +import jax +import jax.numpy as jnp +import jax.scipy.stats as stats +import optax +from absl.testing import absltest + +import blackjax + + +class FullRankVITest(chex.TestCase): + def setUp(self): + super().setUp() + self.key = jax.random.PRNGKey(42) + + @chex.variants(with_jit=True, without_jit=True) + def test_recover_posterior(self): + ground_truth = [ + # loc, scale + (2, 4), + (3, 5), + ] + + def logdensity_fn(x): + logpdf = stats.norm.logpdf(x["x_1"], *ground_truth[0]) + stats.norm.logpdf( + x["x_2"], *ground_truth[1] + ) + return jnp.sum(logpdf) + + initial_position = {"x_1": 0.0, "x_2": 0.0} + + num_steps = 50_000 + num_samples = 500 + + optimizer = optax.sgd(1e-2) + frvi = blackjax.fullrank_vi(logdensity_fn, optimizer, num_samples) + state = frvi.init(initial_position) + + rng_key = self.key + for i in range(num_steps): + subkey = jax.random.split(rng_key, i) + state, _ = self.variant(frvi.step)(subkey, state) + + loc_1, loc_2 = state.mu["x_1"], state.mu["x_2"] + self.assertAlmostEqual(loc_1, ground_truth[0][0], delta=0.01) + self.assertAlmostEqual(loc_2, ground_truth[1][0], delta=0.01) + +if __name__ == "__main__": + absltest.main() \ No newline at end of file From f80b5932be993a03988c4466788068946ffe332f Mon Sep 17 00:00:00 2001 From: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Date: Fri, 16 Aug 2024 04:02:06 +0000 Subject: [PATCH 02/12] Fix: cholesky factor as flattened PyTree Define `chol_params` as a flattened Cholesky factor PyTree that consists of diagonal elements followed by the off-diagonal elements in row-major order for n = dim * (dim + 1) / 2 elements. The diagonal (first dim elements) are passed through a softplus function to ensure positivity, crucial to maintain a valid covariance matrix This parameterization allows for unconstrained optimization while ensuring the resulting covariance matrix Sigma = CC^T is symmetric and positive definite. The `chol_params` are then reshaped into a lower triangular matrix `chol_factor` using `jnp.tril` and `jnp.diag` functions. --- blackjax/__init__.py | 8 ++-- blackjax/vi/fullrank_vi.py | 76 +++++++++++++++++++++++------------- tests/vi/test_fullrank_vi.py | 3 +- 3 files changed, 55 insertions(+), 32 deletions(-) diff --git a/blackjax/__init__.py b/blackjax/__init__.py index ea07955cd..7272634b4 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -133,10 +133,10 @@ def generate_top_level_api_from(module): # variational inference fullrank_vi = GenerateVariationalAPI( - _fullrank_vi.as_top_level_api, - _fullrank_vi.init, - _fullrank_vi.step, - _fullrank_vi.sample + _fullrank_vi.as_top_level_api, + _fullrank_vi.init, + _fullrank_vi.step, + _fullrank_vi.sample, ) meanfield_vi = GenerateVariationalAPI( _meanfield_vi.as_top_level_api, diff --git a/blackjax/vi/fullrank_vi.py b/blackjax/vi/fullrank_vi.py index 729f37d3e..386f29a97 100644 --- a/blackjax/vi/fullrank_vi.py +++ b/blackjax/vi/fullrank_vi.py @@ -33,8 +33,7 @@ class FRVIState(NamedTuple): mu: ArrayTree - rho: ArrayTree - L: ArrayTree + chol_params: ArrayTree # flattened Cholesky factor opt_state: OptState @@ -50,10 +49,10 @@ def init( ) -> FRVIState: """Initialize the full-rank VI state.""" mu = jax.tree.map(jnp.zeros_like, position) - rho = jax.tree.map(jnp.zeros_like, position) - L = jax.tree.map(lambda x: jnp.zeros((*x.shape, x.shape)), position) - opt_state = optimizer.init((mu, rho, L)) - return FRVIState(mu, rho, L, opt_state) + dim = jax.flatten_util.ravel_pytree(mu)[0].shape[0] + chol_params = jax.flatten_util.ravel_pytree(jnp.tril(jnp.eye(dim)))[0] + opt_state = optimizer.init((mu, chol_params)) + return FRVIState(mu, chol_params, opt_state) def step( @@ -87,28 +86,27 @@ def step( """ - parameters = (state.mu, state.rho, state.L) + parameters = (state.mu, state.chol_params) def kl_divergence_fn(parameters): - mu, rho, L = parameters - z = _sample(rng_key, mu, rho, L, num_samples) + mu, chol_params = parameters + z = _sample(rng_key, mu, chol_params, num_samples) if stl_estimator: - parameters = jax.tree_map(jax.lax.stop_gradient, (mu, rho, L)) - logq = jax.vmap(generate_fullrank_logdensity(mu, rho, L))(z) + parameters = jax.tree_map(jax.lax.stop_gradient, (mu, chol_params)) + logq = jax.vmap(generate_fullrank_logdensity(mu, chol_params))(z) logp = jax.vmap(logdensity_fn)(z) return (logq - logp).mean() elbo, elbo_grad = jax.value_and_grad(kl_divergence_fn)(parameters) updates, new_opt_state = optimizer.update(elbo_grad, state.opt_state, parameters) new_parameters = jax.tree.map(lambda p, u: p + u, parameters, updates) - - new_mu, new_rho, new_L = new_parameters - return FRVIState(new_mu, new_rho, new_L, new_opt_state), FRVIInfo(elbo) + new_state = FRVIState(new_parameters[0], new_parameters[1], new_opt_state) + return new_state, FRVIInfo(elbo) def sample(rng_key: PRNGKey, state: FRVIState, num_samples: int = 1): """Sample from the full-rank approximation.""" - return _sample(rng_key, state.mu, state.rho, state.L, num_samples) + return _sample(rng_key, state.mu, state.chol_params, num_samples) def as_top_level_api( @@ -146,20 +144,44 @@ def sample_fn(rng_key: PRNGKey, state: FRVIState, num_samples: int): return VIAlgorithm(init_fn, step_fn, sample_fn) -def _sample(rng_key, mu, rho, L, num_samples): - cholesky = jnp.tril(L, k=-1) + jnp.diag(jnp.exp(L)) - eps = jax.random.normal(rng_key, (num_samples,) + mu.shape) - return mu + eps @ cholesky.T +def _unflatten_cholesky(chol_params): + """Construct the Cholesky factor from a flattened vector of cholesky parameters. + + Transforms a flattened vector representation of a lower triangular matrix + into a full Cholesky factor. The input vector contains n = dim * (dim + 1) / 2 + elements, where dim is the dimension of the resulting square matrix. + + The diagonal elements are passed through a softplus function to ensure (numerically + stable) positivity, crucial to maintain a valid covariance matrix parameterization. + + This parameterization allows for unconstrained optimization while ensuring the + resulting covariance matrix Sigma = CC^T is symmetric and positive definite. + """ + n = chol_params.size + dim = int(jnp.sqrt(1 + 8 * n) - 1) // 2 + tril = jnp.zeros((dim, dim)) + tril = tril.at[jnp.tril_indices(dim, k=-1)].set(chol_params[dim:]) + diag = jax.nn.softplus(chol_params[:dim]) + chol_factor = tril + jnp.diag(diag) + return chol_factor + + +def _sample(rng_key, mu, chol_params, num_samples): + mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu) + chol_factor = _unflatten_cholesky(chol_params) + eps = jax.random.normal(rng_key, (num_samples, mu_flatten.shape[0])) + flatten_sample = eps @ chol_factor.T + mu_flatten + return jax.vmap(unravel_fn)(flatten_sample) -def generate_fullrank_logdensity(mu, rho, L): - cholesky = jnp.tril(L, k=-1) + jnp.diag(jnp.exp(L)) - log_det = 2 * jnp.sum(rho) - const = -0.5 * mu.shape[-1] * jnp.log(2 * jnp.pi) +def generate_fullrank_logdensity(mu, chol_params): + mu_flatten, _ = jax.flatten_util.ravel_pytree(mu) + chol_factor = _unflatten_cholesky(chol_params) + cov = chol_factor @ chol_factor.T def fullrank_logdensity(position): - y = jsp.linalg.solve_triangular(cholesky, position - mu, lower=True) - mahalanobis_dist = jnp.sum(y ** 2, axis=-1) - return const - 0.5 * log_det - 0.5 * mahalanobis_dist + position_flatten = jax.flatten_util.ravel_pytree(position)[0] + # TODO: inefficient because of redundant cholesky decomposition + return jsp.stats.multivariate_normal.logpdf(position_flatten, mu_flatten, cov) - return fullrank_logdensity \ No newline at end of file + return fullrank_logdensity diff --git a/tests/vi/test_fullrank_vi.py b/tests/vi/test_fullrank_vi.py index 61cd3ed34..a8b35d131 100644 --- a/tests/vi/test_fullrank_vi.py +++ b/tests/vi/test_fullrank_vi.py @@ -45,5 +45,6 @@ def logdensity_fn(x): self.assertAlmostEqual(loc_1, ground_truth[0][0], delta=0.01) self.assertAlmostEqual(loc_2, ground_truth[1][0], delta=0.01) + if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From c2b38eb3f13f7ebc809aa1ec95a68f43cc2c243a Mon Sep 17 00:00:00 2001 From: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Date: Fri, 16 Aug 2024 04:15:22 +0000 Subject: [PATCH 03/12] Doc: clarify chol_params order --- blackjax/vi/fullrank_vi.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/blackjax/vi/fullrank_vi.py b/blackjax/vi/fullrank_vi.py index 386f29a97..f21a01f65 100644 --- a/blackjax/vi/fullrank_vi.py +++ b/blackjax/vi/fullrank_vi.py @@ -148,8 +148,9 @@ def _unflatten_cholesky(chol_params): """Construct the Cholesky factor from a flattened vector of cholesky parameters. Transforms a flattened vector representation of a lower triangular matrix - into a full Cholesky factor. The input vector contains n = dim * (dim + 1) / 2 - elements, where dim is the dimension of the resulting square matrix. + into a full Cholesky factor. The input vector contains n = d(d+1)/2 elements + consisting of d diagonal elements followed by n - d off-diagonal elements in + row-major order, where d is the dimension of the matrix. The diagonal elements are passed through a softplus function to ensure (numerically stable) positivity, crucial to maintain a valid covariance matrix parameterization. From b13eb12c7b35136d97d932763a0eaec76bd89318 Mon Sep 17 00:00:00 2001 From: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Date: Fri, 16 Aug 2024 04:21:40 +0000 Subject: [PATCH 04/12] Doc: formatting --- blackjax/vi/fullrank_vi.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/blackjax/vi/fullrank_vi.py b/blackjax/vi/fullrank_vi.py index f21a01f65..a6aea5b99 100644 --- a/blackjax/vi/fullrank_vi.py +++ b/blackjax/vi/fullrank_vi.py @@ -33,7 +33,7 @@ class FRVIState(NamedTuple): mu: ArrayTree - chol_params: ArrayTree # flattened Cholesky factor + chol_params: ArrayTree # flattened Cholesky factor opt_state: OptState @@ -148,8 +148,8 @@ def _unflatten_cholesky(chol_params): """Construct the Cholesky factor from a flattened vector of cholesky parameters. Transforms a flattened vector representation of a lower triangular matrix - into a full Cholesky factor. The input vector contains n = d(d+1)/2 elements - consisting of d diagonal elements followed by n - d off-diagonal elements in + into a full Cholesky factor. The input vector contains n = d(d+1)/2 elements + consisting of d diagonal elements followed by n - d off-diagonal elements in row-major order, where d is the dimension of the matrix. The diagonal elements are passed through a softplus function to ensure (numerically From 4ea435ce7d35bdf5103d454e6179097b7cb048b8 Mon Sep 17 00:00:00 2001 From: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Date: Fri, 16 Aug 2024 05:31:03 +0000 Subject: [PATCH 05/12] Enh: compute normal log density with cholesky factor --- blackjax/vi/fullrank_vi.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/blackjax/vi/fullrank_vi.py b/blackjax/vi/fullrank_vi.py index a6aea5b99..45d6d051f 100644 --- a/blackjax/vi/fullrank_vi.py +++ b/blackjax/vi/fullrank_vi.py @@ -50,7 +50,7 @@ def init( """Initialize the full-rank VI state.""" mu = jax.tree.map(jnp.zeros_like, position) dim = jax.flatten_util.ravel_pytree(mu)[0].shape[0] - chol_params = jax.flatten_util.ravel_pytree(jnp.tril(jnp.eye(dim)))[0] + chol_params, _ = jax.flatten_util.ravel_pytree(jnp.tril(jnp.eye(dim))) opt_state = optimizer.init((mu, chol_params)) return FRVIState(mu, chol_params, opt_state) @@ -170,7 +170,7 @@ def _unflatten_cholesky(chol_params): def _sample(rng_key, mu, chol_params, num_samples): mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu) chol_factor = _unflatten_cholesky(chol_params) - eps = jax.random.normal(rng_key, (num_samples, mu_flatten.shape[0])) + eps = jax.random.normal(rng_key, (num_samples, mu_flatten.size)) flatten_sample = eps @ chol_factor.T + mu_flatten return jax.vmap(unravel_fn)(flatten_sample) @@ -178,11 +178,14 @@ def _sample(rng_key, mu, chol_params, num_samples): def generate_fullrank_logdensity(mu, chol_params): mu_flatten, _ = jax.flatten_util.ravel_pytree(mu) chol_factor = _unflatten_cholesky(chol_params) - cov = chol_factor @ chol_factor.T + log_det = 2 * jnp.sum(jnp.log(jnp.diag(chol_factor))) + const = -0.5 * mu_flatten.size * jnp.log(2 * jnp.pi) def fullrank_logdensity(position): - position_flatten = jax.flatten_util.ravel_pytree(position)[0] - # TODO: inefficient because of redundant cholesky decomposition - return jsp.stats.multivariate_normal.logpdf(position_flatten, mu_flatten, cov) + position_flatten, _ = jax.flatten_util.ravel_pytree(position) + centered_position = position_flatten - mu_flatten + y = jsp.linalg.solve_triangular(chol_factor, centered_position, lower=True) + mahalanobis_dist = jnp.sum(y ** 2) + return const - 0.5 * (log_det + mahalanobis_dist) return fullrank_logdensity From 511988952da0949ba7fee77250a9f3325c58b2f7 Mon Sep 17 00:00:00 2001 From: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Date: Fri, 16 Aug 2024 05:32:29 +0000 Subject: [PATCH 06/12] Doc: formatting --- blackjax/vi/fullrank_vi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blackjax/vi/fullrank_vi.py b/blackjax/vi/fullrank_vi.py index 45d6d051f..bd56932be 100644 --- a/blackjax/vi/fullrank_vi.py +++ b/blackjax/vi/fullrank_vi.py @@ -185,7 +185,7 @@ def fullrank_logdensity(position): position_flatten, _ = jax.flatten_util.ravel_pytree(position) centered_position = position_flatten - mu_flatten y = jsp.linalg.solve_triangular(chol_factor, centered_position, lower=True) - mahalanobis_dist = jnp.sum(y ** 2) + mahalanobis_dist = jnp.sum(y**2) return const - 0.5 * (log_det + mahalanobis_dist) return fullrank_logdensity From 6b9c002b59a3b99be34cd3b2ff37a541ffa506c7 Mon Sep 17 00:00:00 2001 From: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Date: Fri, 16 Aug 2024 05:37:18 +0000 Subject: [PATCH 07/12] Doc: Clarify Cholesky unflattening --- blackjax/vi/fullrank_vi.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/blackjax/vi/fullrank_vi.py b/blackjax/vi/fullrank_vi.py index bd56932be..d65fadf85 100644 --- a/blackjax/vi/fullrank_vi.py +++ b/blackjax/vi/fullrank_vi.py @@ -145,15 +145,15 @@ def sample_fn(rng_key: PRNGKey, state: FRVIState, num_samples: int): def _unflatten_cholesky(chol_params): - """Construct the Cholesky factor from a flattened vector of cholesky parameters. + """Construct the Cholesky factor from a flattened vector of Cholesky parameters. Transforms a flattened vector representation of a lower triangular matrix into a full Cholesky factor. The input vector contains n = d(d+1)/2 elements - consisting of d diagonal elements followed by n - d off-diagonal elements in + consisting of d diagonal elements followed by n-d off-diagonal elements in row-major order, where d is the dimension of the matrix. The diagonal elements are passed through a softplus function to ensure (numerically - stable) positivity, crucial to maintain a valid covariance matrix parameterization. + stable) positivity, such that the resulting Cholesky factor is positive definite. This parameterization allows for unconstrained optimization while ensuring the resulting covariance matrix Sigma = CC^T is symmetric and positive definite. From 4379a6d82d25126b30570735f16c16d07e557efb Mon Sep 17 00:00:00 2001 From: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Date: Mon, 19 Aug 2024 03:49:03 +0000 Subject: [PATCH 08/12] Fix: Non-jitted full-rank VI works Fix testing bug, add docstrings, and change softmax to exponential when converting `chol_params` to `chol_factor` in `_unflatten_cholesky`. --- blackjax/vi/fullrank_vi.py | 84 +++++++++++++++++++++++++++++------- tests/vi/test_fullrank_vi.py | 6 +-- 2 files changed, 71 insertions(+), 19 deletions(-) diff --git a/blackjax/vi/fullrank_vi.py b/blackjax/vi/fullrank_vi.py index d65fadf85..009de960e 100644 --- a/blackjax/vi/fullrank_vi.py +++ b/blackjax/vi/fullrank_vi.py @@ -14,6 +14,7 @@ from typing import Callable, NamedTuple import jax +import jax.flatten_util import jax.numpy as jnp import jax.scipy as jsp from optax import GradientTransformation, OptState @@ -32,12 +33,32 @@ class FRVIState(NamedTuple): + """State of the full-rank VI algorithm. + + mu: + Mean of the Gaussian approximation. + chol_params: + Flattened Cholesky factor of the Gaussian approximation, used to parameterize + the full-rank covariance matrix. A vector of length d(d+1)/2 for a + d-dimensional Gaussian, containing d diagonal elements (in log space) followed + by lower triangular elements in row-major order. + opt_state: + Optax optimizer state. + + """ + mu: ArrayTree - chol_params: ArrayTree # flattened Cholesky factor + chol_params: ArrayTree opt_state: OptState class FRVIInfo(NamedTuple): + """Extra information of the full-rank VI algorithm. + + elbo: + ELBO of approximation wrt target distribution. + + """ elbo: float @@ -47,10 +68,10 @@ def init( *optimizer_args, **optimizer_kwargs, ) -> FRVIState: - """Initialize the full-rank VI state.""" + """Initialize the full-rank VI state with zero mean and identity covariance.""" mu = jax.tree.map(jnp.zeros_like, position) dim = jax.flatten_util.ravel_pytree(mu)[0].shape[0] - chol_params, _ = jax.flatten_util.ravel_pytree(jnp.tril(jnp.eye(dim))) + chol_params = jnp.zeros(dim * (dim + 1) // 2) opt_state = optimizer.init((mu, chol_params)) return FRVIState(mu, chol_params, opt_state) @@ -63,7 +84,7 @@ def step( num_samples: int = 5, stl_estimator: bool = True, ) -> tuple[FRVIState, FRVIInfo]: - """Approximate the target density using the full-rank approximation. + """Approximate the target density using the full-rank Gaussian approximation Parameters ---------- @@ -92,7 +113,7 @@ def kl_divergence_fn(parameters): mu, chol_params = parameters z = _sample(rng_key, mu, chol_params, num_samples) if stl_estimator: - parameters = jax.tree_map(jax.lax.stop_gradient, (mu, chol_params)) + parameters = jax.tree.map(jax.lax.stop_gradient, (mu, chol_params)) logq = jax.vmap(generate_fullrank_logdensity(mu, chol_params))(z) logp = jax.vmap(logdensity_fn)(z) return (logq - logp).mean() @@ -147,30 +168,61 @@ def sample_fn(rng_key: PRNGKey, state: FRVIState, num_samples: int): def _unflatten_cholesky(chol_params): """Construct the Cholesky factor from a flattened vector of Cholesky parameters. - Transforms a flattened vector representation of a lower triangular matrix - into a full Cholesky factor. The input vector contains n = d(d+1)/2 elements - consisting of d diagonal elements followed by n-d off-diagonal elements in - row-major order, where d is the dimension of the matrix. - - The diagonal elements are passed through a softplus function to ensure (numerically - stable) positivity, such that the resulting Cholesky factor is positive definite. + The Cholesky factor (L) is a lower triangular matrix with positive diagonal + elements used to parameterize the (full-rank) covariance matrix of the Gaussian + approximation as Sigma = LL^T. + + This parameterization allows for (1) efficient sampling and log density evaluation, + and (2) ensuring the covariance matrix is symmetric and positive definite during + (unconconstrained) optimization. + + Transforms a flattened vector representation of the Cholesky factor (`chol_params`) + into its proper lower triangular matrix form (`chol_factor`). It specifically + reshapes the input vector `chol_params` into a lower triangular matrix with zeros + above the diagonal and exponentiates the diagonal elements to ensure positivity. - This parameterization allows for unconstrained optimization while ensuring the - resulting covariance matrix Sigma = CC^T is symmetric and positive definite. + Parameters + ---------- + chol_params + Flattened Cholesky factor of the full-rank covariance matrix. + + Returns + ------- + chol_factor + Cholesky factor of the full-rank covariance matrix. """ + n = chol_params.size dim = int(jnp.sqrt(1 + 8 * n) - 1) // 2 tril = jnp.zeros((dim, dim)) tril = tril.at[jnp.tril_indices(dim, k=-1)].set(chol_params[dim:]) - diag = jax.nn.softplus(chol_params[:dim]) + diag = jnp.exp(chol_params[:dim]) # TODO: replace with softplus? chol_factor = tril + jnp.diag(diag) return chol_factor def _sample(rng_key, mu, chol_params, num_samples): + """Sample from the full-rank Gaussian approximation of the target distribution. + + Parameters + ---------- + rng_key + Key for JAX's pseudo-random number generator. + mu + Mean of the Gaussian approximation. + chol_params + Flattened Cholesky factor of the Gaussian approximation. + num_samples + Number of samples to draw. + + Returns + ------- + Samples drawn from the full-rank Gaussian approximation. + + """ mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu) chol_factor = _unflatten_cholesky(chol_params) - eps = jax.random.normal(rng_key, (num_samples, mu_flatten.size)) + eps = jax.random.normal(rng_key, (num_samples,) + mu_flatten.shape) flatten_sample = eps @ chol_factor.T + mu_flatten return jax.vmap(unravel_fn)(flatten_sample) diff --git a/tests/vi/test_fullrank_vi.py b/tests/vi/test_fullrank_vi.py index a8b35d131..411e0c8d0 100644 --- a/tests/vi/test_fullrank_vi.py +++ b/tests/vi/test_fullrank_vi.py @@ -11,9 +11,9 @@ class FullRankVITest(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.PRNGKey(42) + self.key = jax.random.key(42) - @chex.variants(with_jit=True, without_jit=True) + @chex.variants(with_jit=True, without_jit=False) def test_recover_posterior(self): ground_truth = [ # loc, scale @@ -38,7 +38,7 @@ def logdensity_fn(x): rng_key = self.key for i in range(num_steps): - subkey = jax.random.split(rng_key, i) + subkey = jax.random.fold_in(rng_key, i) state, _ = self.variant(frvi.step)(subkey, state) loc_1, loc_2 = state.mu["x_1"], state.mu["x_2"] From bb08e1f364cda6a4938cf00517cc6a96c324a3b8 Mon Sep 17 00:00:00 2001 From: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Date: Mon, 19 Aug 2024 03:58:35 +0000 Subject: [PATCH 09/12] Doc: formatting --- blackjax/vi/fullrank_vi.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/blackjax/vi/fullrank_vi.py b/blackjax/vi/fullrank_vi.py index 009de960e..4ae17fcf6 100644 --- a/blackjax/vi/fullrank_vi.py +++ b/blackjax/vi/fullrank_vi.py @@ -39,12 +39,12 @@ class FRVIState(NamedTuple): Mean of the Gaussian approximation. chol_params: Flattened Cholesky factor of the Gaussian approximation, used to parameterize - the full-rank covariance matrix. A vector of length d(d+1)/2 for a - d-dimensional Gaussian, containing d diagonal elements (in log space) followed + the full-rank covariance matrix. A vector of length d(d+1)/2 for a + d-dimensional Gaussian, containing d diagonal elements (in log space) followed by lower triangular elements in row-major order. opt_state: Optax optimizer state. - + """ mu: ArrayTree @@ -54,11 +54,12 @@ class FRVIState(NamedTuple): class FRVIInfo(NamedTuple): """Extra information of the full-rank VI algorithm. - + elbo: ELBO of approximation wrt target distribution. """ + elbo: float @@ -168,42 +169,42 @@ def sample_fn(rng_key: PRNGKey, state: FRVIState, num_samples: int): def _unflatten_cholesky(chol_params): """Construct the Cholesky factor from a flattened vector of Cholesky parameters. - The Cholesky factor (L) is a lower triangular matrix with positive diagonal - elements used to parameterize the (full-rank) covariance matrix of the Gaussian + The Cholesky factor (L) is a lower triangular matrix with positive diagonal + elements used to parameterize the (full-rank) covariance matrix of the Gaussian approximation as Sigma = LL^T. - + This parameterization allows for (1) efficient sampling and log density evaluation, - and (2) ensuring the covariance matrix is symmetric and positive definite during + and (2) ensuring the covariance matrix is symmetric and positive definite during (unconconstrained) optimization. - + Transforms a flattened vector representation of the Cholesky factor (`chol_params`) - into its proper lower triangular matrix form (`chol_factor`). It specifically - reshapes the input vector `chol_params` into a lower triangular matrix with zeros + into its proper lower triangular matrix form (`chol_factor`). It specifically + reshapes the input vector `chol_params` into a lower triangular matrix with zeros above the diagonal and exponentiates the diagonal elements to ensure positivity. Parameters ---------- chol_params Flattened Cholesky factor of the full-rank covariance matrix. - + Returns ------- chol_factor Cholesky factor of the full-rank covariance matrix. """ - + n = chol_params.size dim = int(jnp.sqrt(1 + 8 * n) - 1) // 2 tril = jnp.zeros((dim, dim)) tril = tril.at[jnp.tril_indices(dim, k=-1)].set(chol_params[dim:]) - diag = jnp.exp(chol_params[:dim]) # TODO: replace with softplus? + diag = jnp.exp(chol_params[:dim]) # TODO: replace with softplus? chol_factor = tril + jnp.diag(diag) return chol_factor def _sample(rng_key, mu, chol_params, num_samples): """Sample from the full-rank Gaussian approximation of the target distribution. - + Parameters ---------- rng_key @@ -214,7 +215,7 @@ def _sample(rng_key, mu, chol_params, num_samples): Flattened Cholesky factor of the Gaussian approximation. num_samples Number of samples to draw. - + Returns ------- Samples drawn from the full-rank Gaussian approximation. From ced702fe55954a8cab8024ba943b7f6cab92a2d6 Mon Sep 17 00:00:00 2001 From: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Date: Mon, 19 Aug 2024 21:13:04 +0000 Subject: [PATCH 10/12] Fix: Full-rank VI compatible with JIT compilation Refactor `_unflatten_cholesky()` function to take `dim` argument instead of infering it (dynamically) from the `chol_params` input vector. This avoids JIT compilation issues. Also update docstrings. --- blackjax/vi/fullrank_vi.py | 47 ++++++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/blackjax/vi/fullrank_vi.py b/blackjax/vi/fullrank_vi.py index 4ae17fcf6..5e4e23990 100644 --- a/blackjax/vi/fullrank_vi.py +++ b/blackjax/vi/fullrank_vi.py @@ -85,7 +85,7 @@ def step( num_samples: int = 5, stl_estimator: bool = True, ) -> tuple[FRVIState, FRVIInfo]: - """Approximate the target density using the full-rank Gaussian approximation + """Approximate the target density using the full-rank Gaussian approximation. Parameters ---------- @@ -166,35 +166,36 @@ def sample_fn(rng_key: PRNGKey, state: FRVIState, num_samples: int): return VIAlgorithm(init_fn, step_fn, sample_fn) -def _unflatten_cholesky(chol_params): +def _unflatten_cholesky(chol_params, dim): """Construct the Cholesky factor from a flattened vector of Cholesky parameters. + Transforms a flattened vector representation of the Cholesky factor (`chol_params`) + into its proper lower triangular matrix form (`chol_factor`). It specifically + reshapes the input vector `chol_params` into a lower triangular matrix with zeros + above the diagonal and exponentiates the diagonal elements to ensure positivity. + The Cholesky factor (L) is a lower triangular matrix with positive diagonal - elements used to parameterize the (full-rank) covariance matrix of the Gaussian + elements used to parameterize the full-rank covariance matrix of the Gaussian approximation as Sigma = LL^T. This parameterization allows for (1) efficient sampling and log density evaluation, and (2) ensuring the covariance matrix is symmetric and positive definite during (unconconstrained) optimization. - Transforms a flattened vector representation of the Cholesky factor (`chol_params`) - into its proper lower triangular matrix form (`chol_factor`). It specifically - reshapes the input vector `chol_params` into a lower triangular matrix with zeros - above the diagonal and exponentiates the diagonal elements to ensure positivity. - Parameters ---------- chol_params Flattened Cholesky factor of the full-rank covariance matrix. + dim + Dimensionality of the Gaussian distribution. Returns ------- chol_factor Cholesky factor of the full-rank covariance matrix. + """ - n = chol_params.size - dim = int(jnp.sqrt(1 + 8 * n) - 1) // 2 tril = jnp.zeros((dim, dim)) tril = tril.at[jnp.tril_indices(dim, k=-1)].set(chol_params[dim:]) diag = jnp.exp(chol_params[:dim]) # TODO: replace with softplus? @@ -221,18 +222,36 @@ def _sample(rng_key, mu, chol_params, num_samples): Samples drawn from the full-rank Gaussian approximation. """ + mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu) - chol_factor = _unflatten_cholesky(chol_params) - eps = jax.random.normal(rng_key, (num_samples,) + mu_flatten.shape) + dim = mu_flatten.size + chol_factor = _unflatten_cholesky(chol_params, dim) + eps = jax.random.normal(rng_key, (num_samples,) + (dim,)) flatten_sample = eps @ chol_factor.T + mu_flatten return jax.vmap(unravel_fn)(flatten_sample) def generate_fullrank_logdensity(mu, chol_params): + """Generate the log-density function of a full-rank Gaussian distribution. + + Parameters + ---------- + mu + Mean of the Gaussian distribution. + chol_params + Flattened Cholesky factor of the Gaussian distribution. + + Returns + ------- + A function that computes the log-density of the full-rank Gaussian distribution. + + """ + mu_flatten, _ = jax.flatten_util.ravel_pytree(mu) - chol_factor = _unflatten_cholesky(chol_params) + dim = mu_flatten.size + chol_factor = _unflatten_cholesky(chol_params, dim) log_det = 2 * jnp.sum(jnp.log(jnp.diag(chol_factor))) - const = -0.5 * mu_flatten.size * jnp.log(2 * jnp.pi) + const = -0.5 * dim * jnp.log(2 * jnp.pi) def fullrank_logdensity(position): position_flatten, _ = jax.flatten_util.ravel_pytree(position) From 26da046980839d5114c2c4e524c5efed4b03cdfb Mon Sep 17 00:00:00 2001 From: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Date: Mon, 19 Aug 2024 21:16:56 +0000 Subject: [PATCH 11/12] Doc: formatting --- blackjax/vi/fullrank_vi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blackjax/vi/fullrank_vi.py b/blackjax/vi/fullrank_vi.py index 5e4e23990..321c1fa5c 100644 --- a/blackjax/vi/fullrank_vi.py +++ b/blackjax/vi/fullrank_vi.py @@ -222,7 +222,7 @@ def _sample(rng_key, mu, chol_params, num_samples): Samples drawn from the full-rank Gaussian approximation. """ - + mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu) dim = mu_flatten.size chol_factor = _unflatten_cholesky(chol_params, dim) From 4b4534f4e5c78bc61da7f275eb51264d1b6fb3f9 Mon Sep 17 00:00:00 2001 From: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Date: Wed, 11 Sep 2024 21:46:51 +0000 Subject: [PATCH 12/12] Tests: Check full-rank covariance matrix Add assert statements that verify full-rank VI recovers the true, full-rank covariance matrix. --- tests/vi/test_fullrank_vi.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/vi/test_fullrank_vi.py b/tests/vi/test_fullrank_vi.py index 411e0c8d0..1cb486c32 100644 --- a/tests/vi/test_fullrank_vi.py +++ b/tests/vi/test_fullrank_vi.py @@ -13,7 +13,6 @@ def setUp(self): super().setUp() self.key = jax.random.key(42) - @chex.variants(with_jit=True, without_jit=False) def test_recover_posterior(self): ground_truth = [ # loc, scale @@ -39,11 +38,15 @@ def logdensity_fn(x): rng_key = self.key for i in range(num_steps): subkey = jax.random.fold_in(rng_key, i) - state, _ = self.variant(frvi.step)(subkey, state) + state, _ = jax.jit(frvi.step)(subkey, state) loc_1, loc_2 = state.mu["x_1"], state.mu["x_2"] + chol_factor = state.chol_params + scale_1, scale_2 = jnp.exp(chol_factor[0]), jnp.exp(chol_factor[1]) self.assertAlmostEqual(loc_1, ground_truth[0][0], delta=0.01) + self.assertAlmostEqual(scale_1, ground_truth[0][1], delta=0.01) self.assertAlmostEqual(loc_2, ground_truth[1][0], delta=0.01) + self.assertAlmostEqual(scale_2, ground_truth[1][1], delta=0.01) if __name__ == "__main__":