Skip to content

Commit

Permalink
feature issue #1674
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Sep 13, 2024
1 parent f6632d7 commit 496a6d8
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 29 deletions.
13 changes: 9 additions & 4 deletions R/loo_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ loo_predict.brmsfit <- function(object, type = c("mean", "var", "quantile"),
type <- match.arg(type)
if (is.null(psis_object)) {
message("Running PSIS to compute weights")
psis_object <- compute_loo(object, criterion = "psis", resp = resp, ...)
# run loo instead of psis to allow for moment matching
loo_object <- loo(object, resp = resp, save_psis = TRUE, ...)
psis_object <- loo_object$psis_object
}
preds <- posterior_predict(object, resp = resp, ...)
E_loo_value(preds, psis_object, type = type, probs = probs)
Expand All @@ -79,10 +81,11 @@ loo_epred.brmsfit <- function(object, type = c("mean", "var", "quantile"),
probs = 0.5, psis_object = NULL, resp = NULL,
...) {
type <- match.arg(type)
# stopifnot_resp(object, resp)
if (is.null(psis_object)) {
message("Running PSIS to compute weights")
psis_object <- compute_loo(object, criterion = "psis", resp = resp, ...)
# run loo instead of psis to allow for moment matching
loo_object <- loo(object, resp = resp, save_psis = TRUE, ...)
psis_object <- loo_object$psis_object
}
preds <- posterior_epred(object, resp = resp, ...)
E_loo_value(preds, psis_object, type = type, probs = probs)
Expand All @@ -106,7 +109,9 @@ loo_linpred.brmsfit <- function(object, type = c("mean", "var", "quantile"),
type <- match.arg(type)
if (is.null(psis_object)) {
message("Running PSIS to compute weights")
psis_object <- compute_loo(object, criterion = "psis", resp = resp, ...)
# run loo instead of psis to allow for moment matching
loo_object <- loo(object, resp = resp, save_psis = TRUE, ...)
psis_object <- loo_object$psis_object
}
preds <- posterior_linpred(object, resp = resp, ...)
E_loo_value(preds, psis_object, type = type, probs = probs)
Expand Down
44 changes: 20 additions & 24 deletions R/pp_check.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#' If \code{NULL} all draws are used. If not specified,
#' the number of posterior draws is chosen automatically.
#' Ignored if \code{draw_ids} is not \code{NULL}.
#' @param prefix The prefix of the \pkg{bayesplot} function to be applied.
#' @param prefix The prefix of the \pkg{bayesplot} function to be applied.
#' Either `"ppc"` (posterior predictive check; the default)
#' or `"ppd"` (posterior predictive distribution), the latter being the same
#' as the former except that the observed data is not shown for `"ppd"`.
Expand Down Expand Up @@ -53,7 +53,7 @@
#'
#' ## get an overview of all valid types
#' pp_check(fit, type = "xyz")
#'
#'
#' ## get a plot without the observed data
#' pp_check(fit, prefix = "ppd")
#' }
Expand All @@ -62,7 +62,7 @@
#' @export pp_check
#' @export
pp_check.brmsfit <- function(object, type, ndraws = NULL, prefix = c("ppc", "ppd"),
group = NULL, x = NULL, newdata = NULL, resp = NULL,
group = NULL, x = NULL, newdata = NULL, resp = NULL,
draw_ids = NULL, nsamples = NULL, subset = NULL, ...) {
dots <- list(...)
if (missing(type)) {
Expand Down Expand Up @@ -124,7 +124,7 @@ pp_check.brmsfit <- function(object, type, ndraws = NULL, prefix = c("ppc", "ppd
"error_scatter_avg", "error_scatter_avg_vs_x",
"intervals", "intervals_grouped",
"loo_intervals", "loo_pit", "loo_pit_overlay",
"loo_pit_qq", "loo_ribbon",
"loo_pit_qq", "loo_ribbon",
'pit_ecdf', 'pit_ecdf_grouped',
"ribbon", "ribbon_grouped",
"rootogram", "scatter_avg", "scatter_avg_grouped",
Expand All @@ -147,7 +147,7 @@ pp_check.brmsfit <- function(object, type, ndraws = NULL, prefix = c("ppc", "ppd
y <- NULL
if (prefix == "ppc") {
# y is ignored in prefix 'ppd' plots
y <- get_y(object, resp = resp, newdata = newdata, ...)
y <- get_y(object, resp = resp, newdata = newdata, ...)
}
draw_ids <- validate_draw_ids(object, draw_ids, ndraws)
pred_args <- list(
Expand All @@ -167,7 +167,7 @@ pp_check.brmsfit <- function(object, type, ndraws = NULL, prefix = c("ppc", "ppd
object, newdata = newdata, resp = resp,
re_formula = NA, check_response = TRUE, ...
)

# prepare plotting arguments
ppc_args <- list()
if (prefix == "ppc") {
Expand All @@ -185,17 +185,16 @@ pp_check.brmsfit <- function(object, type, ndraws = NULL, prefix = c("ppc", "ppd
ppc_args$x <- as.numeric(ppc_args$x)
}
}
if ("psis_object" %in% setdiff(names(formals(ppc_fun)), names(ppc_args))) {
ppc_args$psis_object <- do_call(
compute_loo, c(pred_args, criterion = "psis")
)
}
if ("lw" %in% setdiff(names(formals(ppc_fun)), names(ppc_args))) {
ppc_args$lw <- weights(
do_call(compute_loo, c(pred_args, criterion = "psis"))
)
# run loo instead of psis to allow for moment matching
loo_object <- do_call(loo, c(pred_args, save_psis = TRUE))
ppc_args$lw <- weights(loo_object$psis_object, log = TRUE)
} else if ("psis_object" %in% setdiff(names(formals(ppc_fun)), names(ppc_args))) {
# some PPCs may only support 'psis_object' but not 'lw' for whatever reason
loo_object <- do_call(loo, c(pred_args, save_psis = TRUE))
ppc_args$psis_object <- loo_object$psis_object
}

# censored responses are misleading when displayed in pp_check
bterms <- brmsterms(object$formula)
cens <- get_cens(bterms, data, resp = resp)
Expand All @@ -213,20 +212,17 @@ pp_check.brmsfit <- function(object, type, ndraws = NULL, prefix = c("ppc", "ppd
if (!is.null(ppc_args$x)) {
ppc_args$x <- ppc_args$x[take]
}
if (!is.null(ppc_args$psis_object)) {
# tidier to re-compute with subset
psis_args <- c(pred_args, criterion = "psis")
psis_args$newdata <- data[take, ]
ppc_args$psis_object <- do_call(compute_loo, psis_args)
}
if (!is.null(ppc_args$lw)) {
ppc_args$lw <- ppc_args$lw[,take]
ppc_args$lw <- ppc_args$lw[, take]
} else if (!is.null(ppc_args$psis_object)) {
# we only need the log weights so the rest can remain unchanged
ppc_args$psis_object$log_weights <- ppc_args$psis_object$log_weights[, take]
}
}

# most ... arguments are meant for the prediction function
for_pred <- names(dots) %in% names(formals(prepare_predictions.brmsfit))
ppc_args <- c(ppc_args, dots[!for_pred])

do_call(ppc_fun, ppc_args)
}
2 changes: 1 addition & 1 deletion man/pp_check.brmsfit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 496a6d8

Please sign in to comment.