Skip to content

Commit

Permalink
Merge pull request #1354 from n-kall/priorsense
Browse files Browse the repository at this point in the history
Add priorsense compatibility
  • Loading branch information
paul-buerkner authored Jun 24, 2024
2 parents 11ecd1b + fcf9084 commit 34ac048
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 2 deletions.
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ Package: brms
Encoding: UTF-8
Type: Package
Title: Bayesian Regression Models using 'Stan'
Version: 2.21.5
Date: 2024-05-27
Version: 2.21.6
Date: 2024-06-06
Authors@R:
c(person("Paul-Christian", "Bürkner", email = "[email protected]",
role = c("aut", "cre")),
Expand Down Expand Up @@ -52,6 +52,7 @@ Suggests:
emmeans (>= 1.4.2),
cmdstanr (>= 0.5.0),
projpred (>= 2.0.0),
priorsense (>= 1.0.0),
shinystan (>= 2.4.0),
splines2 (>= 0.5.0),
RWiener,
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ S3method(prior_predictor,brmsframe)
S3method(prior_predictor,default)
S3method(prior_predictor,mvbrmsframe)
S3method(prior_summary,brmsfit)
S3method(priorsense::create_priorsense_data,brmsfit)
S3method(priorsense::log_lik_draws,brmsfit)
S3method(priorsense::log_prior_draws,brmsfit)
S3method(projpred::get_refmodel,brmsfit)
S3method(psis,brmsfit)
S3method(r_eff_log_lik,"function")
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### New Features

* Add method `loo_epred` thanks to Aki Vehtari. (#1641)
* Add priorsense support via `create_priorsense_data.brmsfit` thanks to Noa Kallioinen. (#1354)

### Bug Fixes

Expand Down
95 changes: 95 additions & 0 deletions R/priorsense.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#' Prior sensitivity: Create priorsense data
#'
#' The \code{create_priorsense_data.brmsfit} method can be used to
#' create the data structure needed by the \pkg{priorsense} package
#' for performing power-scaling sensitivity analysis. This method is
#' called automatically when performing powerscaling via
#' \code{\link[priorsense:powerscale]{powerscale}} or other related
#' functions, so you will rarely need to call it manually yourself.
#'
#' @param x A \code{brmsfit} object.
#' @param ... Currently unused.
#'
#' @return A \code{priorsense_data} object to be used in conjunction
#' with the \pkg{priorsense} package.
#'
#' @examples
#' \dontrun{
#' # fit a model with non-uniform priors
#' fit <- brm(rating ~ treat + period + carry,
#' data = inhaler, family = sratio(),
#' prior = set_prior("normal(0, 0.5)"))
#' summary(fit)
#'
#' # The following code requires the 'priorsense' package to be installed:
#' library(priorsense)
#'
#' # perform power-scaling of the prior
#' powerscale(fit, alpha = 1.5, component = "prior")
#'
#' # perform power-scaling sensitivity checks
#' powerscale_sensitivity(fit)
#'
#' # create power-scaling sensitivity plots (for one variable)
#' powerscale_plot_dens(fit, variable = "b_treat")
#' }
#'
#' @exportS3Method priorsense::create_priorsense_data brmsfit
create_priorsense_data.brmsfit <- function(x, ...) {
priorsense::create_priorsense_data(
x = get_draws_ps(x),
fit = x,
log_prior = log_prior_draws.brmsfit(x),
log_lik = log_lik_draws.brmsfit(x),
log_prior_fn = log_prior_draws.brmsfit,
log_lik_fn = log_lik_draws.brmsfit,
log_ratio_fn = powerscale_log_ratio,
...
)
}

#' @exportS3Method priorsense::log_lik_draws
log_lik_draws.brmsfit <- function(x) {
log_lik <- log_lik(x)
log_lik <- posterior::as_draws_array(log_lik)
nvars <- nvariables(log_lik)
posterior::variables(log_lik) <- paste0("log_lik[", seq_len(nvars), "]")
log_lik
}

#' @exportS3Method priorsense::log_prior_draws
log_prior_draws.brmsfit <- function(x, log_prior_name = "lprior") {
posterior::subset_draws(
posterior::as_draws_array(x),
variable = log_prior_name
)
}

get_draws_ps <- function(x, variable = NULL, regex = FALSE,
log_prior_name = "lprior") {
excluded_variables <- c(log_prior_name, "lp__")
draws <- posterior::as_draws_df(x, regex = regex)
if (is.null(variable)) {
# remove unnecessary variables
variable <- posterior::variables(x)
variable <- variable[!(variable %in% excluded_variables)]
draws <- posterior::subset_draws(draws, variable = variable)
}
draws
}

powerscale_log_ratio <- function(draws, fit, alpha, component_fn) {
component_draws <- component_fn(fit)
component_draws <- rowsums_draws(component_draws)
component_draws * (alpha - 1)
}

rowsums_draws <- function(x) {
posterior::draws_array(
sum = rowSums(
posterior::as_draws_array(x),
dims = 2
),
.nchains = posterior::nchains(x)
)
}
47 changes: 47 additions & 0 deletions man/create_priorsense_data.brmsfit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions tests/testthat/tests.priorsense.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
context("Tests for priorsense support")

skip_on_cran()

require(priorsense)

fit1 <- rename_pars(brms:::brmsfit_example1)

test_that("create_priorsense_data returns expected output structure", {
psd1 <- create_priorsense_data(fit1)
expect_s3_class(psd$draws, "draws")
expect_s3_class(psd$fit, "brmsfit")
expect_s3_class(psd$log_lik, "draws")
expect_s3_class(psd$log_prior, "draws")
expect_true(is.function(psd$log_lik_fn))
expect_true(is.function(psd$log_prior_fn))
expect_true(is.function(psd$log_ratio_fn))
})

test_that("powerscale returns without error", {
expect_no_error(powerscale(fit1, component = "prior", alpha = 0.8))
expect_no_error(powerscale(fit1, component = "likelihood", alpha = 1.1))
})

test_that("powerscale_sensitivity returns without error", {
expect_no_error(powerscale_sensitivity(fit1))
})

0 comments on commit 34ac048

Please sign in to comment.