diff --git a/docs/source/inference_algos.rst b/docs/source/inference_algos.rst index 153028d1f7..d578d73260 100644 --- a/docs/source/inference_algos.rst +++ b/docs/source/inference_algos.rst @@ -51,6 +51,12 @@ ELBO :show-inheritance: :member-order: bysource +.. automodule:: pyro.infer.traceiwae_elbo + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + .. automodule:: pyro.infer.tracetmc_elbo :members: :undoc-members: diff --git a/pyro/infer/__init__.py b/pyro/infer/__init__.py index 5485b7476e..2d7dce2586 100644 --- a/pyro/infer/__init__.py +++ b/pyro/infer/__init__.py @@ -23,6 +23,7 @@ from pyro.infer.trace_tail_adaptive_elbo import TraceTailAdaptive_ELBO from pyro.infer.traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO from pyro.infer.tracegraph_elbo import JitTraceGraph_ELBO, TraceGraph_ELBO +from pyro.infer.traceiwae_elbo import TraceIWAE_ELBO from pyro.infer.tracetmc_elbo import TraceTMC_ELBO from pyro.infer.util import enable_validation, is_validation_enabled @@ -54,6 +55,7 @@ "TraceTMC_ELBO", "TraceEnum_ELBO", "TraceGraph_ELBO", + "TraceIWAE_ELBO", "TraceMeanField_ELBO", "TracePosterior", "TracePredictive", diff --git a/pyro/infer/traceiwae_elbo.py b/pyro/infer/traceiwae_elbo.py new file mode 100644 index 0000000000..d05706d150 --- /dev/null +++ b/pyro/infer/traceiwae_elbo.py @@ -0,0 +1,126 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import math + +import pyro.poutine as poutine +from pyro.distributions.util import detach +from pyro.infer.elbo import ELBO +from pyro.infer.util import is_validation_enabled, torch_item +from pyro.poutine.replay_messenger import ReplayMessenger +from pyro.poutine.util import prune_subsample_sites +from pyro.util import ( + check_if_enumerated, + check_model_guide_match, + check_site_shape, + warn_if_nan, +) + + +class DetachReplayMessenger(ReplayMessenger): + def _pyro_sample(self, msg): + super()._pyro_sample(msg) + if msg["name"] in self.trace: + msg["value"] = msg["value"].detach() + + +class TraceIWAE_ELBO(ELBO): + """ + A trace implementation of ELBO-based SVI using the doubly reparameterized + gradient estimator [1]. + + **References** + + [1] G. Tucker, D. Wawson, S. Gu, C.J. Maddison (2018) + Doubly Reparameterized Gradient Estimators for Monte Carlo Objectives + https://arxiv.org/abs/1810.04152 + """ + def _get_trace(*args, **kwargs): + raise ValueError("Use see _get_importance_trace() instead") + + def _get_importance_trace(self, model, guide, args, kwargs): + assert self.vectorize_particles + if self.max_plate_nesting == math.inf: + self._guess_max_plate_nesting(model, guide, args, kwargs) + model = self._vectorized_num_particles(model) + guide = self._vectorized_num_particles(guide) + + guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) + with poutine.trace() as tr: + with DetachReplayMessenger(trace=guide_trace): + model(*args, **kwargs) + model_trace = tr.trace + if is_validation_enabled(): + check_model_guide_match(model_trace, guide_trace, self.max_plate_nesting) + + guide_trace = prune_subsample_sites(guide_trace) + model_trace = prune_subsample_sites(model_trace) + + # Note we can avoid computing log_prob and score_parts unless validating. + if is_validation_enabled(): + model_trace.compute_log_prob() + guide_trace.compute_log_prob() + for site in model_trace.nodes.values(): + if site["type"] == "sample": + check_site_shape(site, self.max_plate_nesting) + for site in guide_trace.nodes.values(): + if site["type"] == "sample": + check_site_shape(site, self.max_plate_nesting) + check_if_enumerated(guide_trace) + + return model_trace, guide_trace + + def loss(self, model, guide, *args, **kwargs): + loss = self.differentiable_loss(model, guide, *args, **kwargs) + return torch_item(loss) + return loss + + def loss_and_grads(self, model, guide, *args, **kwargs): + loss = self.differentiable_loss(model, guide, *args, **kwargs) + loss.backward() + return torch_item(loss) + + def differentiable_loss(self, model, guide, *args, **kwargs): + model_trace, guide_trace = self._get_importance_trace(model, guide, args, kwargs) + + # The following computation follows Sec. 8.3 Eqn. (12) of [1]. + log_w_bar = 0. # all gradients stopped + log_w_hat = 0. # gradients stopped wrt distribution parameters + log_p_tilde = 0. # gradients stopped wrt reparameterized z + + def particle_sum(x): + "sum out everything but the particle plate dimension" + assert x.size(0) == self.num_particles + return x.reshape(self.num_particles, -1).sum(-1) + + for name, site in model_trace.nodes.items(): + if site["type"] == "sample": + fn = site["fn"] + z_detach = site["value"] + if name in guide_trace: + z = guide_trace.nodes[name]["value"] + else: + z = z_detach + + log_p = particle_sum(detach(fn).log_prob(z)) + log_w_bar = log_w_bar + log_p.detach() + log_w_hat = log_w_hat + log_p + log_p_tilde = log_p_tilde + particle_sum(fn.log_prob(z_detach)) + + for name, site in guide_trace.nodes.items(): + if site["type"] == "sample": + fn = site["fn"] + z = site["value"] + + log_q = particle_sum(detach(fn).log_prob(z)) + log_w_bar = log_w_bar - log_q.detach() + log_w_hat = log_w_hat - log_q + + log_W_bar = log_w_bar.logsumexp(0) + weight_bar = (log_w_bar - log_W_bar).exp() + surrogate_elbo = weight_bar.dot(log_p_tilde) + weight_bar.pow(2).dot(log_w_hat) + elbo = log_W_bar - math.log(self.num_particles) + loss = -elbo + surrogate_elbo.detach() - surrogate_elbo + + warn_if_nan(loss, "loss") + return loss diff --git a/tests/infer/test_gradient.py b/tests/infer/test_gradient.py index 45ca0d0932..c65eb6ca56 100644 --- a/tests/infer/test_gradient.py +++ b/tests/infer/test_gradient.py @@ -22,11 +22,17 @@ Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO, + TraceIWAE_ELBO, TraceMeanField_ELBO, config_enumerate, ) from pyro.optim import Adam -from tests.common import assert_equal, xfail_if_not_implemented, xfail_param +from tests.common import ( + assert_close, + assert_equal, + xfail_if_not_implemented, + xfail_param, +) logger = logging.getLogger(__name__) @@ -47,6 +53,7 @@ def DiffTrace_ELBO(*args, **kwargs): (TraceMeanField_ELBO, False), (TraceEnum_ELBO, False), (TraceEnum_ELBO, True), + (TraceIWAE_ELBO, False), ]) def test_subsample_gradient(Elbo, reparameterized, has_rsample, subsample, local_samples, scale): pyro.clear_param_store() @@ -54,6 +61,8 @@ def test_subsample_gradient(Elbo, reparameterized, has_rsample, subsample, local subsample_size = 1 if subsample else len(data) precision = 0.06 * scale Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal + if Elbo is TraceIWAE_ELBO and not (reparameterized and has_rsample): + pytest.skip("not implemented") def model(subsample): with pyro.plate("data", len(data), subsample_size, subsample) as ind: @@ -102,6 +111,41 @@ def guide(subsample): assert_equal(actual_grads, expected_grads, prec=precision) +@pytest.mark.parametrize("Elbo", [ + Trace_ELBO, + TraceEnum_ELBO, + TraceIWAE_ELBO, +]) +def test_zero_gradient(Elbo): + data = torch.tensor([0.0, 2.0, 2.0]) + + def model(data): + z = pyro.sample("z", dist.Normal(0, 1)) + with pyro.plate("data", len(data)): + pyro.sample("x", dist.Normal(z, 1), obs=data) + + # This guide should be the true posterior. + def guide(data): + loc = pyro.param("loc", lambda: torch.tensor(1.0)) + scale = pyro.param("scale", lambda: torch.tensor(0.5)) + pyro.sample("z", dist.Normal(loc, scale)) + + elbo = Elbo(max_plate_nesting=1, + num_particles=10000, + vectorize_particles=True, + strict_enumeration_warning=False) + + loss = elbo.differentiable_loss(model, guide, data) + logger.info('loss = {}'.format(loss.item())) + assert_close(loss.item(), 5.44996, atol=0.001) + + params = dict(pyro.get_param_store().named_parameters()) + grads = torch.autograd.grad(loss, params.values()) + for name, grad in zip(params, grads): + logger.info('grad {} = {}'.format(name, grad)) + assert_close(grad, torch.zeros_like(grad), atol=0.01) + + @pytest.mark.parametrize("reparameterized", [True, False], ids=["reparam", "nonreparam"]) @pytest.mark.parametrize("Elbo", [Trace_ELBO, DiffTrace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) def test_plate(Elbo, reparameterized):