Skip to content

Commit

Permalink
continue work and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Apr 5, 2024
1 parent 9ba36ab commit b12b4ce
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 5 deletions.
22 changes: 17 additions & 5 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,9 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
log_likelihood = _get_log_likelihood(inference_data, var_name=var_name)
pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise

if dask_kwargs is None:
dask_kwargs = {}

log_likelihood = log_likelihood.stack(__sample__=("chain", "draw"))
shape = log_likelihood.shape
n_samples = shape[-1]
Expand Down Expand Up @@ -783,7 +786,9 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples
)

log_weights, pareto_shape = psislw(-log_likelihood, reff, dask_kwargs=dask_kwargs)
psis_dask = dask_kwargs.copy()
psis_dask["output_dtypes"] = (float, float)
log_weights, pareto_shape = psislw(-log_likelihood, reff, dask_kwargs=psis_dask)
log_weights += log_likelihood

warn_mg = False
Expand All @@ -799,8 +804,10 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=

ufunc_kwargs = {"n_dims": 1, "ravel": False}
kwargs = {"input_core_dims": [["__sample__"]]}
logsumexp_dask = dask_kwargs.copy()
logsumexp_dask["output_dtypes"] = [float]
loo_lppd_i = scale_value * _wrap_xarray_ufunc(
_logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, dask_kwargs=dask_kwargs, **kwargs
_logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, dask_kwargs=logsumexp_dask, **kwargs
)
loo_lppd = loo_lppd_i.sum().compute().item()
loo_lppd_se = (n_data_points * loo_lppd_i.var().compute().item()) ** 0.5
Expand All @@ -811,7 +818,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None, dask_kwargs=
log_likelihood,
func_kwargs={"b_inv": n_samples},
ufunc_kwargs=ufunc_kwargs,
dask_kwargs=dask_kwargs,
dask_kwargs=logsumexp_dask,
**kwargs,
)
.sum()
Expand Down Expand Up @@ -907,6 +914,8 @@ def psislw(log_weights, reff=1.0, dask_kwargs=None):
...: az.psislw(-log_likelihood, reff=0.8)
"""
if dask_kwargs is None:
dask_kwargs = {}
if hasattr(log_weights, "__sample__"):
n_samples = len(log_weights.__sample__)
shape = [
Expand All @@ -920,10 +929,13 @@ def psislw(log_weights, reff=1.0, dask_kwargs=None):
cutoffmin = np.log(np.finfo(float).tiny) # pylint: disable=no-member, assignment-from-no-return

# create output array with proper dimensions
out = np.empty_like(log_weights), np.empty(shape)
if dask_kwargs.get("dask", "forbidden") in {"allowed", "parallelized"}:
out = xr.zeros_like(log_weights).data, np.empty(shape)
else:
out = np.empty_like(log_weights), np.empty(shape)

# define kwargs
func_kwargs = {"cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin, "out": out}
func_kwargs = {"cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin}
ufunc_kwargs = {"n_dims": 1, "n_output": 2, "ravel": False, "check_shape": False}
kwargs = {"input_core_dims": [["__sample__"]], "output_core_dims": [["__sample__"], []]}
log_weights, pareto_shape = _wrap_xarray_ufunc(
Expand Down
37 changes: 37 additions & 0 deletions arviz/tests/base_tests/test_stats_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# pylint: disable=redefined-outer-name, no-member
import pytest

import numpy as np

from ...data import load_arviz_data
from ...stats import loo
from ..helpers import multidim_models, importorskip # pylint: disable=unused-import

dask = importorskip("dask", reason="Dask specific tests")

@pytest.fixture()
def centered_eight():
centered_eight = load_arviz_data("centered_eight")
return centered_eight

@pytest.mark.parametrize("multidim", (True, False))
def test_loo(centered_eight, multidim_models, multidim):
"""Test approximate leave one out criterion calculation"""
if multidim:
idata = multidim_models.model_1
idata.log_likelihood = idata.log_likelihood.chunk({"dim2": 3})
else:
idata = centered_eight
idata.log_likelihood = idata.log_likelihood.chunk({"school": 4})
assert loo(idata, dask_kwargs={"dask": "parallelized"}) is not None
loo_pointwise = loo(idata, pointwise=True, dask_kwargs={"dask": "parallelized"})
assert loo_pointwise is not None
assert "loo_i" in loo_pointwise
assert "pareto_k" in loo_pointwise
assert "scale" in loo_pointwise

def test_compare_loo(centered_eight):
loo_ram = loo(centered_eight)
centered_eight.log_likelihood = centered_eight.log_likelihood.chunk({"school": 2})
loo_dask = loo(centered_eight, dask_kwargs={"dask": "parallelized"})
assert np.isclose(loo_ram["elpd"], loo_dask["eldp"])

0 comments on commit b12b4ce

Please sign in to comment.