Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use xarray throughout #97

Closed
ColCarroll opened this issue May 23, 2018 · 42 comments
Closed

Use xarray throughout #97

ColCarroll opened this issue May 23, 2018 · 42 comments

Comments

@ColCarroll
Copy link
Member

There have been proposals to use xarray as a common language for pymc3, pystan, and pymc4. This library might be a good place to start that by implementing utility functions for translating pymc3's Multitrace and pystan's OrderedDict into xarray objects, and then having all plotting functions work with xarrays.

@ahartikainen
Copy link
Contributor

ahartikainen commented May 23, 2018

PyStan fit object is a StanFit4Model cython class. The ordered dict comes from fit.extract, but there are otherways to interact with it.

edit. typos

@twiecki
Copy link
Contributor

twiecki commented May 23, 2018

@ahartikainen The idea would be for PyStan to write a converter from StanFir4Model (or the ordered dict) to an xarray that can then be passed to arviz.

@junpenglao
Copy link
Contributor

Should the xarray object also includes diagnostics such as rhat and effective sample size? The PyStan object has cython function that computes the rhat and effective sample size quite efficiently. Last time when I was updating the effective sample size implementation in pymc3 I found that the python implementation is much slower when the input array is large. @aseyboldt has a numba implementation which is much faster but PyMC3 does not have numba as dependence.

My proposal is that, we make a faster implementation of rhat and effective size implementation in arviz (the numba version), and a general converter that change pymc3 trace and StanFir4Model into xarray (could be two independent one). It has the trace as the first level, and diagnostics and stats as the second level. The second level is optional but will be computed by default if not provided.

Thoughts?

@twiecki
Copy link
Contributor

twiecki commented May 23, 2018

Good idea. I think it makes sense to allow inclusion of these stats in the xarray and if not present, compute it using a fast (numba) implementation in arviz.

@ahartikainen
Copy link
Contributor

@twiecki yes, that is a good idea. I'm currently writing function to transform fit object to suitable format for the current implementation (of plotting code).

But if Arviz will implement code to read specific type of xarray object we can wait for that before we release the arviz wrapped code.

@junpenglao It would be ideal to just give a fit object to arviz and let the arviz do the magic inside it.

Do you think the numba is going to be a dependency or optional?

@aseyboldt
Copy link
Contributor

I'm not sure computing rhat and other stats in arviz is a good idea. Some of that depends on the sampler (eg treedepth), and if we want to print warnings or a report of some kind after sampling then we need to compute it there anyway. I think it would be good if those stats could be in the xarray as well, but there is the problem of name collisions. If we only have rhat and n_effective_samples or so in the array, then that would be fine, but if we also want to put all sampler specific stats in it, then that list would suddenly be much longer and also change over time.
An alternative would be to have two xarrays, one for the samples, and one for the stats.

@ahartikainen
Copy link
Contributor

Yes, summary stats and samples basically live in a different spaces.

Either two xarrays or some tricks with the indexing could work.

@twiecki
Copy link
Contributor

twiecki commented May 23, 2018

I don't see a problem making numba a dependency, it matured a lot and is well packaged. Eventually there is a high chance scipy will use it too.

@junpenglao
Copy link
Contributor

junpenglao commented May 23, 2018

Personally, I prefer to make it a dependency - the alternative implementation is too slow.
I am hoping the numba installation is not an issue anymore - last time i check if you are not under conda there are some complications.

Also I come across this discussion on the mc-stan discourse: Proposal for consolidated output and the related Stan wiki by @martinmodrak and @sakrejda. I think this is a great discussion to have, specifically, whether there could be a universal representation of different inference, with their related diagnostics.

My idea would be:

  • level 0, meta information. This determint the lower structure
    Used inference (sampling, approximation, estimator)
    Avaliable diagnostics
    Parameterization of the approximation (VI or laplace approximation etc)

  • level 1, summary
    This would be a point from the parameter space of the posterior / likelihood function, with related error estimation for estimator and variance/covariance matrix for MCMC samples. This includes:
    MLE or MAP and their associate error
    VI parameters
    mean and cov of MCMC samples
    The complication is that, if you are doing some kind of VI approximation that is not parameterized by only mean and std/cov, those information need to save separately.

  • level 2, samples
    MCMC samples. For VI we can sample from the approximation model (we have that functionality quite handy in PyMC3). If estimator is used than it is just 1 sample

  • level 3, diagnoistics and statistics
    including divergence, tree depth, etc for HMC and NUTS
    effective sample size, rhat for other MCMC samplers
    ELBO history for VI

@junpenglao
Copy link
Contributor

Maybe we should start a google doc and also invite others working on PPL to edit? For example the folks from tensorflow/probability, Pyro etc.

@twiecki
Copy link
Contributor

twiecki commented May 23, 2018

Good idea. CC @fritzo @dustinvtran ArviZ is a package that separates out PyMC3's plotting and some analysis functionality to create many commonly used plots like a traceplot. With PyStan we're currently discussing potential standardized storage objects (xarray). Is there any interest from Edward/Edward2/Pyro to collaborate on this?

@ahartikainen
Copy link
Contributor

cc for active (Py)Stan folks

@ariddell @seantalts @braaannigan

@ariddell
Copy link

I have no objections to using xarrays. Sounds like a reasonable idea.

@braaannigan
Copy link

No objections, probably handy to agree on a naming convention for parameters before implementing in the various code bases

@fritzo
Copy link

fritzo commented May 23, 2018

@twiecki Yes the Pyro team is interested in collaborating on standardized storage formats that can facilitate comparison and encourage inference algorithm research. cc @neerajprad @jpchen @rohitsingh who are working on HMC and Pyro-Stan compatibility.

@eb8680
Copy link

eb8680 commented May 23, 2018

cc @yebai @xukai92 you might be interested in this for Turing

@sakrejda
Copy link

Re: @junpenglao, we've explicitly for the moment punted on this question while we re-organize the intermediate layer to make it possible. I agree that it's important and has been a topic of ongoing discussion so if there is a broader discussion please let us know. Since Stan is multi-interface it might be more complicated at the file-format level (we've been talking about something streaming-friendly in ProtoBuf) but at the level of deciding how outputs should be grouped and organized it would be fantastic to have compatability with other projects.

@shoyer
Copy link

shoyer commented May 23, 2018

As an xarray developer and probabilistic programming enthusiast, I'd really love to see this happen. Please feel free to ping me if you come across any issues.

@dustinvtran
Copy link

dustinvtran commented May 23, 2018

@matthewdhoffman,@davmre,@jvdillon,@csuter,@derifatives,@axch,@srvasude

Edward2 and TFP's abstraction level doesn't really require named data structures except for dicts and namedtuple, which are more for collecting heterogenous data. For PyMC*, Stan, and others, xarray instead of custom classes or pd.DataFrame sounds like a great idea.

@springcoil
Copy link

Wow awesome proposal guys!

@aseyboldt
Copy link
Contributor

If any of the pymc folks want to try it out in some projects, there is some code for getting pymc traces into xarray here: https://discourse.pymc.io/t/use-xarray-for-traces/73
For some time now, I've been wrapping pymc traces in a fit object, that serialises to a netcdf file. The format for that file could be something like this:

  • /trace: Stores the actual trace of a mcmc run. eg to_xarray(pymc_trace).to_netcdf('file', group='/trace')
  • /trace_stats: This is where stats of the sampler (treedepth etc) can go. Also probably info about where divergences happend. Also rhat and effective_n etc...
  • /data: The observed variables and their values in the model (optional)
  • /advi: Some format for the advi result (not sure about that, could just be a trace as well, or some other representation of the approximation)
  • /advi_stats: Stats for the advi, eg history of elbo or so.

And we can put some meta info in the attr tags as well. If we had a serialization format for the model itself, we could also add a /pymc_model group and store it there. (@stan-folks /stan_model :-) )

This should avoid name collisions between stats and variables, but it would probably mean that we have to duplicate some dimension labels, if we need them in more than one group (not sure how to get xarray to read dimensions if they are in different groups, but if that works we could just add /dims.

@aseyboldt
Copy link
Contributor

aseyboldt commented May 24, 2018

On the topic of interoperability:
I think it would be great if we could get this to a point where different tools use the same format. But I think we also need to be careful not to promise too much interoperability. I can't see a reason why /trace couldn't be the same for eg stan and pymc, but for /trace_stats this isn't as clear anymore. They store basically the same thing, but if we want to keep that interoperable that might hinder development. A attr that stores which program and version created those stats might be really helpful, and a visualisation lib could then just have specific code to read those if necessary.

@ahartikainen
Copy link
Contributor

Hey, @aseyboldt @ColCarroll

I have basically put stuff into xarray from PyStan fit object. This should work with PyStan 2.16 onwards (I updated our .extract method). It can still be updated if needed for earlier versions .
I did split the data between the warmup and sampling.

def pars_to_xarray(fit, pars=None, infer_dtypes=True):
    ...
    added regex magic to infer ints from the model code automatically
    ...
    return data_set, data_set_warmup

def sampler_params_to_xarray(fit, params=None):
    ...
    return sampler_params_dataset, sampler_params_dataset_warmup

def inits_to_xarray(fit, pars=None, infer_dtypes=True):
    ...
    return inits_dataset

def summary_to_xarray(fit, pars=None):
    ...
    transform summary data to dataframe --> xarray
    ...
    return summary_dataset, c_summary_dataset 

All the sampled parameters are in their "original" shape (same goes for the inits function)

shape = (draw, chain, *parameter_shape)
val, vec, mat -->
val, (draw, chain)
vec, (draw, chain, vec_axis1)
mat, (draw, chain, mat_axis1, mat_axis2)

All the sampler parameters are

accept_stat__, (draw, chain) float64
stepsize__, (draw, chain) float64
treedepth__, (draw, chain) int64
n_leapfrog__, (draw, chain) int64
divergent__, (draw, chain) bool
energy__, (draw, chain) float64
lp__, (draw, chain) float64

Summary xarray Dataset are in their flatname format

val, (index)
vec[1], (index)
vec[2], (index)
mat[1,1] (index)
mat[1,2] (index)
mat[2,1] (index)
mat[2,2] (index)

Summary for chain is

val, (chain, index)
vec[1], (chain, index)
vec[2], (chain, index)
mat[1,1] (chain, index)
mat[1,2] (chain, index)
mat[2,1] (chain, index)
mat[2,2] (chain, index)

So how should we parse this together? How about the naming?

@ColCarroll
Copy link
Member Author

I have been playing around with this today, too, using the non-centered eight schools model. See here for @aloctavodia's model code for both pystan and pymc3. I am just calling trace = pm.sample() for pymc3, and

sm = pystan.StanModel(model_code=schools_code)
fit = sm.sampling(data=schools_dat, iter=1000, chains=4)

for pystan.

My biggest difficulty right now is automatically detecting that both theta and theta_tilde have the same first dimension (i.e., that each is a vector referring to the same 8 schools). Building off @aseyboldt's example notebooks, my API currently looks like this:

data = to_xarray(
    non_centered_eight_trace, 
    coords = {
        'school': np.arange(J)
    }, 
    dims={
        'theta_tilde': ['school'], 
        'theta': ['school'], 
    }
)

The output from that looks like this:

<xarray.Dataset>
Dimensions:                 (chain: 4, sample: 500, school: 8)
Coordinates:
  * school                  (school) int64 0 1 2 3 4 5 6 7
  * sample                  (sample) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 ...
  * chain                   (chain) int64 0 1 2 3
Data variables:
    mu                      (chain, sample) float64 2.431 3.441 0.8163 2.117 ...
    theta_tilde             (chain, sample, school) float64 -0.03801 -0.2972 ...
    tau                     (chain, sample) float64 10.98 2.865 5.888 0.3688 ...
    theta                   (chain, sample, school) float64 2.014 -0.8323 ...
    stat__max_energy_error  (chain, sample) float64 1.172 0.9255 -0.1971 ...
    stat__mean_tree_accept  (chain, sample) float64 0.8326 0.5738 0.9696 ...
    stat__step_size         (chain, sample) float64 0.6574 0.6574 0.6574 ...
    stat__tree_size         (chain, sample) float64 7.0 7.0 7.0 7.0 7.0 7.0 ...
    stat__energy            (chain, sample) float64 48.17 49.51 49.38 55.18 ...
    stat__tune              (chain, sample) bool False False False False ...
    stat__diverging         (chain, sample) bool False False False False ...
    stat__energy_error      (chain, sample) float64 -0.04698 0.4614 -0.1971 ...
    stat__depth             (chain, sample) int64 3 3 3 3 3 3 3 3 3 3 3 3 3 ...
    stat__step_size_bar     (chain, sample) float64 0.5544 0.5544 0.5544 ...

@ColCarroll
Copy link
Member Author

Hrm, I am reading your and Adrian's post more carefully, and agree that it would be a better model to have multiple xarrays for the different parts of the summary.

Let me keep looking at this, but can you give an example of how your output looks on the non-centered 8 schools model?

@aseyboldt
Copy link
Contributor

That sounds good.
@ahartikainen If you like to share that regex code I could put both extraction methods into a notebook and play a bit with it.

@ColCarroll I don't think we can autodetect if things are the same dimension. Just because the shape is the same doesn't mean it has the same dimension. That is why I like the explicit syntax in the model:

coords = {
    'school': ['name1', 'name2', 'name3']
}

with pm.Model(coords=coords):
    theta = pm.Whatever(dims='school')  # or dims=('school', 'whatever')

One additional issue is that of the index ('chain', 'sample'):

  • Which order do we want?
  • Should we call the second sample or draw?
  • If we select one sample (or draw) using trace.isel(sample=100) we get one value per chain. I think that is a bit counter intuitive. We could use draw (or sample) as an hierarchical index trace.stack(draw=('chain', 'sample')), so that trace.isel(draw=100) gives us a single sample. I kind of like this, but I'm not sure if the additional complexity of using an hierarchical index in all traces is worth it.

@aseyboldt
Copy link
Contributor

Just a quick notebook about how I think we could use the xarrays: https://gist.github.com/aseyboldt/99b8b3ba71d0d58a92264c3bf99bbbf9

@aseyboldt
Copy link
Contributor

It looks like altair might add support for xarray: vega/altair#891

@aseyboldt
Copy link
Contributor

I created a (sketch of a) design document for a netcdf file-format that stores traces. Ideally, I think both stan and pymc (and of course other tools) could write their traces in this fromat, and arviz could use a xarray representation of that for visualisation. The file format should contain all info needed to reproduce the run, and also to debug sampling trouble.

@ahartikainen Is that similar to what you have in mind? I don't have any problems with major changes to this, it is only meant as a starting point. Feel free to edit this as you like.
https://yourpart.eu/p/SXfBlllfnl

@ahartikainen
Copy link
Contributor

@ColCarroll here is an example of the metacode above
@aseyboldt the regex is in the first function

https://gist.github.com/ahartikainen/b16704eec3a912ccd3bb39d62ca04279

Samples / Draw, not sure what is the correct term. I think that one draw equals one value for each parameter in the model.

@aseyboldt that looks a good starting point.

@ColCarroll
Copy link
Member Author

@aseyboldt I worked off your example, and also added a function that gives an informative error message since it took me a little while to understand the syntax.

https://gist.github.com/ColCarroll/c607842947b08bc44d4e1588e6bef98d

@ahartikainen Is there a way to tell pars_to_xarray that theta_axis1 and theta_tilde_axis1 are both referring to the same 8 schools? That is why I am using the slightly more cumbersome notation

data = to_xarray(non_centered_eight_trace, 
                 coords={'school': np.arange(8)}, 
                 dims={'theta_tilde': ['school'], 'theta': ['school']} 
)

Which means the resulting dimensions are just school, sample, and chain.

@ahartikainen
Copy link
Contributor

Doing that automatically: probably not an easy task. Your way to define them looks good.

What is the array order that xarray uses. Which part are "contiguous"? Should that be reflected in the order of axes?

@aseyboldt
Copy link
Contributor

@ahartikainen I think that depends a lot on the backend. If you read data from netcdf4 (hdf5 internally), it autoselects some chunking (which we can override if we want). On numpy I think it uses numpy conventions, so usually (changes for transposition) c-continuous storage. From that point of view I think an order like (school, chain, sample) should be the fastest, if you regularly look at all draws for one variable. I didn't test this though.

@ColCarroll Errors are great :-) (or rather error messages, I don't like errors)

@avehtari
Copy link
Contributor

avehtari commented Jun 1, 2018

Samples / Draw, not sure what is the correct term. I think that one draw equals one value for each parameter in the model.

Wikipedia https://en.wikipedia.org/wiki/Sample_(statistics) says: "In statistics and quantitative research methodology, a data sample is a set of data collected and/or selected from a statistical population by a defined procedure."

So in this case it would be natural that posterior sample is a set of posterior draws. This is what Stan team recommends, although it has not been strictly enforced and variation exists.

@twiecki
Copy link
Contributor

twiecki commented Jun 4, 2018

We try to match our terminology to Stan so draws is fine with me.

@ColCarroll
Copy link
Member Author

Wanted to post a quick update on this issue, since there was a lot of good discussion earlier:

-- There are utilities now for converting posteriors from PyStan and PyMC3 to xarray Datasets
-- Two out of twelve plots use xarray, and will mostly transparently work with posterior draws from either library (#111 has an example of comparing posterior draws from the eight schools model using both PyStan and PyMC3)

My view of next steps, which can go mostly in parallel:
-- Finish porting plots to use xarray
-- Port statistical tests to use xarray
-- Add a converter for OrderedDict/dict, which would work for Edward, Edward2, and the nascent PyMC4`

After that (which isn't that much!), I think it would be reasonable to cut a release on pypi/conda-forge, and start working on using, for example, sampler statistics or observed data in some of these visualizations/analyses.

I appreciate any input/suggestions/help, as always!

@ahartikainen
Copy link
Contributor

Should we add also kwarg for diagnostics (sampler parameters).

It could be "all" or nothing.

Also common names are easier if they are fixed (?)

@SemanticBeeng
Copy link

SemanticBeeng commented Jul 25, 2018

How would this relate to developments around apache arrow brings platform/language independent support for big data (cross C++, Python and JVM)?

See also xtensor-stack/xtensor#394 (comment)

It would be great to have a "common language for pymc3, pystan, and pymc4." that brings them closer to JVM.

@twiecki
Copy link
Contributor

twiecki commented Jul 25, 2018

@SemanticBeeng That's definitely where this is headed. I wouldn't compare it to JVM but rather a common format to store model results that's standardized across different PPLs.

@SemanticBeeng
Copy link

SemanticBeeng commented Jul 30, 2018

Indeed not to JVM specifically but a platform independent format - so is JVM included?

If relevant to you, curious to know how the intent in this thread:

  1. compares to apache arrow
  2. relates to JVM interop
  3. plans to address the need to manage data schema across languages (asking because for JVM people types are very important)
    See for context :

I understand that if JVM is not in the picture then the above are not applicable to this context.

@canyon289
Copy link
Member

Is this now use InferenceData all throughout?

@ColCarroll
Copy link
Member Author

The library is now using xarray and netcdf throughout.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests