-
-
Notifications
You must be signed in to change notification settings - Fork 407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use xarray throughout #97
Comments
PyStan fit object is a StanFit4Model cython class. The ordered dict comes from fit.extract, but there are otherways to interact with it. edit. typos |
@ahartikainen The idea would be for PyStan to write a converter from |
Should the My proposal is that, we make a faster implementation of rhat and effective size implementation in arviz (the numba version), and a general converter that change pymc3 trace and Thoughts? |
Good idea. I think it makes sense to allow inclusion of these stats in the xarray and if not present, compute it using a fast (numba) implementation in arviz. |
@twiecki yes, that is a good idea. I'm currently writing function to transform fit object to suitable format for the current implementation (of plotting code). But if Arviz will implement code to read specific type of xarray object we can wait for that before we release the arviz wrapped code. @junpenglao It would be ideal to just give a fit object to arviz and let the arviz do the magic inside it. Do you think the numba is going to be a dependency or optional? |
I'm not sure computing rhat and other stats in arviz is a good idea. Some of that depends on the sampler (eg treedepth), and if we want to print warnings or a report of some kind after sampling then we need to compute it there anyway. I think it would be good if those stats could be in the xarray as well, but there is the problem of name collisions. If we only have rhat and n_effective_samples or so in the array, then that would be fine, but if we also want to put all sampler specific stats in it, then that list would suddenly be much longer and also change over time. |
Yes, summary stats and samples basically live in a different spaces. Either two xarrays or some tricks with the indexing could work. |
I don't see a problem making numba a dependency, it matured a lot and is well packaged. Eventually there is a high chance scipy will use it too. |
Personally, I prefer to make it a dependency - the alternative implementation is too slow. Also I come across this discussion on the mc-stan discourse: Proposal for consolidated output and the related Stan wiki by @martinmodrak and @sakrejda. I think this is a great discussion to have, specifically, whether there could be a universal representation of different inference, with their related diagnostics. My idea would be:
|
Maybe we should start a google doc and also invite others working on PPL to edit? For example the folks from tensorflow/probability, Pyro etc. |
Good idea. CC @fritzo @dustinvtran ArviZ is a package that separates out PyMC3's plotting and some analysis functionality to create many commonly used plots like a traceplot. With PyStan we're currently discussing potential standardized storage objects (xarray). Is there any interest from Edward/Edward2/Pyro to collaborate on this? |
cc for active (Py)Stan folks |
I have no objections to using xarrays. Sounds like a reasonable idea. |
No objections, probably handy to agree on a naming convention for parameters before implementing in the various code bases |
@twiecki Yes the Pyro team is interested in collaborating on standardized storage formats that can facilitate comparison and encourage inference algorithm research. cc @neerajprad @jpchen @rohitsingh who are working on HMC and Pyro-Stan compatibility. |
Re: @junpenglao, we've explicitly for the moment punted on this question while we re-organize the intermediate layer to make it possible. I agree that it's important and has been a topic of ongoing discussion so if there is a broader discussion please let us know. Since Stan is multi-interface it might be more complicated at the file-format level (we've been talking about something streaming-friendly in ProtoBuf) but at the level of deciding how outputs should be grouped and organized it would be fantastic to have compatability with other projects. |
As an xarray developer and probabilistic programming enthusiast, I'd really love to see this happen. Please feel free to ping me if you come across any issues. |
@matthewdhoffman,@davmre,@jvdillon,@csuter,@derifatives,@axch,@srvasude Edward2 and TFP's abstraction level doesn't really require named data structures except for dicts and namedtuple, which are more for collecting heterogenous data. For PyMC*, Stan, and others, xarray instead of custom classes or pd.DataFrame sounds like a great idea. |
Wow awesome proposal guys! |
If any of the pymc folks want to try it out in some projects, there is some code for getting pymc traces into xarray here: https://discourse.pymc.io/t/use-xarray-for-traces/73
And we can put some meta info in the attr tags as well. If we had a serialization format for the model itself, we could also add a This should avoid name collisions between stats and variables, but it would probably mean that we have to duplicate some dimension labels, if we need them in more than one group (not sure how to get xarray to read dimensions if they are in different groups, but if that works we could just add |
On the topic of interoperability: |
Hey, @aseyboldt @ColCarroll I have basically put stuff into xarray from PyStan fit object. This should work with PyStan 2.16 onwards (I updated our
All the sampled parameters are in their "original" shape (same goes for the inits function)
All the sampler parameters are
Summary xarray Dataset are in their flatname format
Summary for chain is
So how should we parse this together? How about the naming? |
I have been playing around with this today, too, using the non-centered eight schools model. See here for @aloctavodia's model code for both sm = pystan.StanModel(model_code=schools_code)
fit = sm.sampling(data=schools_dat, iter=1000, chains=4) for My biggest difficulty right now is automatically detecting that both data = to_xarray(
non_centered_eight_trace,
coords = {
'school': np.arange(J)
},
dims={
'theta_tilde': ['school'],
'theta': ['school'],
}
) The output from that looks like this: <xarray.Dataset>
Dimensions: (chain: 4, sample: 500, school: 8)
Coordinates:
* school (school) int64 0 1 2 3 4 5 6 7
* sample (sample) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 ...
* chain (chain) int64 0 1 2 3
Data variables:
mu (chain, sample) float64 2.431 3.441 0.8163 2.117 ...
theta_tilde (chain, sample, school) float64 -0.03801 -0.2972 ...
tau (chain, sample) float64 10.98 2.865 5.888 0.3688 ...
theta (chain, sample, school) float64 2.014 -0.8323 ...
stat__max_energy_error (chain, sample) float64 1.172 0.9255 -0.1971 ...
stat__mean_tree_accept (chain, sample) float64 0.8326 0.5738 0.9696 ...
stat__step_size (chain, sample) float64 0.6574 0.6574 0.6574 ...
stat__tree_size (chain, sample) float64 7.0 7.0 7.0 7.0 7.0 7.0 ...
stat__energy (chain, sample) float64 48.17 49.51 49.38 55.18 ...
stat__tune (chain, sample) bool False False False False ...
stat__diverging (chain, sample) bool False False False False ...
stat__energy_error (chain, sample) float64 -0.04698 0.4614 -0.1971 ...
stat__depth (chain, sample) int64 3 3 3 3 3 3 3 3 3 3 3 3 3 ...
stat__step_size_bar (chain, sample) float64 0.5544 0.5544 0.5544 ... |
Hrm, I am reading your and Adrian's post more carefully, and agree that it would be a better model to have multiple xarrays for the different parts of the summary. Let me keep looking at this, but can you give an example of how your output looks on the non-centered 8 schools model? |
That sounds good. @ColCarroll I don't think we can autodetect if things are the same dimension. Just because the shape is the same doesn't mean it has the same dimension. That is why I like the explicit syntax in the model:
One additional issue is that of the index
|
Just a quick notebook about how I think we could use the xarrays: https://gist.github.com/aseyboldt/99b8b3ba71d0d58a92264c3bf99bbbf9 |
It looks like altair might add support for xarray: vega/altair#891 |
I created a (sketch of a) design document for a netcdf file-format that stores traces. Ideally, I think both stan and pymc (and of course other tools) could write their traces in this fromat, and arviz could use a xarray representation of that for visualisation. The file format should contain all info needed to reproduce the run, and also to debug sampling trouble. @ahartikainen Is that similar to what you have in mind? I don't have any problems with major changes to this, it is only meant as a starting point. Feel free to edit this as you like. |
@ColCarroll here is an example of the metacode above https://gist.github.com/ahartikainen/b16704eec3a912ccd3bb39d62ca04279 Samples / Draw, not sure what is the correct term. I think that one draw equals one value for each parameter in the model. @aseyboldt that looks a good starting point. |
@aseyboldt I worked off your example, and also added a function that gives an informative error message since it took me a little while to understand the syntax. https://gist.github.com/ColCarroll/c607842947b08bc44d4e1588e6bef98d @ahartikainen Is there a way to tell data = to_xarray(non_centered_eight_trace,
coords={'school': np.arange(8)},
dims={'theta_tilde': ['school'], 'theta': ['school']}
) Which means the resulting dimensions are just |
Doing that automatically: probably not an easy task. Your way to define them looks good. What is the array order that xarray uses. Which part are "contiguous"? Should that be reflected in the order of axes? |
@ahartikainen I think that depends a lot on the backend. If you read data from netcdf4 (hdf5 internally), it autoselects some chunking (which we can override if we want). On numpy I think it uses numpy conventions, so usually (changes for transposition) c-continuous storage. From that point of view I think an order like @ColCarroll Errors are great :-) (or rather error messages, I don't like errors) |
Wikipedia https://en.wikipedia.org/wiki/Sample_(statistics) says: "In statistics and quantitative research methodology, a data sample is a set of data collected and/or selected from a statistical population by a defined procedure." So in this case it would be natural that posterior sample is a set of posterior draws. This is what Stan team recommends, although it has not been strictly enforced and variation exists. |
We try to match our terminology to Stan so draws is fine with me. |
Wanted to post a quick update on this issue, since there was a lot of good discussion earlier: -- There are utilities now for converting posteriors from PyStan and PyMC3 to xarray Datasets My view of next steps, which can go mostly in parallel: After that (which isn't that much!), I think it would be reasonable to cut a release on pypi/conda-forge, and start working on using, for example, sampler statistics or observed data in some of these visualizations/analyses. I appreciate any input/suggestions/help, as always! |
Should we add also It could be "all" or nothing. Also common names are easier if they are fixed (?) |
How would this relate to developments around See also xtensor-stack/xtensor#394 (comment) It would be great to have a "common language for pymc3, pystan, and pymc4." that brings them closer to JVM. |
@SemanticBeeng That's definitely where this is headed. I wouldn't compare it to JVM but rather a common format to store model results that's standardized across different PPLs. |
Indeed not to JVM specifically but a platform independent format - so is JVM included? If relevant to you, curious to know how the intent in this thread:
I understand that if JVM is not in the picture then the above are not applicable to this context. |
Is this now use InferenceData all throughout? |
The library is now using xarray and netcdf throughout. |
There have been proposals to use xarray as a common language for pymc3, pystan, and pymc4. This library might be a good place to start that by implementing utility functions for translating pymc3's Multitrace and pystan's OrderedDict into xarray objects, and then having all plotting functions work with xarrays.
The text was updated successfully, but these errors were encountered: