diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index d36a0a3208..ff40234f1e 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -806,17 +806,19 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs= kwargs = {"input_core_dims": [["__sample__"]]} logsumexp_dask = dask_kwargs.copy() logsumexp_dask["output_dtypes"] = [float] + logsumexp_out = xr.zeros_like(log_weights.isel(__sample__=0, drop=True)).data loo_lppd_i = scale_value * _wrap_xarray_ufunc( - _logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, dask_kwargs=logsumexp_dask, **kwargs + _logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, func_kwargs={"out": logsumexp_out}, dask_kwargs=logsumexp_dask, **kwargs ) loo_lppd = loo_lppd_i.sum().compute().item() loo_lppd_se = (n_data_points * loo_lppd_i.var().compute().item()) ** 0.5 + lppd_out = xr.zeros_like(log_weights.isel(__sample__=0, drop=True)).data lppd = ( _wrap_xarray_ufunc( _logsumexp, log_likelihood, - func_kwargs={"b_inv": n_samples}, + func_kwargs={"b_inv": n_samples, "out": lppd_out}, ufunc_kwargs=ufunc_kwargs, dask_kwargs=logsumexp_dask, **kwargs, @@ -930,12 +932,13 @@ def psislw(log_weights, reff=1.0, dask_kwargs=None): # create output array with proper dimensions if dask_kwargs.get("dask", "forbidden") in {"allowed", "parallelized"}: - out = xr.zeros_like(log_weights).data, np.empty(shape) + out = xr.zeros_like(log_weights).data, xr.zeros_like(log_weights.isel(__sample__=0, drop=True)).data + dask_kwargs["output_dtypes"] = (float, float) else: out = np.empty_like(log_weights), np.empty(shape) # define kwargs - func_kwargs = {"cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin} + func_kwargs = {"cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin, "out": out} ufunc_kwargs = {"n_dims": 1, "n_output": 2, "ravel": False, "check_shape": False} kwargs = {"input_core_dims": [["__sample__"]], "output_core_dims": [["__sample__"], []]} log_weights, pareto_shape = _wrap_xarray_ufunc( diff --git a/arviz/tests/base_tests/test_stats_dask.py b/arviz/tests/base_tests/test_stats_dask.py index 8dbe969d4b..cc2ccc7f82 100644 --- a/arviz/tests/base_tests/test_stats_dask.py +++ b/arviz/tests/base_tests/test_stats_dask.py @@ -4,25 +4,36 @@ import numpy as np from ...data import load_arviz_data -from ...stats import loo +from ...stats import loo, psislw from ..helpers import multidim_models, importorskip # pylint: disable=unused-import dask = importorskip("dask", reason="Dask specific tests") @pytest.fixture() -def centered_eight(): +def chunked_centered_eight(): centered_eight = load_arviz_data("centered_eight") + centered_eight.log_likelihood = centered_eight.log_likelihood.chunk({"school": 4}) return centered_eight @pytest.mark.parametrize("multidim", (True, False)) -def test_loo(centered_eight, multidim_models, multidim): +def test_psislw(chunked_centered_eight, multidim_models, multidim): + if multidim: + log_like = multidim_models.model_1.log_likelihood["y"].stack(__sample__=["chain", "draw"]).chunk({"dim2": 3}) + else: + log_like = chunked_centered_eight.log_likelihood["obs"].stack(__sample__=["chain", "draw"]) + log_weights, khat = psislw(-log_like, dask_kwargs={"dask": "parallelized", }) + assert log_weights.shape[:-1] == khat.shape + + +@pytest.mark.parametrize("multidim", (True, False)) +def test_loo(chunked_centered_eight, multidim_models, multidim): """Test approximate leave one out criterion calculation""" if multidim: idata = multidim_models.model_1 idata.log_likelihood = idata.log_likelihood.chunk({"dim2": 3}) else: - idata = centered_eight - idata.log_likelihood = idata.log_likelihood.chunk({"school": 4}) + idata = chunked_centered_eight + idata.log_likelihood = idata.log_likelihood assert loo(idata, dask_kwargs={"dask": "parallelized"}) is not None loo_pointwise = loo(idata, pointwise=True, dask_kwargs={"dask": "parallelized"}) assert loo_pointwise is not None