Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use random effects as predictors in formula via 're' terms #1687

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ export(ranef)
export(rasym_laplace)
export(rbeta_binomial)
export(rdirichlet)
export(re)
export(read_csv_as_stanfit)
export(recompile_model)
export(reloo)
Expand Down
9 changes: 6 additions & 3 deletions R/brmsframe.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ brmsframe.brmsterms <- function(x, data, frame = NULL, basis = NULL, ...) {
# this must be a multivariate model
stopifnot(is.list(frame))
x$frame <- frame
x$frame$re <- subset2(x$frame$re, resp = x$resp)
}
data <- subset_data(data, x)
x$frame$resp <- frame_resp(x, data = data)
Expand All @@ -51,6 +50,10 @@ brmsframe.brmsterms <- function(x, data, frame = NULL, basis = NULL, ...) {
basis = basis$nlpars[[nlp]], ...
)
}
# If this is a multivariate model, retain only the subset of random effects
# belonging to the current response variable. Subsetting is performed here
# rather than earlier to allow for correct validation of 're' terms in stan_sp.
x$frame$re <- subset2(x$frame$re, resp = x$resp)
class(x) <- c("brmsframe", class(x))
x
}
Expand Down Expand Up @@ -79,8 +82,8 @@ brmsframe.btl <- function(x, data, frame = list(), basis = NULL, ...) {
x$frame$sp <- frame_sp(x, data = data)
x$frame$gp <- frame_gp(x, data = data)
x$frame$ac <- frame_ac(x, data = data)
# only store the ranefs of this specific linear formula
x$frame$re <- subset2(frame$re, ls = check_prefix(x))
# only keep the ranefs of this specific linear formula
x$frame$re <- subset2(x$frame$re, ls = check_prefix(x))
class(x) <- c("bframel", class(x))
# these data_ functions may require the outputs of the corresponding
# frame_ functions (but not vice versa) and are thus evaluated last
Expand Down
7 changes: 3 additions & 4 deletions R/brmsterms.R
Original file line number Diff line number Diff line change
Expand Up @@ -434,23 +434,22 @@ terms_cs <- function(formula) {

# extract special effects terms
terms_sp <- function(formula) {
types <- c("mo", "me", "mi")
out <- find_terms(formula, types, complete = FALSE)
out <- find_terms(formula, all_sp_types(), complete = FALSE)
if (!length(out)) {
return(NULL)
}
uni_mo <- get_matches_expr(regex_sp("mo"), out)
uni_me <- get_matches_expr(regex_sp("me"), out)
uni_mi <- get_matches_expr(regex_sp("mi"), out)
uni_re <- get_matches_expr(regex_sp("re"), out)
# remove the intercept as it is handled separately
out <- str2formula(c("0", out))
attr(out, "int") <- FALSE
attr(out, "uni_mo") <- uni_mo
attr(out, "uni_me") <- uni_me
attr(out, "uni_mi") <- uni_mi
attr(out, "uni_re") <- uni_re
attr(out, "allvars") <- str2formula(all_vars(out))
# TODO: do we need sp_fake_formula at all?
# attr(out, "allvars") <- sp_fake_formula(uni_mo, uni_me, uni_mi)
out
}

Expand Down
14 changes: 13 additions & 1 deletion R/formula-re.R
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ get_re.btl <- function(x, ...) {
# id: ID of the group-level effect
# group: name of the grouping factor
# gn: number of the grouping term within the respective formula
# gtype: type of the grouping term: 'gr' or 'mm'
# coef: name of the group-level effect
# cn: number of the effect within the ID
# resp: name of the response variable
Expand Down Expand Up @@ -710,7 +711,7 @@ frame_re <- function(bterms, data, old_levels = NULL) {

empty_reframe <- function() {
out <- data.frame(
id = numeric(0), group = character(0), gn = numeric(0),
id = numeric(0), group = character(0), gn = numeric(0), gtype = character(0),
coef = character(0), cn = numeric(0), resp = character(0),
dpar = character(0), nlpar = character(0), ggn = numeric(0),
cor = logical(0), type = character(0), form = character(0),
Expand All @@ -736,6 +737,17 @@ is.reframe <- function(x) {
inherits(x, "reframe")
}

# helper function to find matching rows in reframes
# @param x the reframe to be matched
# @param y the reference reframe to be matched against
# @return an integer vector of matching rows
which_rows_reframe <- function(x, y) {
stopifnot(is.reframe(x), is.reframe(y))
# these columns define a row uniquely in reframes
cols <- c("group", "coef", "resp", "dpar", "nlpar")
which_rows(x, ls = y[cols])
}

# extract names of all grouping variables
get_group_vars <- function(x, ...) {
UseMethod("get_group_vars")
Expand Down
123 changes: 105 additions & 18 deletions R/formula-sp.R
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,67 @@ mo <- function(x, id = NA) {
out
}

#' Group-level effects as predictors in \pkg{brms} Models
#'
#' Specify a group-level predictor term in \pkg{brms}. That is,
#' use group-level effects defined somewhere in the model as
#' predictors in another part of the model. The function does not
#' evaluate its arguments -- it exists purely to help set up a model.
#'
#' @param gr Name of the grouping factor of the group-level effect
#' to be used as predictor.
#' @param coef Optional name of the coefficient of the group-level effect.
#' Defaults to \code{"Intercept"}.
#' @param resp Optional name of the response variable of the group-level effect.
#' @param dpar Optional name of the distributional parameter of the group-level effect.
#' @param nlpar Optional name of the non-linear parameter of the group-level effect.
#'
#' @seealso \code{\link{brmsformula}}
#'
#' @examples
#' \dontrun{
#' # use the group-level intercept of 'AY' for parameter 'ult'
#' # as predictor for the residual standard deviation 'sigma'
#' # multiplying by 1000 reduces the scale of 'ult' to roughly unity
#' bform <- bf(
#' cum ~ 1000 * ult * (1 - exp(-(dev/theta)^omega)),
#' ult ~ 1 + (1|AY), omega ~ 1, theta ~ 1,
#' sigma ~ re(AY, nlpar = "ult"),
#' nl = TRUE
#' )
#' bprior <- c(
#' prior(normal(5, 1), nlpar = "ult"),
#' prior(normal(1, 2), nlpar = "omega"),
#' prior(normal(45, 10), nlpar = "theta"),
#' prior(normal(0, 0.5), dpar = "sigma")
#' )
#'
#' fit <- brm(
#' bform, data = loss,
#' family = gaussian(),
#' prior = bprior,
#' control = list(adapt_delta = 0.9),
#' chains = 2
#' )
#' summary(fit)
#'
#' # shows how sigma varies as a function of the AY levels
#' conditional_effects(fit, "AY", dpar = "sigma", re_formula = NULL)
#' }
#'
#' @export
re <- function(gr, coef = "Intercept", resp = "", dpar = "", nlpar = "") {
term <- as_one_character(deparse_no_string(substitute(gr)))
coef <- as_one_character(coef)
resp <- as_one_character(resp)
dpar <- as_one_character(dpar)
nlpar <- as_one_character(nlpar)
label <- deparse0(match.call())
out <- nlist(term, coef, resp, dpar, nlpar, label)
class(out) <- c("re_term", "sp_term")
out
}

# find variable names for which to keep NAs
vars_keep_na <- function(x, ...) {
UseMethod("vars_keep_na")
Expand Down Expand Up @@ -352,8 +413,7 @@ get_sp_vars <- function(x, type) {
}

# gather information of special effects terms
# @param x either a formula or a list containing an element "sp"
# @param data data frame containing the monotonic variables
# @param x a formula, brmsterms, or brmsframe object
# @return a data.frame with one row per special term
# TODO: refactor to store in long format to avoid several list columns?
frame_sp <- function(x, data) {
Expand All @@ -368,7 +428,7 @@ frame_sp <- function(x, data) {
out <- data.frame(term = colnames(mm), stringsAsFactors = FALSE)
out$coef <- rename(out$term)
calls_cols <- c(paste0("calls_", all_sp_types()), "joint_call")
list_cols <- c("vars_mi", "idx_mi", "idx2_mi", "ids_mo", "Imo")
list_cols <- c("vars_mi", "idx_mi", "idx2_mi", "ids_mo", "Imo", "reframe")
for (col in c(calls_cols, list_cols)) {
out[[col]] <- vector("list", nrow(out))
}
Expand All @@ -377,7 +437,7 @@ frame_sp <- function(x, data) {
for (i in seq_rows(out)) {
# prepare mo terms
take_mo <- grepl_expr(regex_sp("mo"), terms_split[[i]])
if (sum(take_mo)) {
if (any(take_mo)) {
out$calls_mo[[i]] <- terms_split[[i]][take_mo]
nmo <- length(out$calls_mo[[i]])
out$Imo[[i]] <- (kmo + 1):(kmo + nmo)
Expand All @@ -394,15 +454,15 @@ frame_sp <- function(x, data) {
}
# prepare me terms
take_me <- grepl_expr(regex_sp("me"), terms_split[[i]])
if (sum(take_me)) {
if (any(take_me)) {
out$calls_me[[i]] <- terms_split[[i]][take_me]
# remove 'I' (identity) function calls that
# were used solely to separate formula terms
out$calls_me[[i]] <- gsub("^I\\(", "(", out$calls_me[[i]])
}
# prepare mi terms
take_mi <- grepl_expr(regex_sp("mi"), terms_split[[i]])
if (sum(take_mi)) {
if (any(take_mi)) {
mi_parts <- terms_split[[i]][take_mi]
out$calls_mi[[i]] <- get_matches_expr(regex_sp("mi"), mi_parts)
out$vars_mi[[i]] <- out$idx_mi[[i]] <- rep(NA, length(out$calls_mi[[i]]))
Expand All @@ -416,6 +476,40 @@ frame_sp <- function(x, data) {
# do it like terms_resp to ensure correct matching
out$vars_mi[[i]] <- gsub("\\.|_", "", make.names(out$vars_mi[[i]]))
}
take_re <- grepl_expr(regex_sp("re"), terms_split[[i]])
if (any(take_re)) {
re_parts <- terms_split[[i]][take_re]
out$calls_re[[i]] <- get_matches_expr(regex_sp("re"), re_parts)
out$reframe[[i]] <- vector("list", length(out$calls_re[[i]]))
for (j in seq_along(out$calls_re[[i]])) {
re_call <- out$calls_re[[i]][[j]]
re_term <- eval2(re_call)
if (!is.null(x$frame$re)) {
stopifnot(is.reframe(x$frame$re))
cols <- c("coef", "resp", "dpar", "nlpar")
rf <- subset2(x$frame$re, group = re_term$term, ls = re_term[cols])
# Ideally we should check here if the required re term can be found.
# However this will lead to errors in post-processing even if the
# re terms are not actually evaluated. See prepare_predictions_sp
# for more details. The necessary pre-processing validity check
# is instead done in stan_sp.
# if (!NROW(rf)) {
# stop2("Cannot find varying coefficients belonging to ", re_call, ".")
# }
# there should theoretically never be more than one matching row
stopifnot(NROW(rf) <= 1L)
if (isTRUE(rf$gtype == "mm")) {
stop2("Multimembership terms are not yet supported by 're'.")
}
out$reframe[[i]][[j]] <- rf
}
}
if (!isNULL(out$reframe[[i]])) {
out$reframe[[i]] <- Reduce(rbind, out$reframe[[i]])
} else {
out$reframe[[i]] <- empty_reframe()
}
}
has_sp_calls <- grepl_expr(regex_sp(all_sp_types()), terms_split[[i]])
sp_calls <- sub("^I\\(", "(", terms_split[[i]][has_sp_calls])
out$joint_call[[i]] <- paste0(sp_calls, collapse = " * ")
Expand Down Expand Up @@ -540,17 +634,6 @@ sp_model_matrix <- function(formula, data, types = all_sp_types(), ...) {
out
}

# formula of variables used in special effects terms
sp_fake_formula <- function(...) {
dots <- c(...)
out <- vector("list", length(dots))
for (i in seq_along(dots)) {
tmp <- eval2(dots[[i]])
out[[i]] <- all_vars(c(tmp$term, tmp$sdx, tmp$gr))
}
str2formula(unique(unlist(out)))
}

# extract an me variable
get_me_values <- function(term, data) {
term <- get_sp_term(term)
Expand Down Expand Up @@ -625,7 +708,7 @@ get_sp_term <- function(term) {

# all effects which fall under the 'sp' category of brms
all_sp_types <- function() {
c("mo", "me", "mi")
c("mo", "me", "mi", "re")
}

# classes used to set up special effects terms
Expand All @@ -644,3 +727,7 @@ is.me_term <- function(x) {
is.mi_term <- function(x) {
inherits(x, "mi_term")
}

is.re_term <- function(x) {
inherits(x, "re_term")
}
6 changes: 6 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ find_elements <- function(x, ..., ls = list(), fun = '%in%') {

# find rows of 'x' matching columns passed via 'ls' and '...'
# similar to 'find_elements' but for matrix like objects
# TODO: rename ls and fun to .ls and .fun to prevent name clashing
find_rows <- function(x, ..., ls = list(), fun = '%in%') {
x <- as.data.frame(x)
if (!nrow(x)) {
Expand All @@ -162,6 +163,11 @@ find_rows <- function(x, ..., ls = list(), fun = '%in%') {
out
}

# short form of which(find_rows())
which_rows <- function(x, ..., ls = list(), fun = '%in%') {
which(find_rows(x, ..., ls = ls, fun = fun))
}

# subset 'x' using arguments passed via 'ls' and '...'
subset2 <- function(x, ..., ls = list(), fun = '%in%') {
x[find_rows(x, ..., ls = ls, fun = fun), , drop = FALSE]
Expand Down
7 changes: 7 additions & 0 deletions R/predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,13 @@ predictor_sp <- function(prep, i) {
for (j in seq_along(sp[["idxl"]])) {
eval_list[[names(sp[["idxl"]])[j]]] <- p(sp[["idxl"]][[j]], i, row = FALSE)
}
for (j in seq_along(sp[["r"]])) {
# r is not subsetted here since subsetting is handled via Jr
# the advantages of this approach is a reduced memory requirement
# as only the draws per level instead of per observation need to be stored
eval_list[[paste0("r_", j)]] <- sp[["r"]][[j]]
eval_list[[paste0("Jr_", j)]] <- p(sp[["Jr"]][[j]], i)
}
for (j in seq_along(sp[["Csp"]])) {
eval_list[[paste0("Csp_", j)]] <- p(sp[["Csp"]][[j]], i, row = FALSE)
}
Expand Down
Loading
Loading