Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add marginal likelihood estimation via bridge sampling #2040

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

karink520
Copy link
Contributor

@karink520 karink520 commented Jun 2, 2022

Co-authored-by: @junpenglao

Description

Provides an estimate of the (log) marginal likelihood, estimated using bridge sampling (as described in Gronau, Quentin F., et al., 2017), building on an implementation from @junpenglao. This could be expanded to add Bayes factor functionality if desired.

The bridge sampler uses samples from the posterior, so the log_marginal_likelihood_bridgesampling function takes as a parameter and InferenceData object that has a posterior group, as well as an unnormalized log probability function (e.g. from a pymc model model.logp_array).

Because we fit a multivariate normal proposal distribution to the posterior samples, it is helpful to have samples that are transformed e.g. to have support on the real line instead of on a bounded interval. Although these transformed samples are created as part of the NUTS sampling, I believe they're not currently included in InferenceData (see issue #230 ). So, log_marginal_likelihood_bridgesampling currently takes a dict whose keys are variable names and whose values are the associated transformation (or the identity). You could get this from a pymc model with something like the following, although maybe there's a better way:

def get_transformation_dict_from_model(model):
    """
    Returns a dict giving the transformations for each variable

    Parameters:
    -----------
    model: a PyMC model

    Returns:
    --------
    transformation_dict: dict 
      Keys are (str) names of model variables (their pre-transformation names),
      Values are their associated transformation as a function that 
      (elementwise) transforms an array. If the variable has no transformation
      associated, we use the identity function.
    """

    transformation_dict = {}
    for var_name in model.named_vars:
        if not var_name.endswith('__'): 
            var = getattr(model, var_name)
            transformation = getattr(var, 'transformation', None)
            if transformation is not None:
              transformation_dict[var_name] = transformation.forward_val
            else: # if no transformation, use identity
              transformation_dict[var_name] = lambda x: x
    return transformation_dict

Curious to hear any thoughts or feedback! I'm happy to write tests for this as well, but wanted to wait to get initial feedback before doing so.

Checklist

  • Follows official PR format
  • New features are properly documented (with an example if appropriate)
  • Includes new or updated tests to cover the new feature
  • Code style correct (follows pylint and black guidelines)
  • Changes are listed in changelog

@ahartikainen
Copy link
Contributor

Should we use a class approach like we do with reloo? This way different backends only need to create specific methods with uniform parameters and output.

cc @OriolAbril

@OriolAbril
Copy link
Member

I think we should define how to include unconstrained variables in InferenceData and solve #230 between options 1 or 2. It is turning out to be much more work and inconvenient to go with option 3.

I have skimmed the code and have many ideas, mostly related to using xarray more. But I am not sure it is worth it to start changing things yet until we have decided on the issue.

@karink520
Copy link
Contributor Author

Just checking in to see if there's anything I can be helpful with here!

@OriolAbril
Copy link
Member

Very sorry about the other PR taking so long, but it has finally been merged. I can take care of rebasing if it helps, then I'll try and add some high level comments

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants