Skip to content

Commit

Permalink
fix out and kwargs in psislw
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Apr 5, 2024
1 parent b12b4ce commit 636a890
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
11 changes: 7 additions & 4 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,17 +806,19 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
kwargs = {"input_core_dims": [["__sample__"]]}
logsumexp_dask = dask_kwargs.copy()
logsumexp_dask["output_dtypes"] = [float]
logsumexp_out = xr.zeros_like(log_weights.isel(__sample__=0, drop=True)).data
loo_lppd_i = scale_value * _wrap_xarray_ufunc(
_logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, dask_kwargs=logsumexp_dask, **kwargs
_logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, func_kwargs={"out": logsumexp_out}, dask_kwargs=logsumexp_dask, **kwargs
)
loo_lppd = loo_lppd_i.sum().compute().item()
loo_lppd_se = (n_data_points * loo_lppd_i.var().compute().item()) ** 0.5

lppd_out = xr.zeros_like(log_weights.isel(__sample__=0, drop=True)).data
lppd = (
_wrap_xarray_ufunc(
_logsumexp,
log_likelihood,
func_kwargs={"b_inv": n_samples},
func_kwargs={"b_inv": n_samples, "out": lppd_out},
ufunc_kwargs=ufunc_kwargs,
dask_kwargs=logsumexp_dask,
**kwargs,
Expand Down Expand Up @@ -930,12 +932,13 @@ def psislw(log_weights, reff=1.0, dask_kwargs=None):

# create output array with proper dimensions
if dask_kwargs.get("dask", "forbidden") in {"allowed", "parallelized"}:
out = xr.zeros_like(log_weights).data, np.empty(shape)
out = xr.zeros_like(log_weights).data, xr.zeros_like(log_weights.isel(__sample__=0, drop=True)).data
dask_kwargs["output_dtypes"] = (float, float)
else:
out = np.empty_like(log_weights), np.empty(shape)

# define kwargs
func_kwargs = {"cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin}
func_kwargs = {"cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin, "out": out}
ufunc_kwargs = {"n_dims": 1, "n_output": 2, "ravel": False, "check_shape": False}
kwargs = {"input_core_dims": [["__sample__"]], "output_core_dims": [["__sample__"], []]}
log_weights, pareto_shape = _wrap_xarray_ufunc(
Expand Down
21 changes: 16 additions & 5 deletions arviz/tests/base_tests/test_stats_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,36 @@
import numpy as np

from ...data import load_arviz_data
from ...stats import loo
from ...stats import loo, psislw
from ..helpers import multidim_models, importorskip # pylint: disable=unused-import

dask = importorskip("dask", reason="Dask specific tests")

@pytest.fixture()
def centered_eight():
def chunked_centered_eight():
centered_eight = load_arviz_data("centered_eight")
centered_eight.log_likelihood = centered_eight.log_likelihood.chunk({"school": 4})
return centered_eight

@pytest.mark.parametrize("multidim", (True, False))
def test_loo(centered_eight, multidim_models, multidim):
def test_psislw(chunked_centered_eight, multidim_models, multidim):
if multidim:
log_like = multidim_models.model_1.log_likelihood["y"].stack(__sample__=["chain", "draw"]).chunk({"dim2": 3})
else:
log_like = chunked_centered_eight.log_likelihood["obs"].stack(__sample__=["chain", "draw"])
log_weights, khat = psislw(-log_like, dask_kwargs={"dask": "parallelized", })
assert log_weights.shape[:-1] == khat.shape


@pytest.mark.parametrize("multidim", (True, False))
def test_loo(chunked_centered_eight, multidim_models, multidim):
"""Test approximate leave one out criterion calculation"""
if multidim:
idata = multidim_models.model_1
idata.log_likelihood = idata.log_likelihood.chunk({"dim2": 3})
else:
idata = centered_eight
idata.log_likelihood = idata.log_likelihood.chunk({"school": 4})
idata = chunked_centered_eight
idata.log_likelihood = idata.log_likelihood
assert loo(idata, dask_kwargs={"dask": "parallelized"}) is not None
loo_pointwise = loo(idata, pointwise=True, dask_kwargs={"dask": "parallelized"})
assert loo_pointwise is not None
Expand Down

0 comments on commit 636a890

Please sign in to comment.