Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use revised Pareto k threshold #2349

Merged
merged 5 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions arviz/plots/backends/bokeh/khatplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def plot_khat(
figsize,
xdata,
khats,
good_k,
kwargs,
threshold,
coord_labels,
Expand Down Expand Up @@ -53,7 +54,11 @@ def plot_khat(

if hlines_kwargs is None:
hlines_kwargs = {}
hlines_kwargs.setdefault("hlines", [0, 0.5, 0.7, 1])

if good_k is None:
good_k = 0.7

hlines_kwargs.setdefault("hlines", [0, good_k, 1])

cmap = None
if isinstance(color, str):
Expand All @@ -75,7 +80,7 @@ def plot_khat(
rgba_c = cmap(color)

khats = khats if isinstance(khats, np.ndarray) else khats.values.flatten()
alphas = 0.5 + 0.2 * (khats > 0.5) + 0.3 * (khats > 1)
alphas = 0.5 + 0.2 * (khats > good_k) + 0.3 * (khats > 1)

rgba_c = vectorized_to_hex(rgba_c)

Expand Down Expand Up @@ -130,7 +135,7 @@ def plot_khat(
xmax = len(khats)

if show_bins:
bin_edges = np.array([ymin, 0.5, 0.7, 1, ymax])
bin_edges = np.array([ymin, good_k, 1, ymax])
bin_edges = bin_edges[(bin_edges >= ymin) & (bin_edges <= ymax)]
hist, _, _ = histogram(khats, bin_edges)
for idx, count in enumerate(hist):
Expand Down
10 changes: 7 additions & 3 deletions arviz/plots/backends/matplotlib/khatplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def plot_khat(
figsize,
xdata,
khats,
good_k,
kwargs,
threshold,
coord_labels,
Expand Down Expand Up @@ -61,8 +62,11 @@ def plot_khat(
backend_kwargs.setdefault("figsize", figsize)
backend_kwargs["squeeze"] = True

if good_k is None:
good_k = 0.7

hlines_kwargs = matplotlib_kwarg_dealiaser(hlines_kwargs, "hlines")
hlines_kwargs.setdefault("hlines", [0, 0.5, 0.7, 1])
hlines_kwargs.setdefault("hlines", [0, good_k, 1])
hlines_kwargs.setdefault("linestyle", [":", "-.", "--", "-"])
hlines_kwargs.setdefault("alpha", 0.7)
hlines_kwargs.setdefault("zorder", -1)
Expand Down Expand Up @@ -102,7 +106,7 @@ def plot_khat(
rgba_c = cmap(norm_fun(color))

khats = khats if isinstance(khats, np.ndarray) else khats.values.flatten()
alphas = 0.5 + 0.2 * (khats > 0.5) + 0.3 * (khats > 1)
alphas = 0.5 + 0.2 * (khats > good_k) + 0.3 * (khats > 1)
rgba_c[:, 3] = alphas
rgba_c = vectorized_to_hex(rgba_c)
kwargs["c"] = rgba_c
Expand Down Expand Up @@ -151,7 +155,7 @@ def plot_khat(
)

if show_bins:
bin_edges = np.array([ymin, 0.5, 0.7, 1, ymax])
bin_edges = np.array([ymin, good_k, 1, ymax])
bin_edges = bin_edges[(bin_edges >= ymin) & (bin_edges <= ymax)]
hist, _, _ = histogram(khats, bin_edges)
for idx, count in enumerate(hist):
Expand Down
24 changes: 19 additions & 5 deletions arviz/plots/khatplot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Pareto tail indices plot."""

import logging
import warnings

import numpy as np
from xarray import DataArray
Expand Down Expand Up @@ -40,10 +41,8 @@ def plot_khat(

Parameters
----------
khats : ELPDData or array-like
The input Pareto tail indices to be plotted. It can be an ``ELPDData`` object containing
Pareto shapes or an array. In this second case, all the values in the array are interpreted
as Pareto tail indices.
khats : ELPDData
The input Pareto tail indices to be plotted.
color : str or array_like, default "C0"
Colors of the scatter plot, if color is a str all dots will have the same color,
if it is the size of the observations, each dot will have the specified color,
Expand Down Expand Up @@ -165,15 +164,29 @@ def plot_khat(
color = "C0"

if isinstance(khats, np.ndarray):
warnings.warn(
"support for arrays will be deprecated, please use ELPDData."
"The reason for this, is that we need to know the numbers of draws"
"sampled from the posterior",
FutureWarning,
)
khats = khats.flatten()
xlabels = False
legend = False
dims = []
good_k = None
else:
if isinstance(khats, ELPDData):
good_k = khats.good_k
khats = khats.pareto_k
Comment on lines 179 to 181
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

        if isinstance(khats, ELPDData):
            good_k = khats.good_k
            khats = khats.pareto_k
        else:
            good_k = None
            warnings.warn()

This should be something like this instead. Right now, dataarrays are also valid input as they are array-like, but we have more info than we do for numpy arrays, so they are treated more similarly to elpddata input. I think this is a reason for some of the test failures.

Note: also rebase on main to avoid unrelated failures that have been fixed already.

if not isinstance(khats, DataArray):
raise ValueError("Incorrect khat data input. Check the documentation")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The valueerror is a different if altogether, not the else branch of the one above. It is reached if the input isn't one of numpy array, dataarray or elpddata, or if the elpddata for some reason doesn't store the khat data as a dataarray

good_k = None
warnings.warn(
"support for DataArrays will be deprecated, please use ELPDData."
"The reason for this, is that we need to know the numbers of draws"
"sampled from the posterior",
FutureWarning,
)

khats = get_coords(khats, coords)
dims = khats.dims
Expand All @@ -192,6 +205,7 @@ def plot_khat(
figsize=figsize,
xdata=xdata,
khats=khats,
good_k=good_k,
kwargs=kwargs,
threshold=threshold,
coord_labels=coord_labels,
Expand Down
34 changes: 24 additions & 10 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,8 +715,9 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
se: standard error of the elpd
p_loo: effective number of parameters
shape_warn: bool
True if the estimated shape parameter of
Pareto distribution is greater than 0.7 for one or more samples
True if the estimated shape parameter of Pareto distribution is greater than a thresold
value for one or more samples. For a sample size S, the thresold is compute as
min(1 - 1/log10(S), 0.7)
loo_i: array of pointwise predictive accuracy, only if pointwise True
pareto_k: array of Pareto shape values, only if pointwise True
scale: scale of the elpd
Expand Down Expand Up @@ -785,13 +786,15 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
log_weights += log_likelihood

warn_mg = False
if np.any(pareto_shape > 0.7):
good_k = min(1 - 1 / np.log10(n_samples), 0.7)

if np.any(pareto_shape > good_k):
warnings.warn(
"Estimated shape parameter of Pareto distribution is greater than 0.7 for "
"one or more samples. You should consider using a more robust model, this is because "
"importance sampling is less likely to work well if the marginal posterior and "
"LOO posterior are very different. This is more likely to happen with a non-robust "
"model and highly influential observations."
f"Estimated shape parameter of Pareto distribution is greater than {good_k:.2f} "
"for one or more samples. You should consider using a more robust model, this is "
"because importance sampling is less likely to work well if the marginal posterior "
"and LOO posterior are very different. This is more likely to happen with a "
"non-robust model and highly influential observations."
)
warn_mg = True

Expand All @@ -816,8 +819,17 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):

if not pointwise:
return ELPDData(
data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale],
index=["elpd_loo", "se", "p_loo", "n_samples", "n_data_points", "warning", "scale"],
data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale, good_k],
index=[
"elpd_loo",
"se",
"p_loo",
"n_samples",
"n_data_points",
"warning",
"scale",
"good_k",
],
)
if np.equal(loo_lppd, loo_lppd_i).all(): # pylint: disable=no-member
warnings.warn(
Expand All @@ -835,6 +847,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
loo_lppd_i.rename("loo_i"),
pareto_shape,
scale,
good_k,
],
index=[
"elpd_loo",
Expand All @@ -846,6 +859,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
"loo_i",
"pareto_k",
"scale",
"good_k",
],
)

Expand Down
14 changes: 8 additions & 6 deletions arviz/stats/stats_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,10 +454,9 @@ def get_log_likelihood(idata, var_name=None, single_var=True):

Pareto k diagnostic values:
{{0:>{0}}} {{1:>6}}
(-Inf, 0.5] (good) {{2:{0}d}} {{6:6.1f}}%
(0.5, 0.7] (ok) {{3:{0}d}} {{7:6.1f}}%
(0.7, 1] (bad) {{4:{0}d}} {{8:6.1f}}%
(1, Inf) (very bad) {{5:{0}d}} {{9:6.1f}}%
(-Inf, {{8:.2f}}] (good) {{2:{0}d}} {{5:6.1f}}%
({{8:.2f}}, 1] (bad) {{3:{0}d}} {{6:6.1f}}%
(1, Inf) (very bad) {{4:{0}d}} {{7:6.1f}}%
"""
SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"}

Expand Down Expand Up @@ -488,11 +487,14 @@ def __str__(self):
base += "\n\nThere has been a warning during the calculation. Please check the results."

if kind == "loo" and "pareto_k" in self:
bins = np.asarray([-np.inf, 0.5, 0.7, 1, np.inf])
bins = np.asarray([-np.inf, self.good_k, 1, np.inf])
counts, *_ = _histogram(self.pareto_k.values, bins)
extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts)))))
extended = extended.format(
"Count", "Pct.", *[*counts, *(counts / np.sum(counts) * 100)]
"Count",
"Pct.",
*[*counts, *(counts / np.sum(counts) * 100)],
self.good_k,
)
base = "\n".join([base, extended])
return base
Expand Down
Loading