diff --git a/edward/__init__.py b/edward/__init__.py index dbf102fdf..ad1517c33 100644 --- a/edward/__init__.py +++ b/edward/__init__.py @@ -14,16 +14,12 @@ bigan_inference, complete_conditional, gan_inference, - implicit_klqp, klpq, klqp, - reparameterization_klqp, - reparameterization_kl_klqp, - reparameterization_entropy_klqp, - score_klqp, - score_kl_klqp, - score_entropy_klqp, - score_rb_klqp, + klqp_implicit, + klqp_reparameterization, + klqp_reparameterization_kl, + klqp_score, laplace, map, wake_sleep, @@ -31,7 +27,7 @@ ) # from edward.inferences import MonteCarlo, HMC, MetropolisHastings, SGLD, SGHMC, Gibbs from edward.models import RandomVariable, Trace -from edward.util import copy, dot, \ +from edward.util import dot, \ get_ancestors, get_blanket, get_children, get_control_variate_coef, \ get_descendants, get_parents, get_siblings, get_variables, \ is_independent, Progbar, random_variables, rbf, \ @@ -53,7 +49,6 @@ 'bigan_inference', 'complete_conditional', 'gan_inference', - 'implicit_klqp', 'MonteCarlo', 'HMC', 'MetropolisHastings', @@ -61,13 +56,10 @@ 'SGHMC', 'klpq', 'klqp', - 'reparameterization_klqp', - 'reparameterization_kl_klqp', - 'reparameterization_entropy_klqp', - 'score_klqp', - 'score_kl_klqp', - 'score_entropy_klqp', - 'score_rb_klqp', + 'klqp_implicit', + 'klqp_reparameterization', + 'klqp_reparameterization_kl', + 'klqp_score', 'laplace', 'map', 'wake_sleep', @@ -75,7 +67,6 @@ 'Gibbs', 'RandomVariable', 'Trace', - 'copy', 'dot', 'get_ancestors', 'get_blanket', diff --git a/edward/inferences/__init__.py b/edward/inferences/__init__.py index 10c664cb3..709e6214f 100644 --- a/edward/inferences/__init__.py +++ b/edward/inferences/__init__.py @@ -9,10 +9,10 @@ from edward.inferences.gan_inference import * # from edward.inferences.gibbs import * # from edward.inferences.hmc import * -from edward.inferences.implicit_klqp import * from edward.inferences.inference import * from edward.inferences.klpq import * from edward.inferences.klqp import * +from edward.inferences.klqp_implicit import * from edward.inferences.laplace import * from edward.inferences.map import * # from edward.inferences.metropolis_hastings import * @@ -28,18 +28,14 @@ 'bigan_inference', 'complete_conditional', 'gan_inference', - 'implicit_klqp', 'Gibbs', 'HMC', 'klpq', 'klqp', - 'reparameterization_klqp', - 'reparameterization_kl_klqp', - 'reparameterization_entropy_klqp', - 'score_klqp', - 'score_kl_klqp', - 'score_entropy_klqp', - 'score_rb_klqp', + 'klqp_implicit', + 'klqp_reparameterization', + 'klqp_reparameterization_kl', + 'klqp_score', 'laplace', 'map', 'MetropolisHastings', diff --git a/edward/inferences/bigan_inference.py b/edward/inferences/bigan_inference.py index 7a440a547..7aedf6d9b 100644 --- a/edward/inferences/bigan_inference.py +++ b/edward/inferences/bigan_inference.py @@ -5,13 +5,12 @@ import six import tensorflow as tf -from edward.inferences.inference import (check_and_maybe_build_data, - check_and_maybe_build_latent_vars, transform, check_and_maybe_build_dict, check_and_maybe_build_var_list) +from edward.models import Trace +from edward.inferences.inference import call_function_up_to_args -def bigan_inference(latent_vars=None, data=None, discriminator=None, - auto_transform=True, scale=None, var_list=None, - collections=None): +def bigan_inference(model, variational, discriminator, align_data, + align_latent, collections=None, *args, **kwargs): """Adversarially Learned Inference [@dumuolin2017adversarially] or Bidirectional Generative Adversarial Networks [@donahue2017adversarial] for joint learning of generator and inference networks. @@ -44,20 +43,23 @@ def bigan_inference(latent_vars=None, data=None, discriminator=None, zf = gen_latent(x_ph) inference = ed.BiGANInference({z_ph: zf}, {xf: x_ph}, discriminator) ``` + + `align_latent` must only align one random variable in `model` and + `variational`. `model` must return the generated data. """ - if not callable(discriminator): - raise TypeError("discriminator must be a callable function.") - latent_vars = check_and_maybe_build_latent_vars(latent_vars) - data = check_and_maybe_build_data(data) - latent_vars, _ = transform(latent_vars, auto_transform) - scale = check_and_maybe_build_dict(scale) - var_list = check_and_maybe_build_var_list(var_list, latent_vars, data) + with Trace() as posterior_trace: + call_function_up_to_args(variational, *args, **kwargs) + with Trace() as model_trace: + x_fake = call_function_up_to_args(model, *args, **kwargs) - x_true = list(six.itervalues(self.data))[0] - x_fake = list(six.iterkeys(self.data))[0] + x_true = align_data(x_fake.name) - z_true = list(six.iterkeys(self.latent_vars))[0] - z_fake = list(six.itervalues(self.latent_vars))[0] + for name, node in six.iteritems(model_trace): + aligned = align_latent(name) + if aligned != name: + z_true = node.value + z_fake = posterior_trace[aligned].value + break with tf.variable_scope("Disc"): # xtzf := x_true, z_fake @@ -80,14 +82,4 @@ def bigan_inference(latent_vars=None, data=None, discriminator=None, loss_d = tf.reduce_mean(loss_d) + tf.reduce_sum(reg_terms_d) loss = tf.reduce_mean(loss) + tf.reduce_sum(reg_terms) - - var_list_d = tf.get_collection( - tf.GraphKeys.TRAINABLE_VARIABLES, scope="Disc") - var_list = tf.get_collection( - tf.GraphKeys.TRAINABLE_VARIABLES, scope="Gen") - - grads_d = tf.gradients(loss_d, var_list_d) - grads = tf.gradients(loss, var_list) - grads_and_vars_d = list(zip(grads_d, var_list_d)) - grads_and_vars = list(zip(grads, var_list)) - return loss, grads_and_vars, loss_d, grads_and_vars_d + return loss, loss_d diff --git a/edward/inferences/conjugacy/conjugacy.py b/edward/inferences/conjugacy/conjugacy.py index a7987b0e5..7623ea7b9 100644 --- a/edward/inferences/conjugacy/conjugacy.py +++ b/edward/inferences/conjugacy/conjugacy.py @@ -12,7 +12,7 @@ from edward.inferences.conjugacy.simplify \ import symbolic_suff_stat, full_simplify, expr_contains, reconstruct_expr from edward.models.random_variables import * -from edward.util import copy, get_blanket +from edward.util import get_blanket def mvn_diag_from_natural_params(p1, p2): diff --git a/edward/inferences/gan_inference.py b/edward/inferences/gan_inference.py index a1b979733..6346aeabd 100644 --- a/edward/inferences/gan_inference.py +++ b/edward/inferences/gan_inference.py @@ -5,12 +5,12 @@ import six import tensorflow as tf -from edward.inferences.inference import (check_and_maybe_build_data, - transform, check_and_maybe_build_dict, check_and_maybe_build_var_list) +from edward.models import Trace +from edward.inferences.inference import call_function_up_to_args -def gan_inference(data=None, discriminator=None, - scale=None, var_list=None, collections=None): +def gan_inference(model, discriminator, align_data, + collections=None, *args, **kwargs): """Parameter estimation with GAN-style training [@goodfellow2014generative]. @@ -55,18 +55,11 @@ def gan_inference(data=None, discriminator=None, Function (with parameters) to discriminate samples. It should output logit probabilities (real-valued) and not probabilities in $[0, 1]$. - var_list: list of tf.Variable, optional. - List of TensorFlow variables to optimize over (in the generative - model). Default is all trainable variables that `data` depends on. + + `model` must return the generated data. """ - if not callable(discriminator): - raise TypeError("discriminator must be a callable function.") - data = check_and_maybe_build_data(data) - scale = check_and_maybe_build_dict(scale) - var_list = check_and_maybe_build_var_list(var_list, {}, data) - - x_true = list(six.itervalues(data))[0] - x_fake = list(six.iterkeys(data))[0] + x_fake = call_function_up_to_args(model, *args, **kwargs) + x_true = align_data(x_fake.name) with tf.variable_scope("Disc"): d_true = discriminator(x_true) @@ -90,14 +83,4 @@ def gan_inference(data=None, discriminator=None, labels=tf.ones_like(d_fake), logits=d_fake) loss_d = tf.reduce_mean(loss_d) + tf.reduce_sum(reg_terms_d) loss = tf.reduce_mean(loss) + tf.reduce_sum(reg_terms) - - var_list_d = tf.get_collection( - tf.GraphKeys.TRAINABLE_VARIABLES, scope="Disc") - if var_list is None: - var_list = [v for v in tf.trainable_variables() if v not in var_list_d] - - grads_d = tf.gradients(loss_d, var_list_d) - grads = tf.gradients(loss, var_list) - grads_and_vars_d = list(zip(grads_d, var_list_d)) - grads_and_vars = list(zip(grads, var_list)) - return loss, grads_and_vars, loss_d, grads_and_vars_d + return loss, loss_d diff --git a/edward/inferences/hmc.py b/edward/inferences/hmc.py index accf59669..61122d3bb 100644 --- a/edward/inferences/hmc.py +++ b/edward/inferences/hmc.py @@ -8,7 +8,6 @@ from collections import OrderedDict from edward.inferences.monte_carlo import MonteCarlo from edward.models import RandomVariable -from edward.util import copy try: from edward.models import Normal, Uniform diff --git a/edward/inferences/inference.py b/edward/inferences/inference.py index 24f8238c4..3023c39ef 100644 --- a/edward/inferences/inference.py +++ b/edward/inferences/inference.py @@ -30,16 +30,18 @@ This file is a collection of functions shared across inference algorithms, used for the following: -+ input checking and default constructors -+ programmatic docstrings ++ (TODO move elsewhere?) call f up to args ++ a "make intercept" factory + automated transforms -+ summaries -+ variable scoping ++ programmatic docstrings + train() -+ for a subset of algs, optimizer and Monte Carlo stuff (TBA). Other files provide functions to help produce the train (and post-training) ops. + +We do no input checking and assume the idiom of duck typing. (Although +sometimes because TensorFlow is statically typed, we check input +types. But we typically try to defer to Python.) """ from __future__ import absolute_import from __future__ import division @@ -52,244 +54,72 @@ from datetime import datetime from edward.models import RandomVariable -from edward.util import get_session, get_variables, Progbar +from edward.util import get_variables, Progbar from edward.util import transform as _transform -from tensorflow.contrib.distributions import bijectors - +tfb = tf.contrib.distributions.bijectors -def check_and_maybe_build_data(data): - """Check that the data dictionary passed during inference and - criticism is valid. - Args: - data: dict. - Data dictionary which binds observed variables (of type - `RandomVariable` or `tf.Tensor`) to their realizations (of - type `tf.Tensor`). It can also bind placeholders (of type - `tf.Tensor`) used in the model to their realizations; and - prior latent variables (of type `RandomVariable`) to posterior - latent variables (of type `RandomVariable`). - """ - sess = get_session() - if data is None: - data = {} - elif not isinstance(data, dict): - raise TypeError("data must have type dict.") - - for key, value in six.iteritems(data): - if isinstance(key, tf.Tensor) and "Placeholder" in key.op.type: - if isinstance(value, RandomVariable): - raise TypeError("The value of a feed cannot be a ed.RandomVariable " - "object. " - "Acceptable feed values include Python scalars, " - "strings, lists, numpy ndarrays, or TensorHandles.") - elif isinstance(value, tf.Tensor): - raise TypeError("The value of a feed cannot be a tf.Tensor object. " - "Acceptable feed values include Python scalars, " - "strings, lists, numpy ndarrays, or TensorHandles.") - elif isinstance(key, (RandomVariable, tf.Tensor)): - if isinstance(value, (RandomVariable, tf.Tensor)): - if not key.shape.is_compatible_with(value.shape): - raise TypeError("Key-value pair in data does not have same " - "shape: {}, {}".format(key.shape, value.shape)) - elif key.dtype != value.dtype: - raise TypeError("Key-value pair in data does not have same " - "dtype: {}, {}".format(key.dtype, value.dtype)) - elif isinstance(value, (float, list, int, np.ndarray, np.number, str)): - if not key.shape.is_compatible_with(np.shape(value)): - raise TypeError("Key-value pair in data does not have same " - "shape: {}, {}".format(key.shape, np.shape(value))) - elif isinstance(value, (np.ndarray, np.number)) and \ - not np.issubdtype(value.dtype, np.float) and \ - not np.issubdtype(value.dtype, np.int) and \ - not np.issubdtype(value.dtype, np.str): - raise TypeError("Data value has an invalid dtype: " - "{}".format(value.dtype)) - else: - raise TypeError("Data value has an invalid type: " - "{}".format(type(value))) +def call_function_up_to_args(f, *args, **kwargs): + import inspect + if hasattr(f, "_func"): # make_template() + argspec = inspect.getargspec(f._func) + else: + argspec = inspect.getargspec(f) + num_kwargs = len(argspec.defaults) if argspec.defaults is not None else 0 + num_args = len(argspec.args) - num_kwargs + if num_args > 0: + return f(args[:num_args], **kwargs) + elif num_kwargs > 0: + return f(**kwargs) + return f() + + +def make_intercept(trace, align_data, align_latent, args, kwargs): + def _intercept(f, *fargs, **fkwargs): + """Set model's sample values to variational distribution's and data.""" + name = fkwargs.get('name', None) + key = align_data(name) + if isinstance(key, int): + fkwargs['value'] = args[key] + elif kwargs.get(key, None) is not None: + fkwargs['value'] = kwargs.get(key) else: - raise TypeError("Data key has an invalid type: {}".format(type(key))) - - processed_data = {} - for key, value in six.iteritems(data): - if isinstance(key, tf.Tensor) and "Placeholder" in key.op.type: - processed_data[key] = value - elif isinstance(key, (RandomVariable, tf.Tensor)): - if isinstance(value, (RandomVariable, tf.Tensor)): - processed_data[key] = value - elif isinstance(value, (float, list, int, np.ndarray, np.number, str)): - # If value is a Python type, store it in the graph. - # Assign its placeholder with the key's data type. - with tf.variable_scope(None, default_name="data"): - ph = tf.placeholder(key.dtype, np.shape(value)) - var = tf.Variable(ph, trainable=False, collections=[]) - sess.run(var.initializer, {ph: value}) - processed_data[key] = var - return processed_data - - -def check_and_maybe_build_latent_vars(latent_vars): - """Check that the latent variable dictionary passed during inference and - criticism is valid. - - Args: - latent_vars: dict. - Collection of latent variables (of type `RandomVariable` or - `tf.Tensor`) to perform inference on. Each random variable is - binded to another random variable; the latter will infer the - former conditional on data. - """ - if latent_vars is None: - latent_vars = {} - elif not isinstance(latent_vars, dict): - raise TypeError("latent_vars must have type dict.") - - for key, value in six.iteritems(latent_vars): - if not isinstance(key, (RandomVariable, tf.Tensor)): - raise TypeError("Latent variable key has an invalid type: " - "{}".format(type(key))) - elif not isinstance(value, (RandomVariable, tf.Tensor)): - raise TypeError("Latent variable value has an invalid type: " - "{}".format(type(value))) - elif not key.shape.is_compatible_with(value.shape): - raise TypeError("Key-value pair in latent_vars does not have same " - "shape: {}, {}".format(key.shape, value.shape)) - elif key.dtype != value.dtype: - raise TypeError("Key-value pair in latent_vars does not have same " - "dtype: {}, {}".format(key.dtype, value.dtype)) - return latent_vars - - -def check_and_maybe_build_dict(x): - if x is None: - x = {} - elif not isinstance(x, dict): - raise TypeError("x must be dict; got {}".format(type(x).__name__)) - return x - - -def check_and_maybe_build_var_list(var_list, latent_vars, data): + qz = trace[align_latent(name)].value + fkwargs['value'] = qz.value + # if auto_transform and 'qz' in locals(): + # # TODO for generation to work, must output original dist. to + # keep around TD? must maintain another stack to write to as a + # side-effect (or augment the original stack). + # return transform(f, qz, *fargs, **fkwargs) + return f(*fargs, **fkwargs) + return _intercept + + +def transform(f, qz, *args, **kwargs): + """Transform prior -> unconstrained -> q's constraint. + + When using in VI, we keep variational distribution on its original + space (for sake of implementing only one intercepting function). """ - Returns: - List of TensorFlow variables to optimize over. Default is all - trainable variables that `latent_vars` and `data` depend on, - excluding those that are only used in conditionals in `data`. - """ - # Traverse random variable graphs to get default list of variables. - if var_list is None: - var_list = set() - trainables = tf.trainable_variables() - for z, qz in six.iteritems(latent_vars): - var_list.update(get_variables(z, collection=trainables)) - var_list.update(get_variables(qz, collection=trainables)) - - for x, qx in six.iteritems(data): - if isinstance(x, RandomVariable) and \ - not isinstance(qx, RandomVariable): - var_list.update(get_variables(x, collection=trainables)) - - var_list = list(var_list) - return var_list - - -def transform(latent_vars, auto_transform=True): - """ - Args: - auto_transform: bool, optional. - Whether to automatically transform continuous latent variables - of unequal support to be on the unconstrained space. It is - only applied if the argument is `True`, the latent variable - pair are `ed.RandomVariable`s with the `support` attribute, - the supports are both continuous and unequal. - """ - # map from original latent vars to unconstrained versions - if auto_transform: - latent_vars_temp = latent_vars.copy() - # latent_vars maps original latent vars to constrained Q's. - # latent_vars_unconstrained maps unconstrained vars to unconstrained Q's. - latent_vars = {} - latent_vars_unconstrained = {} - for z, qz in six.iteritems(latent_vars_temp): - if hasattr(z, 'support') and hasattr(qz, 'support') and \ - z.support != qz.support and qz.support != 'point': - - # transform z to an unconstrained space - z_unconstrained = _transform(z) - - # make sure we also have a qz that covers the unconstrained space - if qz.support == "points": - qz_unconstrained = qz - else: - qz_unconstrained = _transform(qz) - latent_vars_unconstrained[z_unconstrained] = qz_unconstrained - - # additionally construct the transformation of qz - # back into the original constrained space - if z_unconstrained != z: - qz_constrained = _transform( - qz_unconstrained, bijectors.Invert(z_unconstrained.bijector)) - - try: # attempt to pushforward the params of Empirical distributions - qz_constrained.params = z_unconstrained.bijector.inverse( - qz_unconstrained.params) - except: # qz_unconstrained is not an Empirical distribution - pass - - else: - qz_constrained = qz_unconstrained - - latent_vars[z] = qz_constrained - else: - latent_vars[z] = qz - latent_vars_unconstrained[z] = qz + # TODO deal with f or qz being 'point' or 'points' + if (not hasattr(f, 'support') or not hasattr(qz, 'support') or + f.support == qz.support): + return f(*args, **kwargs) + value = kwargs.pop('value') + kwargs['value'] = 0.0 # to avoid sampling; TODO follow sample shape + rv = f(*args, **kwargs) + # Take shortcuts in logic if p or q are already unconstrained. + if qz.support in ('real', 'multivariate_real'): + return _transform(rv, value=value) + if rv.support in ('real', 'multivariate_real'): + rv_unconstrained = rv else: - latent_vars_unconstrained = None - return latent_vars, latent_vars_unconstrained - - -def summary_variables(latent_vars=None, data=None, variables=None, - *args, **kwargs): - # Note: to use summary_key, set - # collections=[tf.get_default_graph().unique_name("summaries")] - # TODO include in TensorBoard tutorial - """Log variables to TensorBoard. - - For each variable in `variables`, forms a `tf.summary.scalar` if - the variable has scalar shape; otherwise forms a `tf.summary.histogram`. - - Args: - variables: list, optional. - Specifies the list of variables to log after each `n_print` - steps. If None, will log all variables. If `[]`, no variables - will be logged. - """ - if variables is None: - variables = [] - for key in six.iterkeys(data): - variables += get_variables(key) - - for key, value in six.iteritems(latent_vars): - variables += get_variables(key) - variables += get_variables(value) - - variables = set(variables) - - for var in variables: - # replace colons which are an invalid character - var_name = var.name.replace(':', '/') - # Log all scalars. - if len(var.shape) == 0: - tf.summary.scalar("parameter/{}".format(var_name), - var, *args, **kwargs) - elif len(var.shape) == 1 and var.shape[0] == 1: - tf.summary.scalar("parameter/{}".format(var_name), - var[0], *args, **kwargs) - else: - # If var is multi-dimensional, log a histogram of its values. - tf.summary.histogram("parameter/{}".format(var_name), - var, *args, **kwargs) + rv_unconstrained = _transform(rv, value=0.0) + unconstrained_to_constrained = tfb.Invert(_transform(qz).bijector) + return _transform(rv_unconstrained, + unconstrained_to_constrained, + value=value) def train(train_op, summary_key=None, n_iter=1000, n_print=None, @@ -351,7 +181,7 @@ def train(train_op, summary_key=None, n_iter=1000, n_print=None, if summary_key is not None: # TODO should run() also add summaries; or should user call - # summary_variables() manually? + # _summary_variables() manually? summarize = tf.summary.merge_all(key=summary_key) if log_timestamp: logdir = os.path.expanduser(logdir) @@ -382,7 +212,7 @@ def train(train_op, summary_key=None, n_iter=1000, n_print=None, threads = tf.train.start_queue_runners(coord=coord) for _ in range(n_iter): - info_dict = update(progbar, n_print, summarize, + info_dict = _update(progbar, n_print, summarize, train_writer, debug, op_check, train_op, *args, **kwargs) @@ -400,8 +230,52 @@ def train(train_op, summary_key=None, n_iter=1000, n_print=None, coord.request_stop() coord.join(threads) -def optimize(loss, grads_and_vars, collections=None, var_list=None, - optimizer=None, use_prettytensor=False, global_step=None): + +def _summary_variables(latent_vars=None, data=None, variables=None, + *args, **kwargs): + # Note: to use summary_key, set + # collections=[tf.get_default_graph().unique_name("summaries")] + # TODO include in TensorBoard tutorial + """Log variables to TensorBoard. + + For each variable in `variables`, forms a `tf.summary.scalar` if + the variable has scalar shape; otherwise forms a `tf.summary.histogram`. + + Args: + variables: list, optional. + Specifies the list of variables to log after each `n_print` + steps. If None, will log all variables. If `[]`, no variables + will be logged. + """ + if variables is None: + variables = [] + for key in six.iterkeys(data): + variables += get_variables(key) + + for key, value in six.iteritems(latent_vars): + variables += get_variables(key) + variables += get_variables(value) + + variables = set(variables) + + for var in variables: + # replace colons which are an invalid character + var_name = var.name.replace(':', '/') + # Log all scalars. + if len(var.shape) == 0: + tf.summary.scalar("parameter/{}".format(var_name), + var, *args, **kwargs) + elif len(var.shape) == 1 and var.shape[0] == 1: + tf.summary.scalar("parameter/{}".format(var_name), + var[0], *args, **kwargs) + else: + # If var is multi-dimensional, log a histogram of its values. + tf.summary.histogram("parameter/{}".format(var_name), + var, *args, **kwargs) + + +def _optimize(loss, grads_and_vars, collections=None, var_list=None, + optimizer=None, use_prettytensor=False, global_step=None): """Build optimizer and its train op applied to loss or grads_and_vars. @@ -481,8 +355,8 @@ def optimize(loss, grads_and_vars, collections=None, var_list=None, return train_op -def update(progbar, n_print, summarize=None, train_writer=None, - debug=False, op_check=None, *args, **kwargs): +def _update(progbar, n_print, summarize=None, train_writer=None, + debug=False, op_check=None, *args, **kwargs): """Run one iteration of optimization. Args: @@ -529,8 +403,8 @@ def update(progbar, n_print, summarize=None, train_writer=None, # TODO within run(), use this for gan_inference, wgan_inference, # implicit_klqp, bigan_inference -def update(train_op, train_op_d, n_print, summarize=None, train_writer=None, - debug=False, op_check=None, variables=None, *args, **kwargs): +def _update(train_op, train_op_d, n_print, summarize=None, train_writer=None, + debug=False, op_check=None, variables=None, *args, **kwargs): """Run one iteration of optimization. Args: @@ -581,8 +455,8 @@ def update(train_op, train_op_d, n_print, summarize=None, train_writer=None, return dict(zip(kwargs_temp.keys(), values)) # TODO within run(), use this for wgan_inference -def update(clip_op, variables=None, *args, **kwargs): - info_dict = gan_inference.update(variables=variables, *args, **kwargs) +def _update(clip_op, variables=None, *args, **kwargs): + info_dict = gan_inference._update(variables=variables, *args, **kwargs) sess = get_session() if clip_op is not None and variables in (None, "Disc"): diff --git a/edward/inferences/klpq.py b/edward/inferences/klpq.py index 797612a74..b72681b3c 100644 --- a/edward/inferences/klpq.py +++ b/edward/inferences/klpq.py @@ -5,10 +5,9 @@ import six import tensorflow as tf -from edward.inferences.inference import (check_and_maybe_build_data, - check_and_maybe_build_latent_vars, transform, check_and_maybe_build_dict, check_and_maybe_build_var_list) -from edward.models import RandomVariable -from edward.util import copy, get_descendants +from edward.models import Trace +from edward.inferences.inference import (call_function_up_to_args, + make_intercept) try: from edward.models import Normal @@ -16,8 +15,9 @@ raise ImportError("{0}. Your TensorFlow version is not supported.".format(e)) -def klpq(latent_vars=None, data=None, n_samples=1, - auto_transform=True, scale=None, var_list=None, collections=None): +def klpq(model, variational, align_latent, align_data, + scale=lambda name: 1.0, n_samples=1, auto_transform=True, + collections=None, *args, **kwargs): """Variational inference with the KL divergence $\\text{KL}( p(z \mid x) \| q(z) ).$ @@ -91,65 +91,30 @@ def klpq(latent_vars=None, data=None, n_samples=1, $- \sum_{s=1}^S [ w_{\\text{norm}}(z^s; \lambda) \\nabla_{\lambda} \log q(z^s; \lambda) ].$ """ - if isinstance(latent_vars, list): - with tf.variable_scope(None, default_name="posterior"): - latent_vars_dict = {} - continuous = \ - ('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') - for z in latent_vars: - if not hasattr(z, 'support') or z.support not in continuous: - raise AttributeError( - "Random variable {} is not continuous or a random " - "variable with supported continuous support.".format(z)) - batch_event_shape = z.batch_shape.concatenate(z.event_shape) - loc = tf.Variable(tf.random_normal(batch_event_shape)) - scale = tf.nn.softplus( - tf.Variable(tf.random_normal(batch_event_shape))) - latent_vars_dict[z] = Normal(loc=loc, scale=scale) - latent_vars = latent_vars_dict - del latent_vars_dict - latent_vars = check_and_maybe_build_latent_vars(latent_vars) - data = check_and_maybe_build_data(data) - latent_vars, _ = transform(latent_vars, auto_transform) - scale = check_and_maybe_build_dict(scale) - var_list = check_and_maybe_build_var_list(var_list, latent_vars, data) - p_log_prob = [0.0] * n_samples q_log_prob = [0.0] * n_samples - base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(n_samples): - # Form dictionary in order to replace conditioning on prior or - # observed variable with conditioning on a specific value. - scope = base_scope + tf.get_default_graph().unique_name("sample") - dict_swap = {} - for x, qx in six.iteritems(data): - if isinstance(x, RandomVariable): - if isinstance(qx, RandomVariable): - qx_copy = copy(qx, scope=scope) - dict_swap[x] = qx_copy.value - else: - dict_swap[x] = qx - - for z, qz in six.iteritems(latent_vars): - # Copy q(z) to obtain new set of posterior samples. - qz_copy = copy(qz, scope=scope) - dict_swap[z] = qz_copy.value - q_log_prob[s] += tf.reduce_sum( - qz_copy.log_prob(tf.stop_gradient(dict_swap[z]))) - - for z in six.iterkeys(latent_vars): - z_copy = copy(z, dict_swap, scope=scope) - p_log_prob[s] += tf.reduce_sum(z_copy.log_prob(dict_swap[z])) - - for x in six.iterkeys(data): - if isinstance(x, RandomVariable): - x_copy = copy(x, dict_swap, scope=scope) - p_log_prob[s] += tf.reduce_sum(x_copy.log_prob(dict_swap[x])) + with Trace() as posterior_trace: + call_function_up_to_args(variational, *args, **kwargs) + intercept = make_intercept( + posterior_trace, align_data, align_latent, args, kwargs) + with Trace(intercept=intercept) as model_trace: + call_function_up_to_args(model, *args, **kwargs) + + for name, node in six.iteritems(model_trace): + rv = node.value + scale_factor = scale(name) + p_log_prob[s] += tf.reduce_sum( + scale_factor * rv.log_prob(tf.stop_gradient(rv.value))) + posterior_node = posterior_trace.get(align_latent(name), None) + if posterior_node is not None: + qz = posterior_node.value + q_log_prob[s] += tf.reduce_sum( + scale_factor * qz.log_prob(tf.stop_gradient(qz.value))) p_log_prob = tf.stack(p_log_prob) q_log_prob = tf.stack(q_log_prob) reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses()) - if collections is not None: tf.summary.scalar("loss/p_log_prob", tf.reduce_mean(p_log_prob), collections=collections) @@ -158,18 +123,11 @@ def klpq(latent_vars=None, data=None, n_samples=1, tf.summary.scalar("loss/reg_penalty", reg_penalty, collections=collections) - log_w = p_log_prob - q_log_prob + log_w = p_log_prob - tf.stop_gradient(q_log_prob) log_w_norm = log_w - tf.reduce_logsumexp(log_w) w_norm = tf.exp(log_w_norm) - loss = tf.reduce_sum(w_norm * log_w) - reg_penalty - - q_rvs = list(six.itervalues(latent_vars)) - q_vars = [v for v in var_list - if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0] - q_grads = tf.gradients( - -(tf.reduce_sum(q_log_prob * tf.stop_gradient(w_norm)) - reg_penalty), - q_vars) - p_vars = [v for v in var_list if v not in q_vars] - p_grads = tf.gradients(-loss, p_vars) - grads_and_vars = list(zip(q_grads, q_vars)) + list(zip(p_grads, p_vars)) - return loss, grads_and_vars + loss = -tf.reduce_sum(w_norm * log_w) + reg_penalty + # Model parameter gradients will backprop into loss. Variational + # parameter gradients will backprop into reg_penalty and last term. + surrogate_loss = loss + tf.reduce_sum(q_log_prob * tf.stop_gradient(w_norm)) + return loss, surrogate_loss diff --git a/edward/inferences/klqp.py b/edward/inferences/klqp.py index 8bdf3a674..b191a7c9f 100644 --- a/edward/inferences/klqp.py +++ b/edward/inferences/klqp.py @@ -5,10 +5,9 @@ import six import tensorflow as tf -from edward.inferences.inference import (check_and_maybe_build_data, - check_and_maybe_build_latent_vars, transform, check_and_maybe_build_dict, check_and_maybe_build_var_list) -from edward.models import RandomVariable -from edward.util import copy, get_descendants +from edward.models import Trace +from edward.inferences.inference import (call_function_up_to_args, + make_intercept) try: from edward.models import Normal @@ -16,9 +15,12 @@ except Exception as e: raise ImportError("{0}. Your TensorFlow version is not supported.".format(e)) +tfd = tf.contrib.distributions -def klqp(latent_vars=None, data=None, n_samples=1, kl_scaling=None, - auto_transform=True, scale=None, var_list=None, summary_key=None): + +def klqp(model, variational, align_latent, align_data, + scale=lambda name: 1.0, n_samples=1, kl_scaling=lambda name: 1.0, + auto_transform=True, collections=None, *args, **kwargs): """Variational inference with the KL divergence $\\text{KL}( q(z; \lambda) \| p(z \mid x) ).$ @@ -27,20 +29,38 @@ def klqp(latent_vars=None, data=None, n_samples=1, kl_scaling=None, variety of black box inference techniques. Args: - latent_vars: list of RandomVariable or - dict of RandomVariable to RandomVariable. - Collection of random variables to perform inference on. If - list, each random variable will be implictly optimized using a - `Normal` random variable that is defined internally with a + model: function whose inputs are a subset of `args` (e.g., for + discriminative). Output is not used. + TODO auto_transform docstring + Collection of random variables to perform inference on. + If list, each random variable will be implictly optimized using + a `Normal` random variable that is defined internally with a free parameter per location and scale and is initialized using - standard normal draws. The random variables to approximate - must be continuous. + standard normal draws. The random variables to approximate must + be continuous. + variational: function whose inputs are a subset of `args` (e.g., + for amortized). Output is not used. + align_latent: function of string, aligning `model` latent + variables with `variational`. It takes a model variable's name + as input and returns a string, indexing `variational`'s trace; + else identity. + align_data: function of string, aligning `model` observed + variables with data. It takes a model variable's name as input + and returns an integer, indexing `args`; else identity. + scale: function of string, aligning `model` observed + variables with scale factors. It takes a model variable's name + as input and returns a scale factor; else 1.0. The scale + factor's shape must be broadcastable; it is multiplied + element-wise to the random variable. For example, this is useful + for mini-batch scaling when inferring global variables, or + applying masks on a random variable. n_samples: int, optional. Number of samples from variational model for calculating stochastic gradients. - kl_scaling: dict of RandomVariable to tf.Tensor, optional. - Provides option to scale terms when using ELBO with KL divergence. - If the KL divergence terms are + kl_scaling: function of string, aligning `model` latent + variables with KL scale factors. This provides option to scale + terms when using ELBO with KL divergence. If the KL divergence + terms are $\\alpha_p \mathbb{E}_{q(z\mid x, \lambda)} [ \log q(z\mid x, \lambda) - \log p(z)],$ @@ -48,12 +68,8 @@ def klqp(latent_vars=None, data=None, n_samples=1, kl_scaling=None, then pass {$p(z)$: $\\alpha_p$} as `kl_scaling`, where $\\alpha_p$ is a tensor. Its shape must be broadcastable; it is multiplied element-wise to the batchwise KL terms. - scale: dict of RandomVariable to tf.Tensor, optional. - A tensor to dict computation for any random variable that it is - binded to. Its shape must be broadcastable; it is multiplied - element-wise to the random variable. For example, this is useful - for mini-batch scaling when inferring global variables, or - applying masks on a random variable. + args: data inputs. It is passed at compile-time in Graph + mode or runtime in Eager mode. #### Notes @@ -102,470 +118,115 @@ def klqp(latent_vars=None, data=None, n_samples=1, kl_scaling=None, where the KL term is computed analytically [@kingma2014auto]. We compute this automatically when $p(z)$ and $q(z; \lambda)$ are Normal. - """ - if isinstance(latent_vars, list): - with tf.variable_scope(None, default_name="posterior"): - latent_vars_dict = {} - continuous = \ - ('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') - for z in latent_vars: - if not hasattr(z, 'support') or z.support not in continuous: - raise AttributeError( - "Random variable {} is not continuous or a random " - "variable with supported continuous support.".format(z)) - batch_event_shape = z.batch_shape.concatenate(z.event_shape) - loc = tf.Variable(tf.random_normal(batch_event_shape)) - scale = tf.nn.softplus( - tf.Variable(tf.random_normal(batch_event_shape))) - latent_vars_dict[z] = Normal(loc=loc, scale=scale) - latent_vars = latent_vars_dict - del latent_vars_dict - latent_vars = check_and_maybe_build_latent_vars(latent_vars) - data = check_and_maybe_build_data(data) - latent_vars, _ = transform(latent_vars, auto_transform) - kl_scaling = check_and_maybe_build_dict(kl_scaling) - scale = check_and_maybe_build_dict(scale) - var_list = check_and_maybe_build_var_list(var_list, latent_vars, data) - - is_reparameterizable = all([ - rv.reparameterization_type == - tf.contrib.distributions.FULLY_REPARAMETERIZED - for rv in six.itervalues(latent_vars)]) - is_analytic_kl = all([isinstance(z, Normal) and isinstance(qz, Normal) - for z, qz in six.iteritems(latent_vars)]) - if not is_analytic_kl and kl_scaling: - raise TypeError("kl_scaling must be None when using non-analytic KL term") - if is_reparameterizable: - if is_analytic_kl: - return build_reparam_kl_loss_and_gradients( - latent_vars, data, var_list, - scale, n_samples, kl_scaling, summary_key) - # elif is_analytic_entropy: - # return build_reparam_entropy_loss_and_gradients(...) - else: - return build_reparam_loss_and_gradients( - latent_vars, data, var_list, - scale, n_samples, summary_key) - else: - # Prefer Rao-Blackwellization over analytic KL. Unknown what - # would happen stability-wise if the two are combined. - # if is_analytic_kl: - # return build_score_kl_loss_and_gradients(...) - # Analytic entropies may lead to problems around - # convergence; for now it is deactivated. - # elif is_analytic_entropy: - # return build_score_entropy_loss_and_gradients(...) - # else: - return build_score_rb_loss_and_gradients( - latent_vars, data, var_list, - scale, n_samples, summary_key) - - -def reparameterization_klqp( - latent_vars=None, data=None, n_samples=1, - auto_transform=True, scale=None, var_list=None, summary_key=None): - """Variational inference with the KL divergence - - $\\text{KL}( q(z; \lambda) \| p(z \mid x) ).$ - - This class minimizes the objective using the reparameterization - gradient. - - Args: - latent_vars: list of RandomVariable or - dict of RandomVariable to RandomVariable. - Collection of random variables to perform inference on. If - list, each random variable will be implictly optimized using a - `Normal` random variable that is defined internally with a - free parameter per location and scale and is initialized using - standard normal draws. The random variables to approximate - must be continuous. - n_samples: int, optional. - Number of samples from variational model for calculating - stochastic gradients. - The objective function also adds to itself a summation over all - tensors in the `REGULARIZATION_LOSSES` collection. - """ - if isinstance(latent_vars, list): - with tf.variable_scope(None, default_name="posterior"): - latent_vars_dict = {} - continuous = \ - ('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') - for z in latent_vars: - if not hasattr(z, 'support') or z.support not in continuous: - raise AttributeError( - "Random variable {} is not continuous or a random " - "variable with supported continuous support.".format(z)) - batch_event_shape = z.batch_shape.concatenate(z.event_shape) - loc = tf.Variable(tf.random_normal(batch_event_shape)) - scale = tf.nn.softplus( - tf.Variable(tf.random_normal(batch_event_shape))) - latent_vars_dict[z] = Normal(loc=loc, scale=scale) - latent_vars = latent_vars_dict - del latent_vars_dict - latent_vars = check_and_maybe_build_latent_vars(latent_vars) - data = check_and_maybe_build_data(data) - latent_vars, _ = transform(latent_vars, auto_transform) - scale = check_and_maybe_build_dict(scale) - var_list = check_and_maybe_build_var_list(var_list, latent_vars, data) - - return build_reparam_loss_and_gradients( - latent_vars, data, var_list, - scale, n_samples, summary_key) - - -def reparameterization_kl_klqp( - latent_vars=None, data=None, n_samples=1, kl_scaling=None, - auto_transform=True, scale=None, var_list=None, summary_key=None): - """Variational inference with the KL divergence - - $\\text{KL}( q(z; \lambda) \| p(z \mid x) ).$ + This class minimizes the objective using the score function gradient + and Rao-Blackwellization [@ranganath2014black]. - This class minimizes the objective using the reparameterization - gradient and an analytic KL term. + Computed by sampling from :math:`q(z;\lambda)` and evaluating the + expectation using Monte Carlo sampling and Rao-Blackwellization. - Args: - latent_vars: list of RandomVariable or - dict of RandomVariable to RandomVariable. - Collection of random variables to perform inference on. If - list, each random variable will be implictly optimized using a - `Normal` random variable that is defined internally with a - free parameter per location and scale and is initialized using - standard normal draws. The random variables to approximate - must be continuous. - n_samples: int, optional. - Number of samples from variational model for calculating - stochastic gradients. - kl_scaling: dict of RandomVariable to tf.Tensor, optional. - Provides option to scale terms when using ELBO with KL divergence. - If the KL divergence terms are + The implementation takes the surrogate loss approach. See + @schulman2015stochastic; @ruiz2016generalized; @ritchie2016deep. - $\\alpha_p \mathbb{E}_{q(z\mid x, \lambda)} [ - \log q(z\mid x, \lambda) - \log p(z)],$ + #### Notes - then pass {$p(z)$: $\\alpha_p$} as `kl_scaling`, - where $\\alpha_p$ is a tensor. Its shape must be broadcastable; - it is multiplied element-wise to the batchwise KL terms. + Current Rao-Blackwellization is limited to Rao-Blackwellizing across + stochastic nodes in the computation graph. It does not + Rao-Blackwellize within a node such as when a node represents + multiple random variables via non-scalar batch shape. The objective function also adds to itself a summation over all tensors in the `REGULARIZATION_LOSSES` collection. """ - if isinstance(latent_vars, list): - with tf.variable_scope(None, default_name="posterior"): - latent_vars_dict = {} - continuous = \ - ('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') - for z in latent_vars: - if not hasattr(z, 'support') or z.support not in continuous: - raise AttributeError( - "Random variable {} is not continuous or a random " - "variable with supported continuous support.".format(z)) - batch_event_shape = z.batch_shape.concatenate(z.event_shape) - loc = tf.Variable(tf.random_normal(batch_event_shape)) - scale = tf.nn.softplus( - tf.Variable(tf.random_normal(batch_event_shape))) - latent_vars_dict[z] = Normal(loc=loc, scale=scale) - latent_vars = latent_vars_dict - del latent_vars_dict - latent_vars = check_and_maybe_build_latent_vars(latent_vars) - data = check_and_maybe_build_data(data) - latent_vars, _ = transform(latent_vars, auto_transform) - kl_scaling = check_and_maybe_build_dict(kl_scaling) - scale = check_and_maybe_build_dict(scale) - var_list = check_and_maybe_build_var_list(var_list, latent_vars, data) - - return build_reparam_kl_loss_and_gradients( - latent_vars, data, var_list, - scale, n_samples, kl_scaling, summary_key) - - -def reparameterization_entropy_klqp( - latent_vars=None, data=None, n_samples=1, - auto_transform=True, scale=None, var_list=None, summary_key=None): - """Variational inference with the KL divergence - - $\\text{KL}( q(z; \lambda) \| p(z \mid x) ).$ + # TODO baseline, control variate, learnable baseline + # TODO deal with new parameters defined at runtime + # TODO implement kl after rb + # is_analytic_kl = all([isinstance(z, Normal) and isinstance(qz, Normal) + # for z, qz in six.iteritems(latent_vars)]) + # kl_penalty = 0.0 + # for name, node in six.iteritems(model_trace): + # rv = node.value + # posterior_node = posterior_trace.get(align_latent(name), None) + # if posterior_node is not None: + # qz = posterior_node.value + # kl_penalty += tf.reduce_sum(kl_scaling(name) * kl_divergence(qz, rv)) + p_log_prob = [None] * n_samples + q_log_prob = [None] * n_samples + surrogate_loss = [None] * n_samples + for s in range(n_samples): + with Trace() as posterior_trace: + call_function_up_to_args(variational, *args, **kwargs) + intercept = make_intercept( + posterior_trace, align_data, align_latent, args, kwargs) + with Trace(intercept=intercept) as model_trace: + call_function_up_to_args(model, *args, **kwargs) + + # Collect key-value pairs of (rv, rv's (scaled) log prob). + p_dict = {} + q_dict = {} + inverse_align_latent = {} + for name, node in six.iteritems(model_trace): + rv = node.value + scale_factor = scale(name) + p_dict[rv] = tf.reduce_sum(scale_factor * rv.log_prob(rv.value)) + posterior_node = posterior_trace.get(align_latent(name), None) + if posterior_node is not None: + qz = posterior_node.value + q_dict[qz] = tf.reduce_sum(scale_factor * qz.log_prob(qz.value)) + inverse_align_latent[qz] = rv + + # Build surrogate loss. + scaled_q_log_prob = 0.0 + for qz, log_prob in six.iteritems(q_dict): + if qz.reparameterization_type == tfd.FULLY_REPARAMETERIZED: + scale_factor = 1.0 + else: + scale_factor = 0.0 + for rv in qz.get_blanket(q_rvs) + [qz]: + scale_factor += q_dict[rv] + scale_factor -= p_dict[inverse_align_latent[qz]] + scaled_q_log_prob += scale_factor * log_prob + + p_log_prob_s = tf.reduce_sum(list(six.itervalues(p_dict))) + p_log_prob[s] = p_log_prob_s + q_log_prob[s] = tf.reduce_sum(list(six.itervalues(q_dict))) + surrogate_loss[s] = scaled_q_log_prob - p_log_prob_s - This class minimizes the objective using the reparameterization - gradient and an analytic entropy term. + p_log_prob = tf.reduce_mean(p_log_prob) + q_log_prob = tf.reduce_mean(q_log_prob) + surrogate_loss = tf.reduce_mean(surrogate_loss) - Args: - latent_vars: list of RandomVariable or - dict of RandomVariable to RandomVariable. - Collection of random variables to perform inference on. If - list, each random variable will be implictly optimized using a - `Normal` random variable that is defined internally with a - free parameter per location and scale and is initialized using - standard normal draws. The random variables to approximate - must be continuous. - n_samples: int, optional. - Number of samples from variational model for calculating - stochastic gradients. + reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses()) + surrogate_loss += reg_penalty - The objective function also adds to itself a summation over all - tensors in the `REGULARIZATION_LOSSES` collection. - """ - if isinstance(latent_vars, list): - with tf.variable_scope(None, default_name="posterior"): - latent_vars_dict = {} - continuous = \ - ('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') - for z in latent_vars: - if not hasattr(z, 'support') or z.support not in continuous: - raise AttributeError( - "Random variable {} is not continuous or a random " - "variable with supported continuous support.".format(z)) - batch_event_shape = z.batch_shape.concatenate(z.event_shape) - loc = tf.Variable(tf.random_normal(batch_event_shape)) - scale = tf.nn.softplus( - tf.Variable(tf.random_normal(batch_event_shape))) - latent_vars_dict[z] = Normal(loc=loc, scale=scale) - latent_vars = latent_vars_dict - del latent_vars_dict - latent_vars = check_and_maybe_build_latent_vars(latent_vars) - data = check_and_maybe_build_data(data) - latent_vars, _ = transform(latent_vars, auto_transform) - scale = check_and_maybe_build_dict(scale) - var_list = check_and_maybe_build_var_list(var_list, latent_vars, data) + if collections is not None: + tf.summary.scalar("loss/p_log_prob", p_log_prob, + collections=collections) + tf.summary.scalar("loss/q_log_prob", q_log_prob, + collections=collections) + tf.summary.scalar("loss/reg_penalty", reg_penalty, + collections=collections) - return build_reparam_entropy_loss_and_gradients( - latent_vars, data, var_list, - scale, n_samples, summary_key) + loss = q_log_prob - p_log_prob + reg_penalty + return loss, surrogate_loss -def score_klqp( - latent_vars=None, data=None, n_samples=1, - auto_transform=True, scale=None, var_list=None, summary_key=None): +def klqp_reparameterization( + model, variational, align_latent, align_data, + scale=lambda name: 1.0, n_samples=1, auto_transform=True, + collections=None, *args, **kwargs): """Variational inference with the KL divergence $\\text{KL}( q(z; \lambda) \| p(z \mid x) ).$ - This class minimizes the objective using the score function + This class minimizes the objective using the reparameterization gradient. - Args: - latent_vars: list of RandomVariable or - dict of RandomVariable to RandomVariable. - Collection of random variables to perform inference on. If - list, each random variable will be implictly optimized using a - `Normal` random variable that is defined internally with a - free parameter per location and scale and is initialized using - standard normal draws. The random variables to approximate - must be continuous. - The objective function also adds to itself a summation over all tensors in the `REGULARIZATION_LOSSES` collection. - """ - if isinstance(latent_vars, list): - with tf.variable_scope(None, default_name="posterior"): - latent_vars_dict = {} - continuous = \ - ('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') - for z in latent_vars: - if not hasattr(z, 'support') or z.support not in continuous: - raise AttributeError( - "Random variable {} is not continuous or a random " - "variable with supported continuous support.".format(z)) - batch_event_shape = z.batch_shape.concatenate(z.event_shape) - loc = tf.Variable(tf.random_normal(batch_event_shape)) - scale = tf.nn.softplus( - tf.Variable(tf.random_normal(batch_event_shape))) - latent_vars_dict[z] = Normal(loc=loc, scale=scale) - latent_vars = latent_vars_dict - del latent_vars_dict - latent_vars = check_and_maybe_build_latent_vars(latent_vars) - data = check_and_maybe_build_data(data) - latent_vars, _ = transform(latent_vars, auto_transform) - scale = check_and_maybe_build_dict(scale) - var_list = check_and_maybe_build_var_list(var_list, latent_vars, data) - - return build_score_loss_and_gradients( - latent_vars, data, var_list, - scale, n_samples, summary_key) - - -def score_kl_klqp( - latent_vars=None, data=None, n_samples=1, kl_scaling=None, - auto_transform=True, scale=None, var_list=None, summary_key=None): - """Variational inference with the KL divergence - - $\\text{KL}( q(z; \lambda) \| p(z \mid x) ).$ - - This class minimizes the objective using the score function gradient - and an analytic KL term. - - Args: - latent_vars: list of RandomVariable or - dict of RandomVariable to RandomVariable. - Collection of random variables to perform inference on. If - list, each random variable will be implictly optimized using a - `Normal` random variable that is defined internally with a - free parameter per location and scale and is initialized using - standard normal draws. The random variables to approximate - must be continuous. - n_samples: int, optional. - Number of samples from variational model for calculating - stochastic gradients. - kl_scaling: dict of RandomVariable to tf.Tensor, optional. - Provides option to scale terms when using ELBO with KL divergence. - If the KL divergence terms are - - $\\alpha_p \mathbb{E}_{q(z\mid x, \lambda)} [ - \log q(z\mid x, \lambda) - \log p(z)],$ - then pass {$p(z)$: $\\alpha_p$} as `kl_scaling`, - where $\\alpha_p$ is a tensor. Its shape must be broadcastable; - it is multiplied element-wise to the batchwise KL terms. - - The objective function also adds to itself a summation over all - tensors in the `REGULARIZATION_LOSSES` collection. - """ - if isinstance(latent_vars, list): - with tf.variable_scope(None, default_name="posterior"): - latent_vars_dict = {} - continuous = \ - ('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') - for z in latent_vars: - if not hasattr(z, 'support') or z.support not in continuous: - raise AttributeError( - "Random variable {} is not continuous or a random " - "variable with supported continuous support.".format(z)) - batch_event_shape = z.batch_shape.concatenate(z.event_shape) - loc = tf.Variable(tf.random_normal(batch_event_shape)) - scale = tf.nn.softplus( - tf.Variable(tf.random_normal(batch_event_shape))) - latent_vars_dict[z] = Normal(loc=loc, scale=scale) - latent_vars = latent_vars_dict - del latent_vars_dict - latent_vars = check_and_maybe_build_latent_vars(latent_vars) - data = check_and_maybe_build_data(data) - latent_vars, _ = transform(latent_vars, auto_transform) - kl_scaling = check_and_maybe_build_dict(kl_scaling) - scale = check_and_maybe_build_dict(scale) - var_list = check_and_maybe_build_var_list(var_list, latent_vars, data) - - return build_score_kl_loss_and_gradients( - latent_vars, data, var_list, - scale, n_samples, kl_scaling, summary_key) - - -def score_entropy_klqp( - latent_vars=None, data=None, n_samples=1, - auto_transform=True, scale=None, var_list=None, summary_key=None): - """Variational inference with the KL divergence - - $\\text{KL}( q(z; \lambda) \| p(z \mid x) ).$ - - This class minimizes the objective using the score function gradient - and an analytic entropy term. - - Args: - latent_vars: list of RandomVariable or - dict of RandomVariable to RandomVariable. - Collection of random variables to perform inference on. If - list, each random variable will be implictly optimized using a - `Normal` random variable that is defined internally with a - free parameter per location and scale and is initialized using - standard normal draws. The random variables to approximate - must be continuous. - n_samples: int, optional. - Number of samples from variational model for calculating - stochastic gradients. - - The objective function also adds to itself a summation over all - tensors in the `REGULARIZATION_LOSSES` collection. - """ - if isinstance(latent_vars, list): - with tf.variable_scope(None, default_name="posterior"): - latent_vars_dict = {} - continuous = \ - ('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') - for z in latent_vars: - if not hasattr(z, 'support') or z.support not in continuous: - raise AttributeError( - "Random variable {} is not continuous or a random " - "variable with supported continuous support.".format(z)) - batch_event_shape = z.batch_shape.concatenate(z.event_shape) - loc = tf.Variable(tf.random_normal(batch_event_shape)) - scale = tf.nn.softplus( - tf.Variable(tf.random_normal(batch_event_shape))) - latent_vars_dict[z] = Normal(loc=loc, scale=scale) - latent_vars = latent_vars_dict - del latent_vars_dict - latent_vars = check_and_maybe_build_latent_vars(latent_vars) - data = check_and_maybe_build_data(data) - latent_vars, _ = transform(latent_vars, auto_transform) - scale = check_and_maybe_build_dict(scale) - var_list = check_and_maybe_build_var_list(var_list, latent_vars, data) - - return build_score_entropy_loss_and_gradients( - latent_vars, data, var_list, - scale, n_samples, summary_key) - - -def score_rb_klqp( - latent_vars=None, data=None, n_samples=1, - auto_transform=True, scale=None, var_list=None, summary_key=None): - """Variational inference with the KL divergence - - $\\text{KL}( q(z; \lambda) \| p(z \mid x) ).$ - - This class minimizes the objective using the score function gradient - and Rao-Blackwellization. - - Args: - latent_vars: list of RandomVariable or - dict of RandomVariable to RandomVariable. - Collection of random variables to perform inference on. If - list, each random variable will be implictly optimized using a - `Normal` random variable that is defined internally with a - free parameter per location and scale and is initialized using - standard normal draws. The random variables to approximate - must be continuous. - - #### Notes - - Current Rao-Blackwellization is limited to Rao-Blackwellizing across - stochastic nodes in the computation graph. It does not - Rao-Blackwellize within a node such as when a node represents - multiple random variables via non-scalar batch shape. - - The objective function also adds to itself a summation over all - tensors in the `REGULARIZATION_LOSSES` collection. - """ - if isinstance(latent_vars, list): - with tf.variable_scope(None, default_name="posterior"): - latent_vars_dict = {} - continuous = \ - ('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') - for z in latent_vars: - if not hasattr(z, 'support') or z.support not in continuous: - raise AttributeError( - "Random variable {} is not continuous or a random " - "variable with supported continuous support.".format(z)) - batch_event_shape = z.batch_shape.concatenate(z.event_shape) - loc = tf.Variable(tf.random_normal(batch_event_shape)) - scale = tf.nn.softplus( - tf.Variable(tf.random_normal(batch_event_shape))) - latent_vars_dict[z] = Normal(loc=loc, scale=scale) - latent_vars = latent_vars_dict - del latent_vars_dict - latent_vars = check_and_maybe_build_latent_vars(latent_vars) - data = check_and_maybe_build_data(data) - latent_vars, _ = transform(latent_vars, auto_transform) - scale = check_and_maybe_build_dict(scale) - var_list = check_and_maybe_build_var_list(var_list, latent_vars, data) - - return build_score_rb_loss_and_gradients( - latent_vars, data, var_list, - scale, n_samples, summary_key) - - -def build_reparam_loss_and_gradients( - latent_vars, data, var_list, scale, n_samples, summary_key): - """Build loss function. Its automatic differentiation - is a stochastic gradient of + Build loss function equal to KL(q||p) up to a constant. Its + automatic differentiation is a stochastic gradient of $-\\text{ELBO} = -\mathbb{E}_{q(z; \lambda)} [ \log p(x, z) - \log q(z; \lambda) ]$ @@ -574,63 +235,61 @@ def build_reparam_loss_and_gradients( Computed by sampling from $q(z;\lambda)$ and evaluating the expectation using Monte Carlo sampling. + + Note if user defines constrained posterior, then auto_transform + can do inference on real-valued; then test time user can use + constrained. If user defines unconstrained posterior, then how to + work with constrained at test time? For now, user must manually + write the bijectors according to transform. """ p_log_prob = [0.0] * n_samples q_log_prob = [0.0] * n_samples - base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(n_samples): - # Form dictionary in order to replace conditioning on prior or - # observed variable with conditioning on a specific value. - scope = base_scope + tf.get_default_graph().unique_name("sample") - dict_swap = {} - for x, qx in six.iteritems(data): - if isinstance(x, RandomVariable): - if isinstance(qx, RandomVariable): - qx_copy = copy(qx, scope=scope) - dict_swap[x] = qx_copy.value - else: - dict_swap[x] = qx - - for z, qz in six.iteritems(latent_vars): - # Copy q(z) to obtain new set of posterior samples. - qz_copy = copy(qz, scope=scope) - dict_swap[z] = qz_copy.value - q_log_prob[s] += tf.reduce_sum( - scale.get(z, 1.0) * qz_copy.log_prob(dict_swap[z])) - - for z in six.iterkeys(latent_vars): - z_copy = copy(z, dict_swap, scope=scope) - p_log_prob[s] += tf.reduce_sum( - scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z])) - - for x in six.iterkeys(data): - if isinstance(x, RandomVariable): - x_copy = copy(x, dict_swap, scope=scope) - p_log_prob[s] += tf.reduce_sum( - scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) + with Trace() as posterior_trace: + call_function_up_to_args(variational, *args, **kwargs) + intercept = make_intercept( + posterior_trace, align_data, align_latent, args, kwargs) + with Trace(intercept=intercept) as model_trace: + call_function_up_to_args(model, *args, **kwargs) + + for name, node in six.iteritems(model_trace): + rv = node.value + scale_factor = scale(name) + p_log_prob[s] += tf.reduce_sum(scale_factor * rv.log_prob(rv.value)) + posterior_node = posterior_trace.get(align_latent(name), None) + if posterior_node is not None: + qz = posterior_node.value + q_log_prob[s] += tf.reduce_sum(scale_factor * qz.log_prob(qz.value)) p_log_prob = tf.reduce_mean(p_log_prob) q_log_prob = tf.reduce_mean(q_log_prob) reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses()) - - if summary_key is not None: + if collections is not None: tf.summary.scalar("loss/p_log_prob", p_log_prob, - collections=[summary_key]) + collections=collections) tf.summary.scalar("loss/q_log_prob", q_log_prob, - collections=[summary_key]) + collections=collections) tf.summary.scalar("loss/reg_penalty", reg_penalty, - collections=[summary_key]) + collections=collections) + loss = q_log_prob - p_log_prob + reg_penalty + return loss + - loss = -(p_log_prob - q_log_prob - reg_penalty) +def klqp_reparameterization_kl( + model, variational, align_latent, align_data, + scale=lambda name: 1.0, n_samples=1, kl_scaling=lambda name: 1.0, + auto_transform=True, collections=None, *args, **kwargs): + """Variational inference with the KL divergence - grads = tf.gradients(loss, var_list) - grads_and_vars = list(zip(grads, var_list)) - return loss, grads_and_vars + $\\text{KL}( q(z; \lambda) \| p(z \mid x) ).$ + This class minimizes the objective using the reparameterization + gradient and an analytic KL term. -def build_reparam_kl_loss_and_gradients( - latent_vars, data, var_list, scale, n_samples, kl_scaling, summary_key): - """Build loss function. Its automatic differentiation + The objective function also adds to itself a summation over all + tensors in the `REGULARIZATION_LOSSES` collection. + + Build loss function. Its automatic differentiation is a stochastic gradient of .. math:: @@ -646,437 +305,114 @@ def build_reparam_kl_loss_and_gradients( expectation using Monte Carlo sampling. """ p_log_lik = [0.0] * n_samples - base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(n_samples): - # Form dictionary in order to replace conditioning on prior or - # observed variable with conditioning on a specific value. - scope = base_scope + tf.get_default_graph().unique_name("sample") - dict_swap = {} - for x, qx in six.iteritems(data): - if isinstance(x, RandomVariable): - if isinstance(qx, RandomVariable): - qx_copy = copy(qx, scope=scope) - dict_swap[x] = qx_copy.value - else: - dict_swap[x] = qx - - for z, qz in six.iteritems(latent_vars): - # Copy q(z) to obtain new set of posterior samples. - qz_copy = copy(qz, scope=scope) - dict_swap[z] = qz_copy.value - - for x in six.iterkeys(data): - if isinstance(x, RandomVariable): - x_copy = copy(x, dict_swap, scope=scope) - p_log_lik[s] += tf.reduce_sum( - scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) + with Trace() as posterior_trace: + call_function_up_to_args(variational, *args, **kwargs) + intercept = make_intercept( + posterior_trace, align_data, align_latent, args, kwargs) + with Trace(intercept=intercept) as model_trace: + call_function_up_to_args(model, *args, **kwargs) + + for name, node in six.iteritems(model_trace): + rv = node.value + scale_factor = scale(name) + p_log_lik[s] += tf.reduce_sum(scale_factor * rv.log_prob(rv.value)) p_log_lik = tf.reduce_mean(p_log_lik) - kl_penalty = tf.reduce_sum([ - tf.reduce_sum(kl_scaling.get(z, 1.0) * kl_divergence(qz, z)) - for z, qz in six.iteritems(latent_vars)]) + kl_penalty = 0.0 + for name, node in six.iteritems(model_trace): + rv = node.value + posterior_node = posterior_trace.get(align_latent(name), None) + if posterior_node is not None: + qz = posterior_node.value + kl_penalty += tf.reduce_sum(kl_scaling(name) * kl_divergence(qz, rv)) reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses()) - - if summary_key is not None: + if collections is not None: tf.summary.scalar("loss/p_log_lik", p_log_lik, - collections=[summary_key]) + collections=collections) tf.summary.scalar("loss/kl_penalty", kl_penalty, - collections=[summary_key]) + collections=collections) tf.summary.scalar("loss/reg_penalty", reg_penalty, - collections=[summary_key]) - - loss = -(p_log_lik - kl_penalty - reg_penalty) + collections=collections) + loss = -p_log_lik + kl_penalty + reg_penalty + return loss - grads = tf.gradients(loss, var_list) - grads_and_vars = list(zip(grads, var_list)) - return loss, grads_and_vars +def klqp_score( + model, variational, align_latent, align_data, + scale=lambda name: 1.0, n_samples=1, auto_transform=True, + collections=None, *args, **kwargs): + """Variational inference with the KL divergence -def build_reparam_entropy_loss_and_gradients( - latent_vars, data, var_list, scale, n_samples, summary_key): - """Build loss function. Its automatic differentiation - is a stochastic gradient of - - $-\\text{ELBO} = -( \mathbb{E}_{q(z; \lambda)} [ \log p(x , z) ] - + \mathbb{H}(q(z; \lambda)) )$ - - based on the reparameterization trick [@kingma2014auto]. - - It assumes the entropy is analytic. - - Computed by sampling from $q(z;\lambda)$ and evaluating the - expectation using Monte Carlo sampling. - """ - p_log_prob = [0.0] * n_samples - base_scope = tf.get_default_graph().unique_name("inference") + '/' - for s in range(n_samples): - # Form dictionary in order to replace conditioning on prior or - # observed variable with conditioning on a specific value. - scope = base_scope + tf.get_default_graph().unique_name("sample") - dict_swap = {} - for x, qx in six.iteritems(data): - if isinstance(x, RandomVariable): - if isinstance(qx, RandomVariable): - qx_copy = copy(qx, scope=scope) - dict_swap[x] = qx_copy.value - else: - dict_swap[x] = qx - - for z, qz in six.iteritems(latent_vars): - # Copy q(z) to obtain new set of posterior samples. - qz_copy = copy(qz, scope=scope) - dict_swap[z] = qz_copy.value - - for z in six.iterkeys(latent_vars): - z_copy = copy(z, dict_swap, scope=scope) - p_log_prob[s] += tf.reduce_sum( - scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z])) - - for x in six.iterkeys(data): - if isinstance(x, RandomVariable): - x_copy = copy(x, dict_swap, scope=scope) - p_log_prob[s] += tf.reduce_sum( - scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) - - p_log_prob = tf.reduce_mean(p_log_prob) - - q_entropy = tf.reduce_sum([ - tf.reduce_sum(qz.entropy()) - for z, qz in six.iteritems(latent_vars)]) - - reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses()) - - if summary_key is not None: - tf.summary.scalar("loss/p_log_prob", p_log_prob, - collections=[summary_key]) - tf.summary.scalar("loss/q_entropy", q_entropy, - collections=[summary_key]) - tf.summary.scalar("loss/reg_penalty", reg_penalty, - collections=[summary_key]) - - loss = -(p_log_prob + q_entropy - reg_penalty) - - grads = tf.gradients(loss, var_list) - grads_and_vars = list(zip(grads, var_list)) - return loss, grads_and_vars - - -def build_score_loss_and_gradients( - latent_vars, data, var_list, scale, n_samples, summary_key): - """Build loss function and gradients based on the score function - estimator [@paisley2012variational]. - - Computed by sampling from $q(z;\lambda)$ and evaluating the - expectation using Monte Carlo sampling. - """ - p_log_prob = [0.0] * n_samples - q_log_prob = [0.0] * n_samples - base_scope = tf.get_default_graph().unique_name("inference") + '/' - for s in range(n_samples): - # Form dictionary in order to replace conditioning on prior or - # observed variable with conditioning on a specific value. - scope = base_scope + tf.get_default_graph().unique_name("sample") - dict_swap = {} - for x, qx in six.iteritems(data): - if isinstance(x, RandomVariable): - if isinstance(qx, RandomVariable): - qx_copy = copy(qx, scope=scope) - dict_swap[x] = qx_copy.value - else: - dict_swap[x] = qx - - for z, qz in six.iteritems(latent_vars): - # Copy q(z) to obtain new set of posterior samples. - qz_copy = copy(qz, scope=scope) - dict_swap[z] = qz_copy.value - q_log_prob[s] += tf.reduce_sum( - scale.get(z, 1.0) * - qz_copy.log_prob(tf.stop_gradient(dict_swap[z]))) - - for z in six.iterkeys(latent_vars): - z_copy = copy(z, dict_swap, scope=scope) - p_log_prob[s] += tf.reduce_sum( - scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z])) - - for x in six.iterkeys(data): - if isinstance(x, RandomVariable): - x_copy = copy(x, dict_swap, scope=scope) - p_log_prob[s] += tf.reduce_sum( - scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) - - p_log_prob = tf.stack(p_log_prob) - q_log_prob = tf.stack(q_log_prob) - reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses()) - - if summary_key is not None: - tf.summary.scalar("loss/p_log_prob", tf.reduce_mean(p_log_prob), - collections=[summary_key]) - tf.summary.scalar("loss/q_log_prob", tf.reduce_mean(q_log_prob), - collections=[summary_key]) - tf.summary.scalar("loss/reg_penalty", reg_penalty, - collections=[summary_key]) - - losses = p_log_prob - q_log_prob - loss = -(tf.reduce_mean(losses) - reg_penalty) - - q_rvs = list(six.itervalues(latent_vars)) - q_vars = [v for v in var_list - if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0] - q_grads = tf.gradients( - -(tf.reduce_mean(q_log_prob * tf.stop_gradient(losses)) - reg_penalty), - q_vars) - p_vars = [v for v in var_list if v not in q_vars] - p_grads = tf.gradients(loss, p_vars) - grads_and_vars = list(zip(q_grads, q_vars)) + list(zip(p_grads, p_vars)) - return loss, grads_and_vars - + $\\text{KL}( q(z; \lambda) \| p(z \mid x) ).$ -def build_score_kl_loss_and_gradients( - latent_vars, data, var_list, scale, n_samples, kl_scaling, summary_key): - """Build loss function and gradients based on the score function - estimator [@paisley2012variational]. + This class minimizes the objective using the score function + gradient. - It assumes the KL is analytic. + Build loss function equal to KL(q||p) up to a constant. It + returns an surrogate loss function whose automatic differentiation + is based on the score function estimator [@paisley2012variational]. Computed by sampling from $q(z;\lambda)$ and evaluating the expectation using Monte Carlo sampling. - """ - p_log_lik = [0.0] * n_samples - q_log_prob = [0.0] * n_samples - base_scope = tf.get_default_graph().unique_name("inference") + '/' - for s in range(n_samples): - # Form dictionary in order to replace conditioning on prior or - # observed variable with conditioning on a specific value. - scope = base_scope + tf.get_default_graph().unique_name("sample") - dict_swap = {} - for x, qx in six.iteritems(data): - if isinstance(x, RandomVariable): - if isinstance(qx, RandomVariable): - qx_copy = copy(qx, scope=scope) - dict_swap[x] = qx_copy.value - else: - dict_swap[x] = qx - - for z, qz in six.iteritems(latent_vars): - # Copy q(z) to obtain new set of posterior samples. - qz_copy = copy(qz, scope=scope) - dict_swap[z] = qz_copy.value - q_log_prob[s] += tf.reduce_sum( - scale.get(z, 1.0) * - qz_copy.log_prob(tf.stop_gradient(dict_swap[z]))) - - for x in six.iterkeys(data): - if isinstance(x, RandomVariable): - x_copy = copy(x, dict_swap, scope=scope) - p_log_lik[s] += tf.reduce_sum( - scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) - - p_log_lik = tf.stack(p_log_lik) - q_log_prob = tf.stack(q_log_prob) - - kl_penalty = tf.reduce_sum([ - tf.reduce_sum(kl_scaling.get(z, 1.0) * kl_divergence(qz, z)) - for z, qz in six.iteritems(latent_vars)]) - - reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses()) - - if summary_key is not None: - tf.summary.scalar("loss/p_log_lik", tf.reduce_mean(p_log_lik), - collections=[summary_key]) - tf.summary.scalar("loss/kl_penalty", kl_penalty, - collections=[summary_key]) - tf.summary.scalar("loss/reg_penalty", reg_penalty, - collections=[summary_key]) - - loss = -(tf.reduce_mean(p_log_lik) - kl_penalty - reg_penalty) - q_rvs = list(six.itervalues(latent_vars)) - q_vars = [v for v in var_list - if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0] - q_grads = tf.gradients( - -(tf.reduce_mean(q_log_prob * tf.stop_gradient(p_log_lik)) - kl_penalty - - reg_penalty), - q_vars) - p_vars = [v for v in var_list if v not in q_vars] - p_grads = tf.gradients(loss, p_vars) - grads_and_vars = list(zip(q_grads, q_vars)) + list(zip(p_grads, p_vars)) - return loss, grads_and_vars - - -def build_score_entropy_loss_and_gradients( - latent_vars, data, var_list, scale, n_samples, summary_key): - """Build loss function and gradients based on the score function - estimator [@paisley2012variational]. - - It assumes the entropy is analytic. - - Computed by sampling from $q(z;\lambda)$ and evaluating the - expectation using Monte Carlo sampling. + The objective function also adds to itself a summation over all + tensors in the `REGULARIZATION_LOSSES` collection. """ p_log_prob = [0.0] * n_samples q_log_prob = [0.0] * n_samples - base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(n_samples): - # Form dictionary in order to replace conditioning on prior or - # observed variable with conditioning on a specific value. - scope = base_scope + tf.get_default_graph().unique_name("sample") - dict_swap = {} - for x, qx in six.iteritems(data): - if isinstance(x, RandomVariable): - if isinstance(qx, RandomVariable): - qx_copy = copy(qx, scope=scope) - dict_swap[x] = qx_copy.value - else: - dict_swap[x] = qx - - for z, qz in six.iteritems(latent_vars): - # Copy q(z) to obtain new set of posterior samples. - qz_copy = copy(qz, scope=scope) - dict_swap[z] = qz_copy.value - q_log_prob[s] += tf.reduce_sum( - scale.get(z, 1.0) * - qz_copy.log_prob(tf.stop_gradient(dict_swap[z]))) - - for z in six.iterkeys(latent_vars): - z_copy = copy(z, dict_swap, scope=scope) - p_log_prob[s] += tf.reduce_sum( - scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z])) - - for x in six.iterkeys(data): - if isinstance(x, RandomVariable): - x_copy = copy(x, dict_swap, scope=scope) - p_log_prob[s] += tf.reduce_sum( - scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) + with Trace() as posterior_trace: + call_function_up_to_args(variational, *args, **kwargs) + intercept = make_intercept( + posterior_trace, align_data, align_latent, args, kwargs) + with Trace(intercept=intercept) as model_trace: + call_function_up_to_args(model, *args, **kwargs) + + for name, node in six.iteritems(model_trace): + rv = node.value + scale_factor = scale(name) + p_log_prob[s] += tf.reduce_sum(scale_factor * rv.log_prob(rv.value)) + posterior_node = posterior_trace.get(align_latent(name), None) + if posterior_node is not None: + qz = posterior_node.value + q_log_prob[s] += tf.reduce_sum( + scale_factor * qz.log_prob(tf.stop_gradient(qz.value))) p_log_prob = tf.stack(p_log_prob) q_log_prob = tf.stack(q_log_prob) - - q_entropy = tf.reduce_sum([ - tf.reduce_sum(qz.entropy()) - for z, qz in six.iteritems(latent_vars)]) - reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses()) - - if summary_key is not None: + if collections is not None: tf.summary.scalar("loss/p_log_prob", tf.reduce_mean(p_log_prob), - collections=[summary_key]) + collections=collections) tf.summary.scalar("loss/q_log_prob", tf.reduce_mean(q_log_prob), - collections=[summary_key]) - tf.summary.scalar("loss/q_entropy", q_entropy, - collections=[summary_key]) + collections=collections) tf.summary.scalar("loss/reg_penalty", reg_penalty, - collections=[summary_key]) + collections=collections) + losses = q_log_prob - p_log_prob + loss = tf.reduce_mean(losses) + reg_penalty + surrogate_loss = (tf.reduce_mean(q_log_prob * tf.stop_gradient(losses)) + + reg_penalty) + return loss, surrogate_loss - loss = -(tf.reduce_mean(p_log_prob) + q_entropy - reg_penalty) - q_rvs = list(six.itervalues(latent_vars)) - q_vars = [v for v in var_list - if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0] - q_grads = tf.gradients( - -(tf.reduce_mean(q_log_prob * tf.stop_gradient(p_log_prob)) + - q_entropy - reg_penalty), - q_vars) - p_vars = [v for v in var_list if v not in q_vars] - p_grads = tf.gradients(loss, p_vars) - grads_and_vars = list(zip(q_grads, q_vars)) + list(zip(p_grads, p_vars)) - return loss, grads_and_vars - - -def build_score_rb_loss_and_gradients( - latent_vars, data, var_list, scale, n_samples, summary_key): - """Build loss function and gradients based on the score function - estimator [@paisley2012variational] and Rao-Blackwellization - [@ranganath2014black]. - - Computed by sampling from :math:`q(z;\lambda)` and evaluating the - expectation using Monte Carlo sampling and Rao-Blackwellization. - """ - # Build tensors for loss and gradient calculations. There is one set - # for each sample from the variational distribution. - p_log_probs = [{}] * n_samples - q_log_probs = [{}] * n_samples - base_scope = tf.get_default_graph().unique_name("inference") + '/' - for s in range(n_samples): - # Form dictionary in order to replace conditioning on prior or - # observed variable with conditioning on a specific value. - scope = base_scope + tf.get_default_graph().unique_name("sample") - dict_swap = {} - for x, qx in six.iteritems(data): - if isinstance(x, RandomVariable): - if isinstance(qx, RandomVariable): - qx_copy = copy(qx, scope=scope) - dict_swap[x] = qx_copy.value - else: - dict_swap[x] = qx - - for z, qz in six.iteritems(latent_vars): - # Copy q(z) to obtain new set of posterior samples. - qz_copy = copy(qz, scope=scope) - dict_swap[z] = qz_copy.value - q_log_probs[s][qz] = tf.reduce_sum( - scale.get(z, 1.0) * - qz_copy.log_prob(tf.stop_gradient(dict_swap[z]))) - - for z in six.iterkeys(latent_vars): - z_copy = copy(z, dict_swap, scope=scope) - p_log_probs[s][z] = tf.reduce_sum( - scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z])) - - for x in six.iterkeys(data): - if isinstance(x, RandomVariable): - x_copy = copy(x, dict_swap, scope=scope) - p_log_probs[s][x] = tf.reduce_sum( - scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) - - # Take gradients of Rao-Blackwellized loss for each variational parameter. - p_rvs = list(six.iterkeys(latent_vars)) + \ - [x for x in six.iterkeys(data) if isinstance(x, RandomVariable)] - q_rvs = list(six.itervalues(latent_vars)) - reverse_latent_vars = {v: k for k, v in six.iteritems(latent_vars)} - grads = [] - grads_vars = [] - for var in var_list: - # Get all variational factors depending on the parameter. - descendants = get_descendants(tf.convert_to_tensor(var), q_rvs) - if len(descendants) == 0: - continue # skip if not a variational parameter - # Get p and q's Markov blanket wrt these latent variables. - var_p_rvs = set() - for qz in descendants: - z = reverse_latent_vars[qz] - var_p_rvs.update(z.get_blanket(p_rvs) + [z]) - - var_q_rvs = set() - for qz in descendants: - var_q_rvs.update(qz.get_blanket(q_rvs) + [qz]) - - pi_log_prob = [0.0] * n_samples - qi_log_prob = [0.0] * n_samples - for s in range(n_samples): - pi_log_prob[s] = tf.reduce_sum([p_log_probs[s][rv] for rv in var_p_rvs]) - qi_log_prob[s] = tf.reduce_sum([q_log_probs[s][rv] for rv in var_q_rvs]) - - pi_log_prob = tf.stack(pi_log_prob) - qi_log_prob = tf.stack(qi_log_prob) - grad = tf.gradients( - -tf.reduce_mean(qi_log_prob * - tf.stop_gradient(pi_log_prob - qi_log_prob)) + - tf.reduce_sum(tf.losses.get_regularization_losses()), - var) - grads.extend(grad) - grads_vars.append(var) - - # Take gradients of total loss function for model parameters. - loss = -(tf.reduce_mean([tf.reduce_sum(list(six.itervalues(p_log_prob))) - for p_log_prob in p_log_probs]) - - tf.reduce_mean([tf.reduce_sum(list(six.itervalues(q_log_prob))) - for q_log_prob in q_log_probs]) - - tf.reduce_sum(tf.losses.get_regularization_losses())) - model_vars = [v for v in var_list if v not in grads_vars] - model_grads = tf.gradients(loss, model_vars) - grads.extend(model_grads) - grads_vars.extend(model_vars) - grads_and_vars = list(zip(grads, grads_vars)) - return loss, grads_and_vars +def _default_constructor(latent_vars): + if isinstance(latent_vars, list): + with tf.variable_scope(None, default_name="posterior"): + latent_vars_dict = {} + continuous = \ + ('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') + for z in latent_vars: + if not hasattr(z, 'support') or z.support not in continuous: + raise AttributeError( + "Random variable {} is not continuous or a random " + "variable with supported continuous support.".format(z)) + batch_event_shape = z.batch_shape.concatenate(z.event_shape) + loc = tf.Variable(tf.random_normal(batch_event_shape)) + scale = tf.nn.softplus( + tf.Variable(tf.random_normal(batch_event_shape))) + latent_vars_dict[z] = Normal(loc=loc, scale=scale) + latent_vars = latent_vars_dict + return latent_vars diff --git a/edward/inferences/implicit_klqp.py b/edward/inferences/klqp_implicit.py similarity index 59% rename from edward/inferences/implicit_klqp.py rename to edward/inferences/klqp_implicit.py index 9dd5cb66a..e420cf7b8 100644 --- a/edward/inferences/implicit_klqp.py +++ b/edward/inferences/klqp_implicit.py @@ -5,15 +5,15 @@ import six import tensorflow as tf -from edward.inferences.inference import (check_and_maybe_build_data, - check_and_maybe_build_latent_vars, transform, check_and_maybe_build_dict, check_and_maybe_build_var_list) -from edward.models import RandomVariable -from edward.util import copy +from edward.models import Trace +from edward.inferences.inference import (call_function_up_to_args, + make_intercept) -def implicit_klqp(latent_vars=None, data=None, discriminator=None, - global_vars=None, ratio_loss='log', - auto_transform=True, scale=None, var_list=None, collections=None): +def klqp_implicit(model, variational, discriminator, align_latent, + align_latent_global, align_data, ratio_log='log', + scale=lambda name: 1.0, auto_transform=True, + collections=None, *args, **kwargs): """Variational inference with implicit probabilistic models [@tran2017deep]. @@ -44,14 +44,6 @@ def implicit_klqp(latent_vars=None, data=None, discriminator=None, Note the type for `discriminator`'s output changes when one passes in the `scale` argument to `initialize()`. - + If `scale` has at most one item, then `discriminator` - outputs a tensor whose multiplication with that element is - broadcastable. (For example, the output is a tensor and the single - scale factor is a scalar.) - + If `scale` has more than one item, then in order to scale - its corresponding output, `discriminator` must output a - dictionary of same size and keys as `scale`. - The objective function also adds to itself a summation over all tensors in the `REGULARIZATION_LOSSES` collection. """ @@ -121,106 +113,98 @@ def implicit_klqp(latent_vars=None, data=None, discriminator=None, function for q's as well, and an additional loop. we opt not to because it complicates the code; + analytic KL/swapping out the penalty term for the globals. + + align_latent aligns all global and local latents; + align_global_latent only aligns global latents. """ - if not callable(discriminator): - raise TypeError("discriminator must be a callable function.") if callable(ratio_loss): ratio_loss = ratio_loss elif ratio_loss == 'log': - ratio_loss = log_loss + ratio_loss = _log_loss elif ratio_loss == 'hinge': - ratio_loss = hinge_loss + ratio_loss = _hinge_loss else: raise ValueError('Ratio loss not found:', ratio_loss) - latent_vars = check_and_maybe_build_latent_vars(latent_vars) - data = check_and_maybe_build_data(data) - global_vars = check_and_maybe_build_latent_vars(global_vars) - latent_vars, _ = transform(latent_vars, auto_transform) - scale = check_and_maybe_build_dict(scale) - var_list = check_and_maybe_build_var_list(var_list, latent_vars, data) + + with Trace() as posterior_trace: + call_function_up_to_args(variational, *args, **kwargs) + global_intercept = make_intercept( + posterior_trace, align_data, align_latent_global, args, kwargs) + with Trace(intercept=global_intercept) as model_trace: + # Intercept model's global latent variables and set to posterior + # samples (but not its locals). + call_function_up_to_args(model, *args, **kwargs) # Collect tensors used in calculation of losses. - scope = tf.get_default_graph().unique_name("inference") - qbeta_sample = {} pbeta_log_prob = 0.0 qbeta_log_prob = 0.0 - for beta, qbeta in six.iteritems(global_vars): - # Draw a sample beta' ~ q(beta) and calculate - # log p(beta') and log q(beta'). - qbeta_sample[beta] = qbeta.value - pbeta_log_prob += tf.reduce_sum(beta.log_prob(qbeta_sample[beta])) - qbeta_log_prob += tf.reduce_sum(qbeta.log_prob(qbeta_sample[beta])) - + qbeta_sample = {} pz_sample = {} qz_sample = {} - for z, qz in six.iteritems(latent_vars): - if z not in global_vars: - # Copy local variables p(z), q(z) to draw samples - # z' ~ p(z | beta'), z' ~ q(z | beta'). - pz_copy = copy(z, dict_swap=qbeta_sample, scope=scope) - pz_sample[z] = pz_copy.value - qz_sample[z] = qz.value - - # Collect x' ~ p(x | z', beta') and x' ~ q(x). - dict_swap = qbeta_sample.copy() - dict_swap.update(qz_sample) x_psample = {} x_qsample = {} - for x, x_data in six.iteritems(data): - if isinstance(x, tf.Tensor): - if "Placeholder" not in x.op.type: - # Copy p(x | z, beta) to get draw p(x | z', beta'). - x_copy = copy(x, dict_swap=dict_swap, scope=scope) - x_psample[x] = x_copy - x_qsample[x] = x_data - elif isinstance(x, RandomVariable): - # Copy p(x | z, beta) to get draw p(x | z', beta'). - x_copy = copy(x, dict_swap=dict_swap, scope=scope) - x_psample[x] = x_copy.value - x_qsample[x] = x_data + for name, node in six.iteritems(model_trace): + # Calculate log p(beta') and log q(beta'). + posterior_node = posterior_trace.get(align_latent_global(name), None) + if posterior_node is not None: + pbeta = node.value + qbeta = posterior_node.value + scale_factor = scale(name) + pbeta_log_prob += tf.reduce_sum( + scale_factor * pbeta.log_prob(pbeta.value)) + qbeta_log_prob += tf.reduce_sum( + scale_factor * qbeta.log_prob(qbeta.value)) + qbeta_sample[pbeta] = qbeta.value + else: + # TODO This assumes implicit variables are tf.Tensors existing + # on the Trace stack. + posterior_node = posterior_trace.get(align_latent(name), None) + if posterior_node is not None: + pz = node.value + qz = posterior_node.value + pz_sample[pz] = pz + qz_sample[qz] = qz + else: + key = align_data(name) + if isinstance(key, int): + data_node = args[key] + elif kwargs.get(key, None) is not None: + data_node = kwargs.get(key) + px = node.value + qx = data_node.value + x_psample[px] = px + x_qsample[qx] = qx + # Collect x' ~ p(x | z', beta') and x' ~ q(x). with tf.variable_scope("Disc"): + # TODO For now, this assumes the discriminator automagically knows + # how to index the dictionaries and computes some forward pass on + # them (which can vary across executions). Dictionaries should be + # improved to be more idiomatic. r_psample = discriminator(x_psample, pz_sample, qbeta_sample) with tf.variable_scope("Disc", reuse=True): r_qsample = discriminator(x_qsample, qz_sample, qbeta_sample) # Form ratio loss and ratio estimator. - if len(scale) <= 1: - loss_d = tf.reduce_mean(ratio_loss(r_psample, r_qsample)) - scale = list(six.itervalues(scale)) - scale = scale[0] if scale else 1.0 - scaled_ratio = tf.reduce_sum(scale * r_qsample) - else: - loss_d = [tf.reduce_mean(ratio_loss(r_psample[key], r_qsample[key])) - for key in six.iterkeys(scale)] - loss_d = tf.reduce_sum(loss_d) - scaled_ratio = [tf.reduce_sum(scale[key] * r_qsample[key]) - for key in six.iterkeys(scale)] - scaled_ratio = tf.reduce_sum(scaled_ratio) + loss_d = 0.0 + scaled_ratio = 0.0 + for key, value in six.iteritems(r_qsample): + loss_d += tf.reduce_mean(ratio_loss(r_psample[key], value)) + scaled_ratio += tf.reduce_sum(scale(key) * value) reg_terms_d = tf.losses.get_regularization_losses(scope="Disc") reg_terms_all = tf.losses.get_regularization_losses() reg_terms = [r for r in reg_terms_all if r not in reg_terms_d] # Form variational objective. - loss = -(pbeta_log_prob - qbeta_log_prob + scaled_ratio - - tf.reduce_sum(reg_terms)) + loss = (qbeta_log_prob - pbeta_log_prob - scaled_ratio + + tf.reduce_sum(reg_terms))) loss_d = loss_d + tf.reduce_sum(reg_terms_d) - - var_list_d = tf.get_collection( - tf.GraphKeys.TRAINABLE_VARIABLES, scope="Disc") - if var_list is None: - var_list = [v for v in tf.trainable_variables() if v not in var_list_d] - - grads = tf.gradients(loss, var_list) - grads_d = tf.gradients(loss_d, var_list_d) - grads_and_vars = list(zip(grads, var_list)) - grads_and_vars_d = list(zip(grads_d, var_list_d)) - return loss, grads_and_vars, loss_d, grads_and_vars_d + return loss, loss_d -def log_loss(psample, qsample): +def _log_loss(psample, qsample): """Point-wise log loss.""" loss = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(psample), logits=psample) + \ @@ -229,7 +213,7 @@ def log_loss(psample, qsample): return loss -def hinge_loss(psample, qsample): +def _hinge_loss(psample, qsample): """Point-wise hinge loss.""" loss = tf.nn.relu(1.0 - psample) + tf.nn.relu(1.0 + qsample) return loss diff --git a/edward/inferences/laplace.py b/edward/inferences/laplace.py index 47a42b13b..8ec318003 100644 --- a/edward/inferences/laplace.py +++ b/edward/inferences/laplace.py @@ -5,10 +5,9 @@ import six import tensorflow as tf +from edward.inferences.inference import call_function_up_to_args from edward.inferences.map import map -from edward.models import PointMass, RandomVariable from edward.util import get_variables -from edward.util import copy, transform try: from edward.models import \ @@ -17,8 +16,9 @@ raise ImportError("{0}. Your TensorFlow version is not supported.".format(e)) -def laplace(latent_vars=None, data=None, - auto_transform=True, scale=None, var_list=None, collections=None): +def laplace(model, variational, align_latent, align_data, + scale=lambda name: 1.0, auto_transform=True, + collections=None, *args, **kwargs): """Laplace approximation [@laplace1986memoir]. It approximates the posterior distribution using a multivariate @@ -75,32 +75,6 @@ def laplace(latent_vars=None, data=None, variable must be a `MultivariateNormalDiag`, `MultivariateNormalTriL`, or `Normal` random variable. """ - if isinstance(latent_vars, list): - with tf.variable_scope(None, default_name="posterior"): - latent_vars_dict = {} - for z in latent_vars: - # Define location to have constrained support and - # unconstrained free parameters. - batch_event_shape = z.batch_shape.concatenate(z.event_shape) - loc = tf.Variable(tf.random_normal(batch_event_shape)) - if hasattr(z, 'support'): - z_transform = transform(z) - if hasattr(z_transform, 'bijector'): - loc = z_transform.bijector.inverse(loc) - scale_tril = tf.Variable(tf.random_normal( - batch_event_shape.concatenate(batch_event_shape[-1]))) - qz = MultivariateNormalTriL(loc=loc, scale_tril=scale_tril) - latent_vars_dict[z] = qz - latent_vars = latent_vars_dict - del latent_vars_dict - elif isinstance(latent_vars, dict): - for qz in six.itervalues(latent_vars): - if not isinstance( - qz, (MultivariateNormalDiag, MultivariateNormalTriL, Normal)): - raise TypeError("Posterior approximation must consist of only " - "MultivariateNormalDiag, MultivariateTriL, or " - "Normal random variables.") - # Store latent variables in a temporary object; MAP will # optimize `PointMass` random variables, which subsequently # optimizes location parameters of the normal approximations. @@ -108,26 +82,48 @@ def laplace(latent_vars=None, data=None, latent_vars = {z: PointMass(params=qz.loc) for z, qz in six.iteritems(latent_vars_normal)} - loss, grads_and_vars = map( - latent_vars, data, - auto_transform, scale, var_list, collections) - def _finalize(loss, latent_vars, latent_vars_normal): - """Function to call after convergence. - - Computes the Hessian at the mode. - """ - hessians = tf.hessians(loss, list(six.itervalues(latent_vars))) - finalize_ops = [] - for z, hessian in zip(six.iterkeys(latent_vars), hessians): - qz = latent_vars_normal[z] - if isinstance(qz, (MultivariateNormalDiag, Normal)): - scale_var = get_variables(qz.variance())[0] - scale = 1.0 / tf.diag_part(hessian) - else: # qz is MultivariateNormalTriL - scale_var = get_variables(qz.covariance())[0] - scale = tf.matrix_inverse(tf.cholesky(hessian)) - - finalize_ops.append(scale_var.assign(scale)) - return tf.group(*finalize_ops) + variational_pointmass = _make_variational_pointmass( + variational, *args, **kwargs) + loss = map(model, variational, align_latent, align_data, + scale, auto_transform, collections, *args, **kwargs) finalize_op = _finalize(loss, latent_vars, latent_vars_normal) - return loss, grads_and_vars, finalize_op + return loss, finalize_op + + +def _finalize(loss, variational): + """Function to call after convergence. + + Computes the Hessian at the mode. + """ + with Trace() as trace: + call_function_up_to_args(variational, *args, **kwargs) + hessians = tf.hessians( + loss, [node.value.loc for node in six.itervalues(trace)]) + finalize_ops = [] + for qz, hessian in zip(six.itervalues(trace), hessians): + if isinstance(qz, (MultivariateNormalDiag, Normal)): + scale_var = get_variables(qz.variance())[0] + scale = 1.0 / tf.diag_part(hessian) + else: # qz is MultivariateNormalTriL + scale_var = get_variables(qz.covariance())[0] + scale = tf.matrix_inverse(tf.cholesky(hessian)) + + finalize_ops.append(scale_var.assign(scale)) + return tf.group(*finalize_ops) + + +def _make_variational_pointmass(variational, *args, **kwargs): + """Take a variational program and build a new one that replaces all + random variables with point masses. + + We assume all latent variables are traceable in one execution. + """ + with Trace() as trace: + call_function_up_to_args(variational, *args, **kwargs) + def variational_pointmass(*args, **kwargs): + for name, node in six.iteritems(trace): + qz = node.value + qz_pointmass = PointMass(params=qz.loc, + name=qz.name + "_pointmass", + value=qz.loc) + return variational_pointmass diff --git a/edward/inferences/map.py b/edward/inferences/map.py index 593f3efbc..49d17f286 100644 --- a/edward/inferences/map.py +++ b/edward/inferences/map.py @@ -5,10 +5,9 @@ import six import tensorflow as tf -from edward.inferences.inference import (check_and_maybe_build_data, - check_and_maybe_build_latent_vars, transform, check_and_maybe_build_dict, check_and_maybe_build_var_list) -from edward.models import RandomVariable, PointMass -from edward.util import copy, transform +from edward.models import Trace +from edward.inferences.inference import (call_function_up_to_args, + make_intercept) try: from tensorflow.contrib.distributions import bijectors @@ -16,8 +15,9 @@ raise ImportError("{0}. Your TensorFlow version is not supported.".format(e)) -def map(latent_vars=None, data=None, - auto_transform=True, scale=None, var_list=None, collections=None): +def map(model, variational, align_latent, align_data, + scale=lambda name: 1.0, auto_transform=True, collections=None, + *args, **kwargs): """Maximum a posteriori. This class implements gradient-based optimization to solve the @@ -95,62 +95,25 @@ def map(latent_vars=None, data=None, $- \log p(x,z).$ """ - if isinstance(latent_vars, list): - with tf.variable_scope(None, default_name="posterior"): - latent_vars_dict = {} - for z in latent_vars: - # Define point masses to have constrained support and - # unconstrained free parameters. - batch_event_shape = z.batch_shape.concatenate(z.event_shape) - params = tf.Variable(tf.random_normal(batch_event_shape)) - if hasattr(z, 'support'): - z_transform = transform(z) - if hasattr(z_transform, 'bijector'): - params = z_transform.bijector.inverse(params) - latent_vars_dict[z] = PointMass(params=params) - latent_vars = latent_vars_dict - del latent_vars_dict - elif isinstance(latent_vars, dict): - for qz in six.itervalues(latent_vars): - if not isinstance(qz, PointMass): - raise TypeError("Posterior approximation must consist of only " - "PointMass random variables.") - latent_vars = check_and_maybe_build_latent_vars(latent_vars) - data = check_and_maybe_build_data(data) - latent_vars, _ = transform(latent_vars, auto_transform) - scale = check_and_maybe_build_dict(scale) - var_list = check_and_maybe_build_var_list(var_list, latent_vars, data) - - # Form dictionary in order to replace conditioning on prior or - # observed variable with conditioning on a specific value. - scope = tf.get_default_graph().unique_name("inference") - dict_swap = {z: qz.value - for z, qz in six.iteritems(latent_vars)} - for x, qx in six.iteritems(data): - if isinstance(x, RandomVariable): - if isinstance(qx, RandomVariable): - dict_swap[x] = qx.value - else: - dict_swap[x] = qx + with Trace() as posterior_trace: + call_function_up_to_args(variational, *args, **kwargs) + intercept = make_intercept( + posterior_trace, align_data, align_latent, args, kwargs) + with Trace(intercept=intercept) as model_trace: + call_function_up_to_args(model, *args, **kwargs) p_log_prob = 0.0 - for z in six.iterkeys(latent_vars): - z_copy = copy(z, dict_swap, scope=scope) - p_log_prob += tf.reduce_sum( - scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z])) - - for x in six.iterkeys(data): - if isinstance(x, RandomVariable): - if dict_swap: - x_copy = copy(x, dict_swap, scope=scope) - else: - x_copy = x - p_log_prob += tf.reduce_sum( - scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) + for name, node in six.iteritems(model_trace): + rv = node.value + scale_factor = scale(name) + p_log_prob += tf.reduce_sum(scale_factor * rv.log_prob(rv.value)) reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses()) - loss = -p_log_prob + reg_penalty + if collections is not None: + tf.summary.scalar("loss/p_log_prob", p_log_prob, + collections=collections) + tf.summary.scalar("loss/reg_penalty", reg_penalty, + collections=collections) - grads = tf.gradients(loss, var_list) - grads_and_vars = list(zip(grads, var_list)) - return loss, grads_and_vars + loss = -p_log_prob + reg_penalty + return loss diff --git a/edward/inferences/metropolis_hastings.py b/edward/inferences/metropolis_hastings.py index 52b6467a9..a78fe115c 100644 --- a/edward/inferences/metropolis_hastings.py +++ b/edward/inferences/metropolis_hastings.py @@ -8,7 +8,7 @@ from collections import OrderedDict from edward.inferences.monte_carlo import MonteCarlo from edward.models import RandomVariable -from edward.util import check_and_maybe_build_latent_vars, copy +from edward.util import check_and_maybe_build_latent_vars try: from edward.models import Uniform diff --git a/edward/inferences/sghmc.py b/edward/inferences/sghmc.py index a91917a51..2f30eeb08 100644 --- a/edward/inferences/sghmc.py +++ b/edward/inferences/sghmc.py @@ -7,7 +7,6 @@ from edward.inferences.monte_carlo import MonteCarlo from edward.models import RandomVariable, Empirical -from edward.util import copy try: from edward.models import Normal diff --git a/edward/inferences/sgld.py b/edward/inferences/sgld.py index 901e3240a..bd58fcfe1 100644 --- a/edward/inferences/sgld.py +++ b/edward/inferences/sgld.py @@ -7,7 +7,6 @@ from edward.inferences.monte_carlo import MonteCarlo from edward.models import RandomVariable -from edward.util import copy try: from edward.models import Normal diff --git a/edward/inferences/wake_sleep.py b/edward/inferences/wake_sleep.py index 4e98fa897..fd392244e 100644 --- a/edward/inferences/wake_sleep.py +++ b/edward/inferences/wake_sleep.py @@ -5,14 +5,14 @@ import six import tensorflow as tf -from edward.inferences.inference import (check_and_maybe_build_data, - check_and_maybe_build_latent_vars, transform, check_and_maybe_build_dict, check_and_maybe_build_var_list) -from edward.models import RandomVariable -from edward.util import copy, get_descendants +from edward.models import Trace +from edward.inferences.inference import (call_function_up_to_args, + make_intercept) -def wake_sleep(latent_vars=None, data=None, n_samples=1, phase_q='sleep', - auto_transform=True, scale=None, var_list=None, collections=None): +def wake_sleep(model, variational, align_latent, align_data, + scale=lambda name: 1.0, n_samples=1, phase_q='sleep', + auto_transform=True, collections=None, *args, **kwargs): """Wake-Sleep algorithm [@hinton1995wake]. Given a probability model $p(x, z; \\theta)$ and variational @@ -68,69 +68,55 @@ def wake_sleep(latent_vars=None, data=None, n_samples=1, phase_q='sleep', (Unlike reparameterization gradients, the sample is held fixed.) """ - latent_vars = check_and_maybe_build_latent_vars(latent_vars) - data = check_and_maybe_build_data(data) - latent_vars, _ = transform(latent_vars, auto_transform) - scale = check_and_maybe_build_dict(scale) - var_list = check_and_maybe_build_var_list(var_list, latent_vars, data) - p_log_prob = [0.0] * n_samples q_log_prob = [0.0] * n_samples - base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(n_samples): - # Form dictionary in order to replace conditioning on prior or - # observed variable with conditioning on a specific value. - scope = base_scope + tf.get_default_graph().unique_name("q_sample") - dict_swap = {} - for x, qx in six.iteritems(data): - if isinstance(x, RandomVariable): - if isinstance(qx, RandomVariable): - qx_copy = copy(qx, scope=scope) - dict_swap[x] = qx_copy.value - else: - dict_swap[x] = qx - - # Sample z ~ q(z), then compute log p(x, z). - q_dict_swap = dict_swap.copy() - for z, qz in six.iteritems(latent_vars): - # Copy q(z) to obtain new set of posterior samples. - qz_copy = copy(qz, scope=scope) - q_dict_swap[z] = qz_copy.value - if phase_q != 'sleep': + with Trace() as posterior_trace: + call_function_up_to_args(variational, *args, **kwargs) + intercept = make_intercept( + posterior_trace, align_data, align_latent, args, kwargs) + with Trace(intercept=intercept) as model_trace: + call_function_up_to_args(model, *args, **kwargs) + + for name, node in six.iteritems(model_trace): + rv = node.value + scale_factor = scale(name) + p_log_prob[s] += tf.reduce_sum(scale_factor * rv.log_prob(rv.value)) + posterior_node = posterior_trace.get(align_latent(name), None) + if phase_q != 'sleep' and posterior_node is not None: # If not sleep phase, compute log q(z). + qz = posterior_node.value q_log_prob[s] += tf.reduce_sum( - scale.get(z, 1.0) * - qz_copy.log_prob(tf.stop_gradient(q_dict_swap[z]))) - - for z in six.iterkeys(latent_vars): - z_copy = copy(z, q_dict_swap, scope=scope) - p_log_prob[s] += tf.reduce_sum( - scale.get(z, 1.0) * z_copy.log_prob(q_dict_swap[z])) - - for x in six.iterkeys(data): - if isinstance(x, RandomVariable): - x_copy = copy(x, q_dict_swap, scope=scope) - p_log_prob[s] += tf.reduce_sum( - scale.get(x, 1.0) * x_copy.log_prob(q_dict_swap[x])) + scale_factor * qz.log_prob(tf.stop_gradient(qz.value))) if phase_q == 'sleep': - # Sample z ~ p(z), then compute log q(z). - scope = base_scope + tf.get_default_graph().unique_name("p_sample") - p_dict_swap = dict_swap.copy() - for z, qz in six.iteritems(latent_vars): - # Copy p(z) to obtain new set of prior samples. - z_copy = copy(z, scope=scope) - p_dict_swap[qz] = z_copy.value - for qz in six.itervalues(latent_vars): - qz_copy = copy(qz, p_dict_swap, scope=scope) + with Trace() as model_trace: + call_function_up_to_args(model, *args, **kwargs) + intercept = _make_sleep_intercept( + model_trace, align_data, align_latent, args, kwargs) + with Trace(intercept=intercept) as posterior_trace: + call_function_up_to_args(variational, *args, **kwargs) + + # Build dictionary to return scale factor for a posterior + # variable via its corresponding prior. The implementation is + # naive. + scale_posterior = {} + for name, node in six.iteritems(model_trace): + rv = node.value + posterior_node = posterior_trace.get(align_latent(name), None) + if posterior_node is not None: + qz = posterior_node.value + scale_posterior[qz] = rv + + for name, node in six.iteritems(posterior_trace): + rv = node.value + scale_factor = scale_posterior[qz] q_log_prob[s] += tf.reduce_sum( - scale.get(z, 1.0) * - qz_copy.log_prob(tf.stop_gradient(p_dict_swap[qz]))) + scale_factor * rv.log_prob(tf.stop_gradient(rv.value))) p_log_prob = tf.reduce_mean(p_log_prob) q_log_prob = tf.reduce_mean(q_log_prob) reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses()) - if collections is not None: tf.summary.scalar("loss/p_log_prob", p_log_prob, collections=collections) @@ -141,12 +127,14 @@ def wake_sleep(latent_vars=None, data=None, n_samples=1, phase_q='sleep', loss_p = -p_log_prob + reg_penalty loss_q = -q_log_prob + reg_penalty + return loss_p, loss_q + - q_rvs = list(six.itervalues(latent_vars)) - q_vars = [v for v in var_list - if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0] - q_grads = tf.gradients(loss_q, q_vars) - p_vars = [v for v in var_list if v not in q_vars] - p_grads = tf.gradients(loss_p, p_vars) - grads_and_vars = list(zip(q_grads, q_vars)) + list(zip(p_grads, p_vars)) - return loss_p, grads_and_vars +def _make_sleep_intercept(trace, align_data, align_latent, args, kwargs): + def _intercept(f, *fargs, **fkwargs): + """Set variational distribution's sample value to prior's.""" + name = fkwargs.get('name', None) + z = trace[align_latent(name)].value + fkwargs['value'] = z.value + return f(*fargs, **fkwargs) + return _intercept diff --git a/edward/inferences/wgan_inference.py b/edward/inferences/wgan_inference.py index 1b95f6a1e..9e015f795 100644 --- a/edward/inferences/wgan_inference.py +++ b/edward/inferences/wgan_inference.py @@ -5,8 +5,8 @@ import six import tensorflow as tf -from edward.inferences.inference import (check_and_maybe_build_data, - transform, check_and_maybe_build_dict, check_and_maybe_build_var_list) +from edward.models import Trace +from edward.inferences.inference import call_function_up_to_args try: from edward.models import Uniform @@ -14,9 +14,8 @@ raise ImportError("{0}. Your TensorFlow version is not supported.".format(e)) -def wgan_inference(data=None, discriminator=None, - penalty=10.0, - scale=None, var_list=None, collections=None): +def wgan_inference(model, discriminator, align_data, + penalty=10.0, collections=None, *args, **kwargs): """Parameter estimation with GAN-style training [@goodfellow2014generative], using the Wasserstein distance [@arjovsky2017wasserstein]. @@ -67,13 +66,11 @@ def wgan_inference(data=None, discriminator=None, None (or 0.0) if using no penalty. clip: float, optional. Value to clip weights by. Default is no clipping. - """ - data = check_and_maybe_build_data(data) - scale = check_and_maybe_build_dict(scale) - var_list = check_and_maybe_build_var_list(var_list, {}, data) - x_true = list(six.itervalues(data))[0] - x_fake = list(six.iterkeys(data))[0] + `model` must return the generated data. + """ + x_fake = call_function_up_to_args(model, *args, **kwargs) + x_true = align_data(x_fake.name) with tf.variable_scope("Disc"): d_true = discriminator(x_true) @@ -103,14 +100,4 @@ def wgan_inference(data=None, discriminator=None, mean_fake = tf.reduce_mean(d_fake) loss_d = mean_fake - mean_true + penalty + tf.reduce_sum(reg_terms_d) loss = -mean_fake + tf.reduce_sum(reg_terms) - - var_list_d = tf.get_collection( - tf.GraphKeys.TRAINABLE_VARIABLES, scope="Disc") - if var_list is None: - var_list = [v for v in tf.trainable_variables() if v not in var_list_d] - - grads_d = tf.gradients(loss_d, var_list_d) - grads = tf.gradients(loss, var_list) - grads_and_vars_d = list(zip(grads_d, var_list_d)) - grads_and_vars = list(zip(grads, var_list)) - return loss, grads_and_vars, loss_d, grads_and_vars_d + return loss, loss_d diff --git a/edward/util/__init__.py b/edward/util/__init__.py index 7b0dccae9..f40069a29 100644 --- a/edward/util/__init__.py +++ b/edward/util/__init__.py @@ -14,7 +14,6 @@ _allowed_symbols = [ 'compute_multinomial_mode', - 'copy', 'dot', 'get_ancestors', 'get_blanket', diff --git a/edward/util/random_variables.py b/edward/util/random_variables.py index 186ff059c..85ef7669c 100644 --- a/edward/util/random_variables.py +++ b/edward/util/random_variables.py @@ -6,319 +6,11 @@ import six import tensorflow as tf -from copy import deepcopy from edward.models.random_variable import RandomVariable from edward.models.random_variables import TransformedDistribution from edward.models import PointMass from edward.util.graphs import random_variables from tensorflow.contrib.distributions import bijectors -from tensorflow.core.framework import attr_value_pb2 -from tensorflow.python.framework.ops import set_shapes_for_outputs -from tensorflow.python.util import compat - - -def _copy_default(x, *args, **kwargs): - if isinstance(x, (RandomVariable, tf.Operation, tf.Tensor, tf.Variable)): - x = copy(x, *args, **kwargs) - - return x - - -def copy(org_instance, dict_swap=None, scope="copied", - replace_itself=False, copy_q=False, copy_parent_rvs=True): - """Build a new node in the TensorFlow graph from `org_instance`, - where any of its ancestors existing in `dict_swap` are - replaced with `dict_swap`'s corresponding value. - - Copying is done recursively. Any `Operation` whose output is - required to copy `org_instance` is also copied (if it isn't already - copied within the new scope). - - `tf.Variable`s, `tf.placeholder`s, and nodes of type `Queue` are - always reused and not copied. In addition, `tf.Operation`s with - operation-level seeds are copied with a new operation-level seed. - - Args: - org_instance: RandomVariable, tf.Operation, tf.Tensor, or tf.Variable. - Node to add in graph with replaced ancestors. - dict_swap: dict, optional. - Random variables, variables, tensors, or operations to swap with. - Its keys are what `org_instance` may depend on, and its values are - the corresponding object (not necessarily of the same class - instance, but must have the same type, e.g., float32) that is used - in exchange. - scope: str, optional. - A scope for the new node(s). This is used to avoid name - conflicts with the original node(s). - replace_itself: bool, optional - Whether to replace `org_instance` itself if it exists in - `dict_swap`. (This is used for the recursion.) - copy_q: bool, optional. - Whether to copy the replaced tensors too (if not already - copied within the new scope). Otherwise will reuse them. - copy_parent_rvs: - Whether to copy parent random variables `org_instance` depends - on. Otherwise will copy only the sample tensors and not the - random variable class itself. - - Returns: - RandomVariable, tf.Variable, tf.Tensor, or tf.Operation. - The copied node. - - Raises: - TypeError. - If `org_instance` is not one of the above types. - - #### Examples - - ```python - x = tf.constant(2.0) - y = tf.constant(3.0) - z = x * y - - qx = tf.constant(4.0) - # The TensorFlow graph is currently - # `x` -> `z` <- y`, `qx` - - # This adds a subgraph with newly copied nodes, - # `qx` -> `copied/z` <- `copied/y` - z_new = ed.copy(z, {x: qx}) - - sess = tf.Session() - sess.run(z) - 6.0 - sess.run(z_new) - 12.0 - ``` - """ - if not isinstance(org_instance, - (RandomVariable, tf.Operation, tf.Tensor, tf.Variable)): - raise TypeError("Could not copy instance: " + str(org_instance)) - - if dict_swap is None: - dict_swap = {} - if scope[-1] != '/': - scope += '/' - - # Swap instance if in dictionary. - if org_instance in dict_swap and replace_itself: - org_instance = dict_swap[org_instance] - if not copy_q: - return org_instance - elif isinstance(org_instance, tf.Tensor) and replace_itself: - # Deal with case when `org_instance` is the associated tensor - # from the RandomVariable, e.g., `z.value`. If - # `dict_swap={z: qz}`, we aim to swap it with `qz.value`. - for key, value in six.iteritems(dict_swap): - if isinstance(key, RandomVariable): - if org_instance == key.value: - if isinstance(value, RandomVariable): - org_instance = value.value - else: - org_instance = value - if not copy_q: - return org_instance - break - - # If instance is a tf.Variable, return it; do not copy any. Note we - # check variables via their name. If we get variables through an - # op's inputs, it has type tf.Tensor and not tf.Variable. - if isinstance(org_instance, (tf.Tensor, tf.Variable)): - for variable in tf.global_variables(): - if org_instance.name == variable.name: - if variable in dict_swap and replace_itself: - # Deal with case when `org_instance` is the associated _ref - # tensor for a tf.Variable. - org_instance = dict_swap[variable] - if not copy_q or isinstance(org_instance, tf.Variable): - return org_instance - for variable in tf.global_variables(): - if org_instance.name == variable.name: - return variable - break - else: - return variable - - graph = tf.get_default_graph() - new_name = scope + org_instance.name - - # If an instance of the same name exists, return it. - if isinstance(org_instance, RandomVariable): - for rv in random_variables(): - if new_name == rv.name: - return rv - elif isinstance(org_instance, (tf.Tensor, tf.Operation)): - try: - return graph.as_graph_element(new_name, - allow_tensor=True, - allow_operation=True) - except: - pass - - # Preserve ordering of random variables. Random variables are always - # copied first (from parent -> child) before any deterministic - # operations that depend on them. - if copy_parent_rvs and \ - isinstance(org_instance, (RandomVariable, tf.Tensor, tf.Variable)): - for v in get_parents(org_instance): - copy(v, dict_swap, scope, True, copy_q, True) - - if isinstance(org_instance, RandomVariable): - rv = org_instance - - # If it has copiable arguments, copy them. - args = [_copy_default(arg, dict_swap, scope, True, copy_q, False) - for arg in rv._args] - - kwargs = {} - for key, value in six.iteritems(rv._kwargs): - if isinstance(value, list): - kwargs[key] = [_copy_default(v, dict_swap, scope, True, copy_q, False) - for v in value] - else: - kwargs[key] = _copy_default( - value, dict_swap, scope, True, copy_q, False) - - kwargs['name'] = new_name - # Create new random variable with copied arguments. - try: - new_rv = type(rv)(*args, **kwargs) - except ValueError: - # Handle case where parameters are copied under absolute name - # scope. This can cause an error when creating a new random - # variable as tf.identity name ops are called on parameters ("op - # with name already exists"). To avoid remove absolute name scope. - kwargs['name'] = new_name[:-1] - new_rv = type(rv)(*args, **kwargs) - return new_rv - elif isinstance(org_instance, tf.Tensor): - tensor = org_instance - - # Do not copy tf.placeholders. - if 'Placeholder' in tensor.op.type: - return tensor - - # A tensor is one of the outputs of its underlying - # op. Therefore copy the op itself. - op = tensor.op - new_op = copy(op, dict_swap, scope, True, copy_q, False) - - output_index = op.outputs.index(tensor) - new_tensor = new_op.outputs[output_index] - - # Add copied tensor to collections that the original one is in. - for name, collection in six.iteritems(tensor.graph._collections): - if tensor in collection: - graph.add_to_collection(name, new_tensor) - - return new_tensor - elif isinstance(org_instance, tf.Operation): - op = org_instance - - # Do not copy queue operations. - if 'Queue' in op.type: - return op - - # Copy the node def. - # It is unique to every Operation instance. Replace the name and - # its operation-level seed if it has one. - node_def = deepcopy(op.node_def) - node_def.name = new_name - if 'seed2' in node_def.attr and tf.get_seed(None)[1] is not None: - node_def.attr['seed2'].i = tf.get_seed(None)[1] - - # Copy other arguments needed for initialization. - output_types = op._output_types[:] - - # If it has an original op, copy it. - if op._original_op is not None: - original_op = copy(op._original_op, dict_swap, scope, True, copy_q, False) - else: - original_op = None - - # Copy the op def. - # It is unique to every Operation type. - op_def = deepcopy(op.op_def) - - new_op = tf.Operation(node_def, - graph, - [], # inputs; will add them afterwards - output_types, - [], # control inputs; will add them afterwards - [], # input types; will add them afterwards - original_op, - op_def) - - # advertise op early to break recursions - graph._add_op(new_op) - - # If it has control inputs, copy them. - control_inputs = [] - for x in op.control_inputs: - elem = copy(x, dict_swap, scope, True, copy_q, False) - if not isinstance(elem, tf.Operation): - elem = tf.convert_to_tensor(elem) - - control_inputs.append(elem) - - new_op._add_control_inputs(control_inputs) - - # If it has inputs, copy them. - for x in op.inputs: - elem = copy(x, dict_swap, scope, True, copy_q, False) - if not isinstance(elem, tf.Operation): - elem = tf.convert_to_tensor(elem) - - new_op._add_input(elem) - - # Use Graph's private methods to add the op, following - # implementation of `tf.Graph().create_op()`. - compute_shapes = True - compute_device = True - op_type = new_name - - if compute_shapes: - set_shapes_for_outputs(new_op) - graph._record_op_seen_by_control_dependencies(new_op) - - if compute_device: - graph._apply_device_functions(new_op) - - if graph._colocation_stack: - all_colocation_groups = [] - for colocation_op in graph._colocation_stack: - all_colocation_groups.extend(colocation_op.colocation_groups()) - if colocation_op.device: - # Make this device match the device of the colocated op, to - # provide consistency between the device and the colocation - # property. - if new_op.device and new_op.device != colocation_op.device: - logging.warning("Tried to colocate %s with an op %s that had " - "a different device: %s vs %s. " - "Ignoring colocation property.", - name, colocation_op.name, new_op.device, - colocation_op.device) - - all_colocation_groups = sorted(set(all_colocation_groups)) - new_op.node_def.attr["_class"].CopyFrom(attr_value_pb2.AttrValue( - list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups))) - - # Sets "container" attribute if - # (1) graph._container is not None - # (2) "is_stateful" is set in OpDef - # (3) "container" attribute is in OpDef - # (4) "container" attribute is None - if (graph._container and - op_type in graph._registered_ops and - graph._registered_ops[op_type].is_stateful and - "container" in new_op.node_def.attr and - not new_op.node_def.attr["container"].s): - new_op.node_def.attr["container"].CopyFrom( - attr_value_pb2.AttrValue(s=compat.as_bytes(graph._container))) - - return new_op - else: - raise TypeError("Could not copy instance: " + str(org_instance)) def get_ancestors(x, collection=None): diff --git a/examples/beta_bernoulli.py b/examples/beta_bernoulli.py index 20fc979bf..48da6e64d 100644 --- a/examples/beta_bernoulli.py +++ b/examples/beta_bernoulli.py @@ -12,58 +12,6 @@ from edward.models import Bernoulli, Beta, Empirical - -def klqp(model, variational, align_latent, align_data, *args): - """Loss function equal to KL(q||p) up to a constant. - - Args: - model: function whose inputs are a subset of `args` (e.g., for - discriminative). Output is not used. - variational: function whose inputs are a subset of `args` (e.g., - for amortized). Output is not used. - align_latent: function of string, aligning `model` latent - variables with `variational`. It takes a model variable's name - as input and returns a string, indexing `variational`'s trace; - else identity. - align_data: function of string, aligning `model` observed - variables with data. It takes a model variable's name as input - and returns an integer, indexing `args`; else identity. - args: data inputs. It is passed at compile-time in Graph - mode or runtime in Eager mode. - """ - def _intercept(f, *args, **kwargs): - """Set model's sample values to variational distribution's and data.""" - name = kwargs.get('name', None) - if isinstance(align_data(name), int): - kwargs['value'] = arg[align_data(name)] - else: - kwargs['value'] = posterior_trace[align_latent(name)].value - return f(*args, **kwargs) - with Trace() as posterior_trace: - call_function_up_to_args(variational, args) - with Trace(intercept=_intercept) as model_trace: - call_function_up_to_args(model, args) - - log_p = tf.reduce_sum([tf.reduce_sum(x.log_prob(x.value)) - for x in model_trace.values() - if isinstance(x, tfd.Distribution)]) - log_q = tf.reduce_sum([tf.reduce_sum(qz.log_prob(qz.value)) - for qz in posterior_trace.values() - if isinstance(qz, tfd.Distribution)]) - loss = log_q - log_p - return loss - - -def call_function_up_to_args(f, args): - import inspect - if hasattr(f, "_func"): # make_template() - f = f._func - num_args = len(inspect.getargspec(f).args) - if num_args > 0: - return f(args[:num_args]) - return f() - - ed.set_seed(42) # DATA @@ -82,10 +30,6 @@ def proposal(x=None): proposal_p = Beta(3.0, 9.0, name="proposal_p") return proposal_p -# TODO update? maybe just transition of the posterior chain -# `update` is a function of the realized data (tfe.Tensor) and returns -# a train operation. -# ed.automate(train_op, x_data) update = ed.metropolis_hastings( model, posterior, @@ -99,7 +43,7 @@ def proposal(x=None): # posterior, # data) -sess = tf.Session() # ed.get_session() not needed -tf.global_variables_initializer().run() # presumably not needed +sess = tf.Session() +sess.run(tf.global_variables_initializer()) for _ in range(1000): info_dict = update(x_data) diff --git a/examples/normal_normal.py b/examples/normal_normal.py index abb148de6..c7f068670 100644 --- a/examples/normal_normal.py +++ b/examples/normal_normal.py @@ -5,52 +5,54 @@ from __future__ import print_function import edward as ed -import matplotlib.pyplot as plt import numpy as np import tensorflow as tf -from edward.models import Empirical, Normal -from edward.util import get_session, Progbar +from edward.models import Normal +from edward.util import Progbar -tf.set_random_seed(42) -# DATA -x_data = np.array([0.0] * 50) +def model(): + """Normal-Normal with known variance.""" + mu = Normal(loc=0.0, scale=1.0, name="mu") + x = Normal(loc=mu, scale=1.0, sample_shape=50, name="x") + return x -# MODEL: Normal-Normal with known variance -mu = Normal(loc=0.0, scale=1.0) -x = Normal(loc=tf.ones(50) * mu, scale=1.0) -# INFERENCE -qmu = Normal(loc=tf.Variable(0.0), scale=tf.nn.softplus(tf.Variable(1.0))+1e-3) +def variational(): + qmu = Normal(loc=tf.get_variable("loc", []), + scale=tf.nn.softplus(tf.get_variable("shape", [])), + name="qmu") + return qmu -# analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140) -loss, grads_and_vars = ed.klqp({mu: qmu}, data={x: x_data}) -train_op = tf.train.AdamOptimizer().apply_gradients(grads_and_vars) -progbar = Progbar(1000) -sess = get_session() +variational = tf.make_template("variational", variational) + +tf.set_random_seed(42) +x_data = np.array([0.0] * 50) + +# analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140) +loss, surrogate_loss = ed.klqp( + model, + variational, + align_latent=lambda name: 'qmu' if name == 'mu' else name, + align_data=lambda name: 'x_data' if name == 'x' else name, + x_data=x_data) + +var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) +grads = tf.gradients(surrogate_loss, var_list) +grads_and_vars = list(zip(grads, var_list)) +train_op = tf.train.AdamOptimizer(1e-2).apply_gradients(grads_and_vars) + +qmu = variational() +sess = tf.Session() + +progbar = Progbar(5000) sess.run(tf.global_variables_initializer()) -for t in range(1, 1001): +for t in range(1, 5001): loss_val, _ = sess.run([loss, train_op]) if t % 50 == 0: - progbar.update(t, {"Loss": loss_val}) - -# # CRITICISM -sess = get_session() -mean, stddev = sess.run([qmu.mean(), qmu.stddev()]) -print("Inferred posterior mean:") -print(mean) -print("Inferred posterior stddev:") -print(stddev) - -# Check convergence with visual diagnostics. -# samples = sess.run(qmu.params) - -# # Plot histogram. -# plt.hist(samples, bins='auto') -# plt.show() - -# # Trace plot. -# plt.plot(samples) -# plt.show() + mean, stddev = sess.run([qmu.mean(), qmu.stddev()]) + progbar.update(t, {"Loss": loss_val, + "Posterior mean": mean, + "Posterior stddev": stddev})