diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c41e75585..67e20c91e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## v0.x.x Unreleased ### New features +- Use revised Pareto k threshold ([2349](https://github.com/arviz-devs/arviz/pull/2349)) ### Maintenance and fixes - Ensure support with numpy 2.0 ([2321](https://github.com/arviz-devs/arviz/pull/2321)) @@ -12,6 +13,8 @@ - Fix legend overwriting issue in `plot_trace` ([2334](https://github.com/arviz-devs/arviz/pull/2334)) ### Deprecation +- Support for arrays and DataArrays in plot_khat has been deprecated. Only ELPDdata will be supported in the future ([2349](https://github.com/arviz-devs/arviz/pull/2349)) + ### Documentation diff --git a/arviz/plots/backends/bokeh/khatplot.py b/arviz/plots/backends/bokeh/khatplot.py index 61fc3e9992..a424680a16 100644 --- a/arviz/plots/backends/bokeh/khatplot.py +++ b/arviz/plots/backends/bokeh/khatplot.py @@ -21,6 +21,7 @@ def plot_khat( figsize, xdata, khats, + good_k, kwargs, threshold, coord_labels, @@ -53,7 +54,11 @@ def plot_khat( if hlines_kwargs is None: hlines_kwargs = {} - hlines_kwargs.setdefault("hlines", [0, 0.5, 0.7, 1]) + + if good_k is None: + good_k = 0.7 + + hlines_kwargs.setdefault("hlines", [0, good_k, 1]) cmap = None if isinstance(color, str): @@ -75,7 +80,7 @@ def plot_khat( rgba_c = cmap(color) khats = khats if isinstance(khats, np.ndarray) else khats.values.flatten() - alphas = 0.5 + 0.2 * (khats > 0.5) + 0.3 * (khats > 1) + alphas = 0.5 + 0.2 * (khats > good_k) + 0.3 * (khats > 1) rgba_c = vectorized_to_hex(rgba_c) @@ -130,7 +135,7 @@ def plot_khat( xmax = len(khats) if show_bins: - bin_edges = np.array([ymin, 0.5, 0.7, 1, ymax]) + bin_edges = np.array([ymin, good_k, 1, ymax]) bin_edges = bin_edges[(bin_edges >= ymin) & (bin_edges <= ymax)] hist, _, _ = histogram(khats, bin_edges) for idx, count in enumerate(hist): diff --git a/arviz/plots/backends/matplotlib/khatplot.py b/arviz/plots/backends/matplotlib/khatplot.py index 2e52b3dd2a..af30bc832d 100644 --- a/arviz/plots/backends/matplotlib/khatplot.py +++ b/arviz/plots/backends/matplotlib/khatplot.py @@ -20,6 +20,7 @@ def plot_khat( figsize, xdata, khats, + good_k, kwargs, threshold, coord_labels, @@ -61,8 +62,11 @@ def plot_khat( backend_kwargs.setdefault("figsize", figsize) backend_kwargs["squeeze"] = True + if good_k is None: + good_k = 0.7 + hlines_kwargs = matplotlib_kwarg_dealiaser(hlines_kwargs, "hlines") - hlines_kwargs.setdefault("hlines", [0, 0.5, 0.7, 1]) + hlines_kwargs.setdefault("hlines", [0, good_k, 1]) hlines_kwargs.setdefault("linestyle", [":", "-.", "--", "-"]) hlines_kwargs.setdefault("alpha", 0.7) hlines_kwargs.setdefault("zorder", -1) @@ -102,7 +106,7 @@ def plot_khat( rgba_c = cmap(norm_fun(color)) khats = khats if isinstance(khats, np.ndarray) else khats.values.flatten() - alphas = 0.5 + 0.2 * (khats > 0.5) + 0.3 * (khats > 1) + alphas = 0.5 + 0.2 * (khats > good_k) + 0.3 * (khats > 1) rgba_c[:, 3] = alphas rgba_c = vectorized_to_hex(rgba_c) kwargs["c"] = rgba_c @@ -151,7 +155,7 @@ def plot_khat( ) if show_bins: - bin_edges = np.array([ymin, 0.5, 0.7, 1, ymax]) + bin_edges = np.array([ymin, good_k, 1, ymax]) bin_edges = bin_edges[(bin_edges >= ymin) & (bin_edges <= ymax)] hist, _, _ = histogram(khats, bin_edges) for idx, count in enumerate(hist): diff --git a/arviz/plots/khatplot.py b/arviz/plots/khatplot.py index bc1847f1c1..3493f69eb8 100644 --- a/arviz/plots/khatplot.py +++ b/arviz/plots/khatplot.py @@ -1,6 +1,7 @@ """Pareto tail indices plot.""" import logging +import warnings import numpy as np from xarray import DataArray @@ -40,10 +41,8 @@ def plot_khat( Parameters ---------- - khats : ELPDData or array-like - The input Pareto tail indices to be plotted. It can be an ``ELPDData`` object containing - Pareto shapes or an array. In this second case, all the values in the array are interpreted - as Pareto tail indices. + khats : ELPDData + The input Pareto tail indices to be plotted. color : str or array_like, default "C0" Colors of the scatter plot, if color is a str all dots will have the same color, if it is the size of the observations, each dot will have the specified color, @@ -165,13 +164,29 @@ def plot_khat( color = "C0" if isinstance(khats, np.ndarray): + warnings.warn( + "support for arrays will be deprecated, please use ELPDData." + "The reason for this, is that we need to know the numbers of draws" + "sampled from the posterior", + FutureWarning, + ) khats = khats.flatten() xlabels = False legend = False dims = [] + good_k = None else: if isinstance(khats, ELPDData): + good_k = khats.good_k khats = khats.pareto_k + else: + good_k = None + warnings.warn( + "support for DataArrays will be deprecated, please use ELPDData." + "The reason for this, is that we need to know the numbers of draws" + "sampled from the posterior", + FutureWarning, + ) if not isinstance(khats, DataArray): raise ValueError("Incorrect khat data input. Check the documentation") @@ -192,6 +207,7 @@ def plot_khat( figsize=figsize, xdata=xdata, khats=khats, + good_k=good_k, kwargs=kwargs, threshold=threshold, coord_labels=coord_labels, diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 9747bc5f48..7f0cb9548b 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -715,8 +715,9 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None): se: standard error of the elpd p_loo: effective number of parameters shape_warn: bool - True if the estimated shape parameter of - Pareto distribution is greater than 0.7 for one or more samples + True if the estimated shape parameter of Pareto distribution is greater than a thresold + value for one or more samples. For a sample size S, the thresold is compute as + min(1 - 1/log10(S), 0.7) loo_i: array of pointwise predictive accuracy, only if pointwise True pareto_k: array of Pareto shape values, only if pointwise True scale: scale of the elpd @@ -785,13 +786,15 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None): log_weights += log_likelihood warn_mg = False - if np.any(pareto_shape > 0.7): + good_k = min(1 - 1 / np.log10(n_samples), 0.7) + + if np.any(pareto_shape > good_k): warnings.warn( - "Estimated shape parameter of Pareto distribution is greater than 0.7 for " - "one or more samples. You should consider using a more robust model, this is because " - "importance sampling is less likely to work well if the marginal posterior and " - "LOO posterior are very different. This is more likely to happen with a non-robust " - "model and highly influential observations." + f"Estimated shape parameter of Pareto distribution is greater than {good_k:.2f} " + "for one or more samples. You should consider using a more robust model, this is " + "because importance sampling is less likely to work well if the marginal posterior " + "and LOO posterior are very different. This is more likely to happen with a " + "non-robust model and highly influential observations." ) warn_mg = True @@ -816,8 +819,17 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None): if not pointwise: return ELPDData( - data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale], - index=["elpd_loo", "se", "p_loo", "n_samples", "n_data_points", "warning", "scale"], + data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale, good_k], + index=[ + "elpd_loo", + "se", + "p_loo", + "n_samples", + "n_data_points", + "warning", + "scale", + "good_k", + ], ) if np.equal(loo_lppd, loo_lppd_i).all(): # pylint: disable=no-member warnings.warn( @@ -835,6 +847,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None): loo_lppd_i.rename("loo_i"), pareto_shape, scale, + good_k, ], index=[ "elpd_loo", @@ -846,6 +859,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None): "loo_i", "pareto_k", "scale", + "good_k", ], ) diff --git a/arviz/stats/stats_utils.py b/arviz/stats/stats_utils.py index 7a5772f920..c3e636207d 100644 --- a/arviz/stats/stats_utils.py +++ b/arviz/stats/stats_utils.py @@ -454,10 +454,9 @@ def get_log_likelihood(idata, var_name=None, single_var=True): Pareto k diagnostic values: {{0:>{0}}} {{1:>6}} -(-Inf, 0.5] (good) {{2:{0}d}} {{6:6.1f}}% - (0.5, 0.7] (ok) {{3:{0}d}} {{7:6.1f}}% - (0.7, 1] (bad) {{4:{0}d}} {{8:6.1f}}% - (1, Inf) (very bad) {{5:{0}d}} {{9:6.1f}}% +(-Inf, {{8:.2f}}] (good) {{2:{0}d}} {{5:6.1f}}% + ({{8:.2f}}, 1] (bad) {{3:{0}d}} {{6:6.1f}}% + (1, Inf) (very bad) {{4:{0}d}} {{7:6.1f}}% """ SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"} @@ -488,11 +487,14 @@ def __str__(self): base += "\n\nThere has been a warning during the calculation. Please check the results." if kind == "loo" and "pareto_k" in self: - bins = np.asarray([-np.inf, 0.5, 0.7, 1, np.inf]) + bins = np.asarray([-np.inf, self.good_k, 1, np.inf]) counts, *_ = _histogram(self.pareto_k.values, bins) extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts))))) extended = extended.format( - "Count", "Pct.", *[*counts, *(counts / np.sum(counts) * 100)] + "Count", + "Pct.", + *[*counts, *(counts / np.sum(counts) * 100)], + self.good_k, ) base = "\n".join([base, extended]) return base