From aee1c0d6f4a118666cb6de941c0ea1c475835db4 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Mon, 27 May 2024 11:01:10 -0300 Subject: [PATCH 1/5] use revised Pareto k threshold --- arviz/plots/backends/bokeh/khatplot.py | 13 ++++++++++--- arviz/plots/backends/matplotlib/khatplot.py | 10 ++++++++-- arviz/plots/khatplot.py | 14 ++++++++++---- arviz/stats/stats.py | 19 +++++++++++-------- 4 files changed, 39 insertions(+), 17 deletions(-) diff --git a/arviz/plots/backends/bokeh/khatplot.py b/arviz/plots/backends/bokeh/khatplot.py index 61fc3e9992..c09ff99f36 100644 --- a/arviz/plots/backends/bokeh/khatplot.py +++ b/arviz/plots/backends/bokeh/khatplot.py @@ -21,6 +21,7 @@ def plot_khat( figsize, xdata, khats, + sample_size, kwargs, threshold, coord_labels, @@ -53,7 +54,13 @@ def plot_khat( if hlines_kwargs is None: hlines_kwargs = {} - hlines_kwargs.setdefault("hlines", [0, 0.5, 0.7, 1]) + + if sample_size is None: + good_k = 0.7 + else: + good_k = min(1 - 1 / np.log10(sample_size), 0.7) + + hlines_kwargs.setdefault("hlines", [0, good_k, 1]) cmap = None if isinstance(color, str): @@ -75,7 +82,7 @@ def plot_khat( rgba_c = cmap(color) khats = khats if isinstance(khats, np.ndarray) else khats.values.flatten() - alphas = 0.5 + 0.2 * (khats > 0.5) + 0.3 * (khats > 1) + alphas = 0.5 + 0.2 * (khats > good_k) + 0.3 * (khats > 1) rgba_c = vectorized_to_hex(rgba_c) @@ -130,7 +137,7 @@ def plot_khat( xmax = len(khats) if show_bins: - bin_edges = np.array([ymin, 0.5, 0.7, 1, ymax]) + bin_edges = np.array([ymin, threshold, 1, ymax]) bin_edges = bin_edges[(bin_edges >= ymin) & (bin_edges <= ymax)] hist, _, _ = histogram(khats, bin_edges) for idx, count in enumerate(hist): diff --git a/arviz/plots/backends/matplotlib/khatplot.py b/arviz/plots/backends/matplotlib/khatplot.py index 2e52b3dd2a..f6628b3ae6 100644 --- a/arviz/plots/backends/matplotlib/khatplot.py +++ b/arviz/plots/backends/matplotlib/khatplot.py @@ -20,6 +20,7 @@ def plot_khat( figsize, xdata, khats, + sample_size, kwargs, threshold, coord_labels, @@ -61,8 +62,13 @@ def plot_khat( backend_kwargs.setdefault("figsize", figsize) backend_kwargs["squeeze"] = True + if sample_size is None: + good_k = 0.7 + else: + good_k = min(1 - 1 / np.log10(sample_size), 0.7) + hlines_kwargs = matplotlib_kwarg_dealiaser(hlines_kwargs, "hlines") - hlines_kwargs.setdefault("hlines", [0, 0.5, 0.7, 1]) + hlines_kwargs.setdefault("hlines", [0, good_k, 1]) hlines_kwargs.setdefault("linestyle", [":", "-.", "--", "-"]) hlines_kwargs.setdefault("alpha", 0.7) hlines_kwargs.setdefault("zorder", -1) @@ -102,7 +108,7 @@ def plot_khat( rgba_c = cmap(norm_fun(color)) khats = khats if isinstance(khats, np.ndarray) else khats.values.flatten() - alphas = 0.5 + 0.2 * (khats > 0.5) + 0.3 * (khats > 1) + alphas = 0.5 + 0.2 * (khats > good_k) + 0.3 * (khats > 1) rgba_c[:, 3] = alphas rgba_c = vectorized_to_hex(rgba_c) kwargs["c"] = rgba_c diff --git a/arviz/plots/khatplot.py b/arviz/plots/khatplot.py index bc1847f1c1..e2bc0bbdad 100644 --- a/arviz/plots/khatplot.py +++ b/arviz/plots/khatplot.py @@ -40,10 +40,8 @@ def plot_khat( Parameters ---------- - khats : ELPDData or array-like - The input Pareto tail indices to be plotted. It can be an ``ELPDData`` object containing - Pareto shapes or an array. In this second case, all the values in the array are interpreted - as Pareto tail indices. + khats : ELPDData + The input Pareto tail indices to be plotted. color : str or array_like, default "C0" Colors of the scatter plot, if color is a str all dots will have the same color, if it is the size of the observations, each dot will have the specified color, @@ -165,12 +163,19 @@ def plot_khat( color = "C0" if isinstance(khats, np.ndarray): + _log.warning( + "support for arrays will be deprecated, please use ELPDData." + "The reason for this, is that we need to know the numbers of draws" + "sampled from the posterior" + ) khats = khats.flatten() xlabels = False legend = False dims = [] + sample_size = None else: if isinstance(khats, ELPDData): + sample_size = khats.n_samples khats = khats.pareto_k if not isinstance(khats, DataArray): raise ValueError("Incorrect khat data input. Check the documentation") @@ -192,6 +197,7 @@ def plot_khat( figsize=figsize, xdata=xdata, khats=khats, + sample_size=sample_size, kwargs=kwargs, threshold=threshold, coord_labels=coord_labels, diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 9747bc5f48..3427993781 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -715,8 +715,9 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None): se: standard error of the elpd p_loo: effective number of parameters shape_warn: bool - True if the estimated shape parameter of - Pareto distribution is greater than 0.7 for one or more samples + True if the estimated shape parameter of Pareto distribution is greater than a thresold + value for one or more samples. For a sample size S, the thresold is compute as + min(1 - 1/log10(S), 0.7) loo_i: array of pointwise predictive accuracy, only if pointwise True pareto_k: array of Pareto shape values, only if pointwise True scale: scale of the elpd @@ -785,13 +786,15 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None): log_weights += log_likelihood warn_mg = False - if np.any(pareto_shape > 0.7): + good_k = min(1 - 1 / np.log10(n_samples), 0.7) + + if np.any(pareto_shape > good_k): warnings.warn( - "Estimated shape parameter of Pareto distribution is greater than 0.7 for " - "one or more samples. You should consider using a more robust model, this is because " - "importance sampling is less likely to work well if the marginal posterior and " - "LOO posterior are very different. This is more likely to happen with a non-robust " - "model and highly influential observations." + f"Estimated shape parameter of Pareto distribution is greater than {good_k:.1f} " + "for one or more samples. You should consider using a more robust model, this is " + "because importance sampling is less likely to work well if the marginal posterior " + "and LOO posterior are very different. This is more likely to happen with a " + "non-robust model and highly influential observations." ) warn_mg = True From 5cb460ee3cc1a104cb686d719f028c9f60e7c801 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Tue, 28 May 2024 13:58:21 -0300 Subject: [PATCH 2/5] avoid duplicated computations --- arviz/plots/backends/bokeh/khatplot.py | 8 +++----- arviz/plots/backends/matplotlib/khatplot.py | 8 +++----- arviz/plots/khatplot.py | 12 +++++++----- arviz/stats/stats.py | 15 +++++++++++++-- arviz/stats/stats_utils.py | 14 ++++++++------ 5 files changed, 34 insertions(+), 23 deletions(-) diff --git a/arviz/plots/backends/bokeh/khatplot.py b/arviz/plots/backends/bokeh/khatplot.py index c09ff99f36..a424680a16 100644 --- a/arviz/plots/backends/bokeh/khatplot.py +++ b/arviz/plots/backends/bokeh/khatplot.py @@ -21,7 +21,7 @@ def plot_khat( figsize, xdata, khats, - sample_size, + good_k, kwargs, threshold, coord_labels, @@ -55,10 +55,8 @@ def plot_khat( if hlines_kwargs is None: hlines_kwargs = {} - if sample_size is None: + if good_k is None: good_k = 0.7 - else: - good_k = min(1 - 1 / np.log10(sample_size), 0.7) hlines_kwargs.setdefault("hlines", [0, good_k, 1]) @@ -137,7 +135,7 @@ def plot_khat( xmax = len(khats) if show_bins: - bin_edges = np.array([ymin, threshold, 1, ymax]) + bin_edges = np.array([ymin, good_k, 1, ymax]) bin_edges = bin_edges[(bin_edges >= ymin) & (bin_edges <= ymax)] hist, _, _ = histogram(khats, bin_edges) for idx, count in enumerate(hist): diff --git a/arviz/plots/backends/matplotlib/khatplot.py b/arviz/plots/backends/matplotlib/khatplot.py index f6628b3ae6..af30bc832d 100644 --- a/arviz/plots/backends/matplotlib/khatplot.py +++ b/arviz/plots/backends/matplotlib/khatplot.py @@ -20,7 +20,7 @@ def plot_khat( figsize, xdata, khats, - sample_size, + good_k, kwargs, threshold, coord_labels, @@ -62,10 +62,8 @@ def plot_khat( backend_kwargs.setdefault("figsize", figsize) backend_kwargs["squeeze"] = True - if sample_size is None: + if good_k is None: good_k = 0.7 - else: - good_k = min(1 - 1 / np.log10(sample_size), 0.7) hlines_kwargs = matplotlib_kwarg_dealiaser(hlines_kwargs, "hlines") hlines_kwargs.setdefault("hlines", [0, good_k, 1]) @@ -157,7 +155,7 @@ def plot_khat( ) if show_bins: - bin_edges = np.array([ymin, 0.5, 0.7, 1, ymax]) + bin_edges = np.array([ymin, good_k, 1, ymax]) bin_edges = bin_edges[(bin_edges >= ymin) & (bin_edges <= ymax)] hist, _, _ = histogram(khats, bin_edges) for idx, count in enumerate(hist): diff --git a/arviz/plots/khatplot.py b/arviz/plots/khatplot.py index e2bc0bbdad..6e4383ae8f 100644 --- a/arviz/plots/khatplot.py +++ b/arviz/plots/khatplot.py @@ -1,6 +1,7 @@ """Pareto tail indices plot.""" import logging +import warnings import numpy as np from xarray import DataArray @@ -163,19 +164,20 @@ def plot_khat( color = "C0" if isinstance(khats, np.ndarray): - _log.warning( + warnings.warn( "support for arrays will be deprecated, please use ELPDData." "The reason for this, is that we need to know the numbers of draws" - "sampled from the posterior" + "sampled from the posterior", + FutureWarning, ) khats = khats.flatten() xlabels = False legend = False dims = [] - sample_size = None + good_k = None else: if isinstance(khats, ELPDData): - sample_size = khats.n_samples + good_k = khats.good_k khats = khats.pareto_k if not isinstance(khats, DataArray): raise ValueError("Incorrect khat data input. Check the documentation") @@ -197,7 +199,7 @@ def plot_khat( figsize=figsize, xdata=xdata, khats=khats, - sample_size=sample_size, + good_k=good_k, kwargs=kwargs, threshold=threshold, coord_labels=coord_labels, diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 3427993781..c842073dea 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -819,8 +819,17 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None): if not pointwise: return ELPDData( - data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale], - index=["elpd_loo", "se", "p_loo", "n_samples", "n_data_points", "warning", "scale"], + data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale, good_k], + index=[ + "elpd_loo", + "se", + "p_loo", + "n_samples", + "n_data_points", + "warning", + "scale", + "good_k", + ], ) if np.equal(loo_lppd, loo_lppd_i).all(): # pylint: disable=no-member warnings.warn( @@ -838,6 +847,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None): loo_lppd_i.rename("loo_i"), pareto_shape, scale, + good_k, ], index=[ "elpd_loo", @@ -849,6 +859,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None): "loo_i", "pareto_k", "scale", + "good_k", ], ) diff --git a/arviz/stats/stats_utils.py b/arviz/stats/stats_utils.py index 7a5772f920..1f0f8c9892 100644 --- a/arviz/stats/stats_utils.py +++ b/arviz/stats/stats_utils.py @@ -454,10 +454,9 @@ def get_log_likelihood(idata, var_name=None, single_var=True): Pareto k diagnostic values: {{0:>{0}}} {{1:>6}} -(-Inf, 0.5] (good) {{2:{0}d}} {{6:6.1f}}% - (0.5, 0.7] (ok) {{3:{0}d}} {{7:6.1f}}% - (0.7, 1] (bad) {{4:{0}d}} {{8:6.1f}}% - (1, Inf) (very bad) {{5:{0}d}} {{9:6.1f}}% +(-Inf, {{8:.1f}}] (good) {{2:{0}d}} {{5:6.1f}}% + ({{8:.1f}}, 1] (bad) {{3:{0}d}} {{6:6.1f}}% + (1, Inf) (very bad) {{4:{0}d}} {{7:6.1f}}% """ SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"} @@ -488,11 +487,14 @@ def __str__(self): base += "\n\nThere has been a warning during the calculation. Please check the results." if kind == "loo" and "pareto_k" in self: - bins = np.asarray([-np.inf, 0.5, 0.7, 1, np.inf]) + bins = np.asarray([-np.inf, self.good_k, 1, np.inf]) counts, *_ = _histogram(self.pareto_k.values, bins) extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts))))) extended = extended.format( - "Count", "Pct.", *[*counts, *(counts / np.sum(counts) * 100)] + "Count", + "Pct.", + *[*counts, *(counts / np.sum(counts) * 100)], + self.good_k, ) base = "\n".join([base, extended]) return base From ada4736de27d6f05c335a1d333da8176c038592f Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Wed, 5 Jun 2024 10:39:25 -0300 Subject: [PATCH 3/5] fix per comments --- arviz/plots/khatplot.py | 8 +++++++- arviz/stats/stats.py | 2 +- arviz/stats/stats_utils.py | 4 ++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/arviz/plots/khatplot.py b/arviz/plots/khatplot.py index 6e4383ae8f..15ca723733 100644 --- a/arviz/plots/khatplot.py +++ b/arviz/plots/khatplot.py @@ -180,7 +180,13 @@ def plot_khat( good_k = khats.good_k khats = khats.pareto_k if not isinstance(khats, DataArray): - raise ValueError("Incorrect khat data input. Check the documentation") + good_k = None + warnings.warn( + "support for DataArrays will be deprecated, please use ELPDData." + "The reason for this, is that we need to know the numbers of draws" + "sampled from the posterior", + FutureWarning, + ) khats = get_coords(khats, coords) dims = khats.dims diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index c842073dea..7f0cb9548b 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -790,7 +790,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None): if np.any(pareto_shape > good_k): warnings.warn( - f"Estimated shape parameter of Pareto distribution is greater than {good_k:.1f} " + f"Estimated shape parameter of Pareto distribution is greater than {good_k:.2f} " "for one or more samples. You should consider using a more robust model, this is " "because importance sampling is less likely to work well if the marginal posterior " "and LOO posterior are very different. This is more likely to happen with a " diff --git a/arviz/stats/stats_utils.py b/arviz/stats/stats_utils.py index 1f0f8c9892..c3e636207d 100644 --- a/arviz/stats/stats_utils.py +++ b/arviz/stats/stats_utils.py @@ -454,8 +454,8 @@ def get_log_likelihood(idata, var_name=None, single_var=True): Pareto k diagnostic values: {{0:>{0}}} {{1:>6}} -(-Inf, {{8:.1f}}] (good) {{2:{0}d}} {{5:6.1f}}% - ({{8:.1f}}, 1] (bad) {{3:{0}d}} {{6:6.1f}}% +(-Inf, {{8:.2f}}] (good) {{2:{0}d}} {{5:6.1f}}% + ({{8:.2f}}, 1] (bad) {{3:{0}d}} {{6:6.1f}}% (1, Inf) (very bad) {{4:{0}d}} {{7:6.1f}}% """ SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"} From 42a69ec4475adb32941e40e7698f8f1e66997192 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Wed, 5 Jun 2024 14:03:07 -0300 Subject: [PATCH 4/5] fix ValueError and warning --- arviz/plots/khatplot.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/arviz/plots/khatplot.py b/arviz/plots/khatplot.py index 15ca723733..3493f69eb8 100644 --- a/arviz/plots/khatplot.py +++ b/arviz/plots/khatplot.py @@ -179,7 +179,7 @@ def plot_khat( if isinstance(khats, ELPDData): good_k = khats.good_k khats = khats.pareto_k - if not isinstance(khats, DataArray): + else: good_k = None warnings.warn( "support for DataArrays will be deprecated, please use ELPDData." @@ -187,6 +187,8 @@ def plot_khat( "sampled from the posterior", FutureWarning, ) + if not isinstance(khats, DataArray): + raise ValueError("Incorrect khat data input. Check the documentation") khats = get_coords(khats, coords) dims = khats.dims From 3d11015ee012ed060cc64f305a99a09258a0a154 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Wed, 5 Jun 2024 14:36:45 -0300 Subject: [PATCH 5/5] update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c41e75585..67e20c91e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## v0.x.x Unreleased ### New features +- Use revised Pareto k threshold ([2349](https://github.com/arviz-devs/arviz/pull/2349)) ### Maintenance and fixes - Ensure support with numpy 2.0 ([2321](https://github.com/arviz-devs/arviz/pull/2321)) @@ -12,6 +13,8 @@ - Fix legend overwriting issue in `plot_trace` ([2334](https://github.com/arviz-devs/arviz/pull/2334)) ### Deprecation +- Support for arrays and DataArrays in plot_khat has been deprecated. Only ELPDdata will be supported in the future ([2349](https://github.com/arviz-devs/arviz/pull/2349)) + ### Documentation