Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement first version of IWAE with DReG #2605

Draft
wants to merge 2 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/inference_algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ ELBO
:show-inheritance:
:member-order: bysource

.. automodule:: pyro.infer.traceiwae_elbo
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

.. automodule:: pyro.infer.tracetmc_elbo
:members:
:undoc-members:
Expand Down
2 changes: 2 additions & 0 deletions pyro/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pyro.infer.trace_tail_adaptive_elbo import TraceTailAdaptive_ELBO
from pyro.infer.traceenum_elbo import JitTraceEnum_ELBO, TraceEnum_ELBO
from pyro.infer.tracegraph_elbo import JitTraceGraph_ELBO, TraceGraph_ELBO
from pyro.infer.traceiwae_elbo import TraceIWAE_ELBO
from pyro.infer.tracetmc_elbo import TraceTMC_ELBO
from pyro.infer.util import enable_validation, is_validation_enabled

Expand Down Expand Up @@ -54,6 +55,7 @@
"TraceTMC_ELBO",
"TraceEnum_ELBO",
"TraceGraph_ELBO",
"TraceIWAE_ELBO",
"TraceMeanField_ELBO",
"TracePosterior",
"TracePredictive",
Expand Down
126 changes: 126 additions & 0 deletions pyro/infer/traceiwae_elbo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math

import pyro.poutine as poutine
from pyro.distributions.util import detach
from pyro.infer.elbo import ELBO
from pyro.infer.util import is_validation_enabled, torch_item
from pyro.poutine.replay_messenger import ReplayMessenger
from pyro.poutine.util import prune_subsample_sites
from pyro.util import (
check_if_enumerated,
check_model_guide_match,
check_site_shape,
warn_if_nan,
)


class DetachReplayMessenger(ReplayMessenger):
def _pyro_sample(self, msg):
super()._pyro_sample(msg)
if msg["name"] in self.trace:
msg["value"] = msg["value"].detach()


class TraceIWAE_ELBO(ELBO):
"""
A trace implementation of ELBO-based SVI using the doubly reparameterized
gradient estimator [1].

**References**

[1] G. Tucker, D. Wawson, S. Gu, C.J. Maddison (2018)
Doubly Reparameterized Gradient Estimators for Monte Carlo Objectives
https://arxiv.org/abs/1810.04152
"""
def _get_trace(*args, **kwargs):
raise ValueError("Use see _get_importance_trace() instead")

def _get_importance_trace(self, model, guide, args, kwargs):
assert self.vectorize_particles
if self.max_plate_nesting == math.inf:
self._guess_max_plate_nesting(model, guide, args, kwargs)
model = self._vectorized_num_particles(model)
guide = self._vectorized_num_particles(guide)

guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
with poutine.trace() as tr:
with DetachReplayMessenger(trace=guide_trace):
model(*args, **kwargs)
model_trace = tr.trace
if is_validation_enabled():
check_model_guide_match(model_trace, guide_trace, self.max_plate_nesting)

guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)

# Note we can avoid computing log_prob and score_parts unless validating.
if is_validation_enabled():
model_trace.compute_log_prob()
guide_trace.compute_log_prob()
for site in model_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_plate_nesting)
for site in guide_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_plate_nesting)
check_if_enumerated(guide_trace)

return model_trace, guide_trace

def loss(self, model, guide, *args, **kwargs):
loss = self.differentiable_loss(model, guide, *args, **kwargs)
return torch_item(loss)
return loss

def loss_and_grads(self, model, guide, *args, **kwargs):
loss = self.differentiable_loss(model, guide, *args, **kwargs)
loss.backward()
return torch_item(loss)

def differentiable_loss(self, model, guide, *args, **kwargs):
model_trace, guide_trace = self._get_importance_trace(model, guide, args, kwargs)

# The following computation follows Sec. 8.3 Eqn. (12) of [1].
log_w_bar = 0. # all gradients stopped
log_w_hat = 0. # gradients stopped wrt distribution parameters
log_p_tilde = 0. # gradients stopped wrt reparameterized z

def particle_sum(x):
"sum out everything but the particle plate dimension"
assert x.size(0) == self.num_particles
return x.reshape(self.num_particles, -1).sum(-1)

for name, site in model_trace.nodes.items():
if site["type"] == "sample":
fn = site["fn"]
z_detach = site["value"]
if name in guide_trace:
z = guide_trace.nodes[name]["value"]
else:
z = z_detach

log_p = particle_sum(detach(fn).log_prob(z))
log_w_bar = log_w_bar + log_p.detach()
log_w_hat = log_w_hat + log_p
log_p_tilde = log_p_tilde + particle_sum(fn.log_prob(z_detach))

for name, site in guide_trace.nodes.items():
if site["type"] == "sample":
fn = site["fn"]
z = site["value"]

log_q = particle_sum(detach(fn).log_prob(z))
log_w_bar = log_w_bar - log_q.detach()
log_w_hat = log_w_hat - log_q

log_W_bar = log_w_bar.logsumexp(0)
weight_bar = (log_w_bar - log_W_bar).exp()
surrogate_elbo = weight_bar.dot(log_p_tilde) + weight_bar.pow(2).dot(log_w_hat)
elbo = log_W_bar - math.log(self.num_particles)
loss = -elbo + surrogate_elbo.detach() - surrogate_elbo

warn_if_nan(loss, "loss")
return loss
46 changes: 45 additions & 1 deletion tests/infer/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,17 @@
Trace_ELBO,
TraceEnum_ELBO,
TraceGraph_ELBO,
TraceIWAE_ELBO,
TraceMeanField_ELBO,
config_enumerate,
)
from pyro.optim import Adam
from tests.common import assert_equal, xfail_if_not_implemented, xfail_param
from tests.common import (
assert_close,
assert_equal,
xfail_if_not_implemented,
xfail_param,
)

logger = logging.getLogger(__name__)

Expand All @@ -47,13 +53,16 @@ def DiffTrace_ELBO(*args, **kwargs):
(TraceMeanField_ELBO, False),
(TraceEnum_ELBO, False),
(TraceEnum_ELBO, True),
(TraceIWAE_ELBO, False),
])
def test_subsample_gradient(Elbo, reparameterized, has_rsample, subsample, local_samples, scale):
pyro.clear_param_store()
data = torch.tensor([-0.5, 2.0])
subsample_size = 1 if subsample else len(data)
precision = 0.06 * scale
Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal
if Elbo is TraceIWAE_ELBO and not (reparameterized and has_rsample):
pytest.skip("not implemented")

def model(subsample):
with pyro.plate("data", len(data), subsample_size, subsample) as ind:
Expand Down Expand Up @@ -102,6 +111,41 @@ def guide(subsample):
assert_equal(actual_grads, expected_grads, prec=precision)


@pytest.mark.parametrize("Elbo", [
Trace_ELBO,
TraceEnum_ELBO,
TraceIWAE_ELBO,
])
def test_zero_gradient(Elbo):
data = torch.tensor([0.0, 2.0, 2.0])

def model(data):
z = pyro.sample("z", dist.Normal(0, 1))
with pyro.plate("data", len(data)):
pyro.sample("x", dist.Normal(z, 1), obs=data)

# This guide should be the true posterior.
def guide(data):
loc = pyro.param("loc", lambda: torch.tensor(1.0))
scale = pyro.param("scale", lambda: torch.tensor(0.5))
pyro.sample("z", dist.Normal(loc, scale))

elbo = Elbo(max_plate_nesting=1,
num_particles=10000,
vectorize_particles=True,
strict_enumeration_warning=False)

loss = elbo.differentiable_loss(model, guide, data)
logger.info('loss = {}'.format(loss.item()))
assert_close(loss.item(), 5.44996, atol=0.001)

params = dict(pyro.get_param_store().named_parameters())
grads = torch.autograd.grad(loss, params.values())
for name, grad in zip(params, grads):
logger.info('grad {} = {}'.format(name, grad))
assert_close(grad, torch.zeros_like(grad), atol=0.01)


@pytest.mark.parametrize("reparameterized", [True, False], ids=["reparam", "nonreparam"])
@pytest.mark.parametrize("Elbo", [Trace_ELBO, DiffTrace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
def test_plate(Elbo, reparameterized):
Expand Down