Skip to content

Commit

Permalink
fix failures
Browse files Browse the repository at this point in the history
  • Loading branch information
jgabry committed Feb 7, 2024
1 parent df94bee commit 68700d3
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 21 deletions.
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ S3method(.compute_point_estimate,default)
S3method(.compute_point_estimate,matrix)
S3method(.ndraws,default)
S3method(.ndraws,matrix)
S3method(.thin_draws,default)
S3method(.thin_draws,matrix)
S3method(.thin_draws,numeric)
S3method(E_loo,default)
Expand Down Expand Up @@ -151,7 +152,6 @@ export(relative_eff)
export(scrps)
export(sis)
export(stacking_weights)
export(thin_draws.default)
export(tis)
export(waic)
export(waic.array)
Expand Down
36 changes: 18 additions & 18 deletions R/loo_subsample.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ loo_subsample.function <-
cores <- loo_cores(cores)

checkmate::assert_choice(loo_approximation, choices = loo_approximation_choices(), null.ok = FALSE)
checkmate::assert_int(loo_approximation_draws, lower = 1, upper = n_draws(draws), null.ok = TRUE)
checkmate::assert_int(loo_approximation_draws, lower = 1, upper = .ndraws(draws), null.ok = TRUE)
checkmate::assert_choice(estimator, choices = estimator_choices())

.llgrad <- .llhess <- NULL
Expand Down Expand Up @@ -234,7 +234,7 @@ loo_subsample.function <-
.llgrad = .llgrad,
.llhess = .llhess,
data_dim = dim(data),
ndraws = n_draws(draws))
ndraws = .ndraws(draws))
loo_ss
}

Expand Down Expand Up @@ -537,20 +537,20 @@ elpd_loo_approximation <- function(.llfun, data, draws, cores, loo_approximation
if (loo_approximation == "none") return(rep(1L,N))

if (loo_approximation %in% c("tis", "sis")) {
draws <- thin_draws(draws, loo_approximation_draws)
draws <- .thin_draws(draws, loo_approximation_draws)
is_values <- suppressWarnings(loo.function(.llfun, data = data, draws = draws, is_method = loo_approximation))
return(is_values$pointwise[, "elpd_loo"])
}

if (loo_approximation == "waic") {
draws <- thin_draws(draws, loo_approximation_draws)
draws <- .thin_draws(draws, loo_approximation_draws)
waic_full_obj <- waic.function(.llfun, data = data, draws = draws)
return(waic_full_obj$pointwise[,"elpd_waic"])
}

# Compute the lpd or log p(y_i|y_{-i})
if (loo_approximation == "lpd") {
draws <- thin_draws(draws, loo_approximation_draws)
draws <- .thin_draws(draws, loo_approximation_draws)
lpds <- compute_lpds(N, data, draws, .llfun, cores)
return(lpds) # Use only the lpd
}
Expand All @@ -561,8 +561,8 @@ elpd_loo_approximation <- function(.llfun, data, draws, cores, loo_approximation
loo_approximation == "waic_grad_marginal" |
loo_approximation == "waic_hess") {

draws <- thin_draws(draws, loo_approximation_draws)
point_est <- compute_point_estimate(draws)
draws <- .thin_draws(draws, loo_approximation_draws)
point_est <- .compute_point_estimate(draws)
lpds <- compute_lpds(N, data, point_est, .llfun, cores)
if (loo_approximation == "plpd") return(lpds) # Use only the lpd
}
Expand All @@ -572,7 +572,7 @@ elpd_loo_approximation <- function(.llfun, data, draws, cores, loo_approximation
loo_approximation == "waic_hess") {
checkmate::assert_true(!is.null(.llgrad))

point_est <- compute_point_estimate(draws)
point_est <- .compute_point_estimate(draws)
# Compute the lpds
lpds <- compute_lpds(N, data, point_est, .llfun, cores)

Expand Down Expand Up @@ -629,7 +629,7 @@ elpd_loo_approximation <- function(.llfun, data, draws, cores, loo_approximation
#' @param draws A draws object with draws from the posterior.
#' @return A 1 by P matrix with point estimates from a draws object.
.compute_point_estimate <- function(draws) {
UseMethod("compute_point_estimate")
UseMethod(".compute_point_estimate")
}
#' @rdname dot-compute_point_estimate
#' @export
Expand All @@ -639,7 +639,7 @@ elpd_loo_approximation <- function(.llfun, data, draws, cores, loo_approximation
#' @rdname dot-compute_point_estimate
#' @export
.compute_point_estimate.default <- function(draws) {
stop("compute_point_estimate() has not been implemented for objects of class '", class(draws), "'")
stop(".compute_point_estimate() has not been implemented for objects of class '", class(draws), "'")
}

#' Thin a draws object
Expand All @@ -654,27 +654,27 @@ elpd_loo_approximation <- function(.llfun, data, draws, cores, loo_approximation
#' @param loo_approximation_draws The number of posterior draws to return (ie after thinning).
#' @return A thinned draws object.
.thin_draws <- function(draws, loo_approximation_draws) {
UseMethod("thin_draws")
UseMethod(".thin_draws")
}
#' @rdname dot-thin_draws
#' @export
.thin_draws.matrix <- function(draws, loo_approximation_draws) {
if (is.null(loo_approximation_draws)) return(draws)
checkmate::assert_int(loo_approximation_draws, lower = 1, upper = n_draws(draws), null.ok = TRUE)
S <- n_draws(draws)
checkmate::assert_int(loo_approximation_draws, lower = 1, upper = .ndraws(draws), null.ok = TRUE)
S <- .ndraws(draws)
idx <- 1:loo_approximation_draws * S %/% loo_approximation_draws
draws <- draws[idx, , drop = FALSE]
draws
}
#' @rdname dot-thin_draws
#' @export
.thin_draws.numeric <- function(draws, loo_approximation_draws) {
thin_draws.matrix(as.matrix(draws), loo_approximation_draws)
.thin_draws.matrix(as.matrix(draws), loo_approximation_draws)
}
#' @rdname dot-thin_draws
#' @export
thin_draws.default <- function(draws, loo_approximation_draws) {
stop("thin_draws() has not been implemented for objects of class '", class(draws), "'")
.thin_draws.default <- function(draws, loo_approximation_draws) {
stop(".thin_draws() has not been implemented for objects of class '", class(draws), "'")
}


Expand All @@ -689,7 +689,7 @@ thin_draws.default <- function(draws, loo_approximation_draws) {
#' @param x A draws object with posterior draws.
#' @return An integer with the number of draws.
.ndraws <- function(x) {
UseMethod("n_draws")
UseMethod(".ndraws")
}
#' @rdname dot-ndraws
#' @export
Expand All @@ -699,7 +699,7 @@ thin_draws.default <- function(draws, loo_approximation_draws) {
#' @rdname dot-ndraws
#' @export
.ndraws.default <- function(x) {
stop("n_draws() has not been implemented for objects of class '", class(x), "'")
stop(".ndraws() has not been implemented for objects of class '", class(x), "'")
}

## Subsampling -----
Expand Down
4 changes: 2 additions & 2 deletions man/dot-thin_draws.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 68700d3

Please sign in to comment.