From 496a6d8c4b24ddce8301b7d253689b459291ee77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Fri, 13 Sep 2024 20:38:37 +0200 Subject: [PATCH] feature issue #1674 --- R/loo_predict.R | 13 ++++++++---- R/pp_check.R | 44 +++++++++++++++++++---------------------- man/pp_check.brmsfit.Rd | 2 +- 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/R/loo_predict.R b/R/loo_predict.R index ac267b723..03dbbbd07 100644 --- a/R/loo_predict.R +++ b/R/loo_predict.R @@ -64,7 +64,9 @@ loo_predict.brmsfit <- function(object, type = c("mean", "var", "quantile"), type <- match.arg(type) if (is.null(psis_object)) { message("Running PSIS to compute weights") - psis_object <- compute_loo(object, criterion = "psis", resp = resp, ...) + # run loo instead of psis to allow for moment matching + loo_object <- loo(object, resp = resp, save_psis = TRUE, ...) + psis_object <- loo_object$psis_object } preds <- posterior_predict(object, resp = resp, ...) E_loo_value(preds, psis_object, type = type, probs = probs) @@ -79,10 +81,11 @@ loo_epred.brmsfit <- function(object, type = c("mean", "var", "quantile"), probs = 0.5, psis_object = NULL, resp = NULL, ...) { type <- match.arg(type) - # stopifnot_resp(object, resp) if (is.null(psis_object)) { message("Running PSIS to compute weights") - psis_object <- compute_loo(object, criterion = "psis", resp = resp, ...) + # run loo instead of psis to allow for moment matching + loo_object <- loo(object, resp = resp, save_psis = TRUE, ...) + psis_object <- loo_object$psis_object } preds <- posterior_epred(object, resp = resp, ...) E_loo_value(preds, psis_object, type = type, probs = probs) @@ -106,7 +109,9 @@ loo_linpred.brmsfit <- function(object, type = c("mean", "var", "quantile"), type <- match.arg(type) if (is.null(psis_object)) { message("Running PSIS to compute weights") - psis_object <- compute_loo(object, criterion = "psis", resp = resp, ...) + # run loo instead of psis to allow for moment matching + loo_object <- loo(object, resp = resp, save_psis = TRUE, ...) + psis_object <- loo_object$psis_object } preds <- posterior_linpred(object, resp = resp, ...) E_loo_value(preds, psis_object, type = type, probs = probs) diff --git a/R/pp_check.R b/R/pp_check.R index 096257dce..13e9d7751 100644 --- a/R/pp_check.R +++ b/R/pp_check.R @@ -16,7 +16,7 @@ #' If \code{NULL} all draws are used. If not specified, #' the number of posterior draws is chosen automatically. #' Ignored if \code{draw_ids} is not \code{NULL}. -#' @param prefix The prefix of the \pkg{bayesplot} function to be applied. +#' @param prefix The prefix of the \pkg{bayesplot} function to be applied. #' Either `"ppc"` (posterior predictive check; the default) #' or `"ppd"` (posterior predictive distribution), the latter being the same #' as the former except that the observed data is not shown for `"ppd"`. @@ -53,7 +53,7 @@ #' #' ## get an overview of all valid types #' pp_check(fit, type = "xyz") -#' +#' #' ## get a plot without the observed data #' pp_check(fit, prefix = "ppd") #' } @@ -62,7 +62,7 @@ #' @export pp_check #' @export pp_check.brmsfit <- function(object, type, ndraws = NULL, prefix = c("ppc", "ppd"), - group = NULL, x = NULL, newdata = NULL, resp = NULL, + group = NULL, x = NULL, newdata = NULL, resp = NULL, draw_ids = NULL, nsamples = NULL, subset = NULL, ...) { dots <- list(...) if (missing(type)) { @@ -124,7 +124,7 @@ pp_check.brmsfit <- function(object, type, ndraws = NULL, prefix = c("ppc", "ppd "error_scatter_avg", "error_scatter_avg_vs_x", "intervals", "intervals_grouped", "loo_intervals", "loo_pit", "loo_pit_overlay", - "loo_pit_qq", "loo_ribbon", + "loo_pit_qq", "loo_ribbon", 'pit_ecdf', 'pit_ecdf_grouped', "ribbon", "ribbon_grouped", "rootogram", "scatter_avg", "scatter_avg_grouped", @@ -147,7 +147,7 @@ pp_check.brmsfit <- function(object, type, ndraws = NULL, prefix = c("ppc", "ppd y <- NULL if (prefix == "ppc") { # y is ignored in prefix 'ppd' plots - y <- get_y(object, resp = resp, newdata = newdata, ...) + y <- get_y(object, resp = resp, newdata = newdata, ...) } draw_ids <- validate_draw_ids(object, draw_ids, ndraws) pred_args <- list( @@ -167,7 +167,7 @@ pp_check.brmsfit <- function(object, type, ndraws = NULL, prefix = c("ppc", "ppd object, newdata = newdata, resp = resp, re_formula = NA, check_response = TRUE, ... ) - + # prepare plotting arguments ppc_args <- list() if (prefix == "ppc") { @@ -185,17 +185,16 @@ pp_check.brmsfit <- function(object, type, ndraws = NULL, prefix = c("ppc", "ppd ppc_args$x <- as.numeric(ppc_args$x) } } - if ("psis_object" %in% setdiff(names(formals(ppc_fun)), names(ppc_args))) { - ppc_args$psis_object <- do_call( - compute_loo, c(pred_args, criterion = "psis") - ) - } if ("lw" %in% setdiff(names(formals(ppc_fun)), names(ppc_args))) { - ppc_args$lw <- weights( - do_call(compute_loo, c(pred_args, criterion = "psis")) - ) + # run loo instead of psis to allow for moment matching + loo_object <- do_call(loo, c(pred_args, save_psis = TRUE)) + ppc_args$lw <- weights(loo_object$psis_object, log = TRUE) + } else if ("psis_object" %in% setdiff(names(formals(ppc_fun)), names(ppc_args))) { + # some PPCs may only support 'psis_object' but not 'lw' for whatever reason + loo_object <- do_call(loo, c(pred_args, save_psis = TRUE)) + ppc_args$psis_object <- loo_object$psis_object } - + # censored responses are misleading when displayed in pp_check bterms <- brmsterms(object$formula) cens <- get_cens(bterms, data, resp = resp) @@ -213,20 +212,17 @@ pp_check.brmsfit <- function(object, type, ndraws = NULL, prefix = c("ppc", "ppd if (!is.null(ppc_args$x)) { ppc_args$x <- ppc_args$x[take] } - if (!is.null(ppc_args$psis_object)) { - # tidier to re-compute with subset - psis_args <- c(pred_args, criterion = "psis") - psis_args$newdata <- data[take, ] - ppc_args$psis_object <- do_call(compute_loo, psis_args) - } if (!is.null(ppc_args$lw)) { - ppc_args$lw <- ppc_args$lw[,take] + ppc_args$lw <- ppc_args$lw[, take] + } else if (!is.null(ppc_args$psis_object)) { + # we only need the log weights so the rest can remain unchanged + ppc_args$psis_object$log_weights <- ppc_args$psis_object$log_weights[, take] } } - + # most ... arguments are meant for the prediction function for_pred <- names(dots) %in% names(formals(prepare_predictions.brmsfit)) ppc_args <- c(ppc_args, dots[!for_pred]) - + do_call(ppc_fun, ppc_args) } diff --git a/man/pp_check.brmsfit.Rd b/man/pp_check.brmsfit.Rd index 737333f86..f81881fb8 100644 --- a/man/pp_check.brmsfit.Rd +++ b/man/pp_check.brmsfit.Rd @@ -35,7 +35,7 @@ If \code{NULL} all draws are used. If not specified, the number of posterior draws is chosen automatically. Ignored if \code{draw_ids} is not \code{NULL}.} -\item{prefix}{The prefix of the \pkg{bayesplot} function to be applied. +\item{prefix}{The prefix of the \pkg{bayesplot} function to be applied. Either `"ppc"` (posterior predictive check; the default) or `"ppd"` (posterior predictive distribution), the latter being the same as the former except that the observed data is not shown for `"ppd"`.}