Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

913: Functions for PIT histograms #949

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ S3method(get_metrics,forecast_point)
S3method(get_metrics,forecast_quantile)
S3method(get_metrics,forecast_sample)
S3method(get_metrics,scores)
S3method(get_pit,default)
S3method(get_pit,forecast_quantile)
S3method(get_pit,forecast_sample)
S3method(get_pit_histogram,default)
S3method(get_pit_histogram,forecast_quantile)
S3method(get_pit_histogram,forecast_sample)
S3method(head,forecast)
S3method(print,forecast)
S3method(score,default)
Expand Down Expand Up @@ -56,7 +56,7 @@ export(get_forecast_counts)
export(get_forecast_unit)
export(get_metrics)
export(get_pairwise_comparisons)
export(get_pit)
export(get_pit_histogram)
export(interval_coverage)
export(is_forecast)
export(is_forecast_binary)
Expand All @@ -72,13 +72,12 @@ export(mad_sample)
export(new_forecast)
export(overprediction_quantile)
export(overprediction_sample)
export(pit_sample)
export(pit_histogram_sample)
export(plot_correlations)
export(plot_forecast_counts)
export(plot_heatmap)
export(plot_interval_coverage)
export(plot_pairwise_comparisons)
export(plot_pit)
export(plot_quantile_coverage)
export(plot_wis)
export(quantile_score)
Expand Down Expand Up @@ -115,9 +114,7 @@ importFrom(checkmate,assert_vector)
importFrom(checkmate,check_atomic_vector)
importFrom(checkmate,check_function)
importFrom(checkmate,check_matrix)
importFrom(checkmate,check_number)
importFrom(checkmate,check_numeric)
importFrom(checkmate,check_set_equal)
importFrom(checkmate,check_vector)
importFrom(checkmate,test_atomic_vector)
importFrom(checkmate,test_list)
Expand All @@ -138,6 +135,7 @@ importFrom(data.table,as.data.table)
importFrom(data.table,copy)
importFrom(data.table,data.table)
importFrom(data.table,dcast)
importFrom(data.table,fcase)
importFrom(data.table,is.data.table)
importFrom(data.table,melt)
importFrom(data.table,nafill)
Expand All @@ -150,16 +148,13 @@ importFrom(data.table,setorderv)
importFrom(ggplot2,.data)
importFrom(ggplot2,`%+replace%`)
importFrom(ggplot2,aes)
importFrom(ggplot2,after_stat)
importFrom(ggplot2,coord_cartesian)
importFrom(ggplot2,coord_flip)
importFrom(ggplot2,element_blank)
importFrom(ggplot2,element_line)
importFrom(ggplot2,element_text)
importFrom(ggplot2,facet_grid)
importFrom(ggplot2,facet_wrap)
importFrom(ggplot2,geom_col)
importFrom(ggplot2,geom_histogram)
importFrom(ggplot2,geom_line)
importFrom(ggplot2,geom_linerange)
importFrom(ggplot2,geom_polygon)
Expand All @@ -175,7 +170,6 @@ importFrom(ggplot2,scale_fill_gradient)
importFrom(ggplot2,scale_fill_gradient2)
importFrom(ggplot2,scale_fill_manual)
importFrom(ggplot2,scale_y_continuous)
importFrom(ggplot2,stat)
importFrom(ggplot2,theme)
importFrom(ggplot2,theme_light)
importFrom(ggplot2,theme_minimal)
Expand All @@ -187,9 +181,7 @@ importFrom(purrr,partial)
importFrom(scoringRules,crps_sample)
importFrom(scoringRules,dss_sample)
importFrom(scoringRules,logs_sample)
importFrom(stats,as.formula)
importFrom(stats,cor)
importFrom(stats,density)
importFrom(stats,mad)
importFrom(stats,median)
importFrom(stats,na.omit)
Expand Down
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ of our [original](https://doi.org/10.48550/arXiv.2205.07090) `scoringutils` pape
- Users can now also use their own scoring rules (making use of the `metrics` argument, which takes in a named list of functions). Default scoring rules can be accessed using the function `get_metrics()`, which is a a generic with S3 methods for each forecast type. It returns a named list of scoring rules suitable for the respective forecast object. For example, you could call `get_metrics(example_quantile)`. Column names of scores in the output of `score()` correspond to the names of the scoring rules (i.e. the names of the functions in the list of metrics).
- Instead of supplying arguments to `score()` to manipulate individual scoring rules users should now manipulate the metric list being supplied using `purrr::partial()` and `select_metric()`. See `?score()` for more information.
- the CRPS is now reported as decomposition into dispersion, overprediction and underprediction.
- functionality to calculate the Probability Integral Transform (PIT) has been deprecated and replaced by functionality to calculate PIT histograms, using the `get_pit_histogram()` function; as part of this change, nonrandomised PITs can now be calculated for count data, and this is is done by default

### Creating a forecast object
- The `as_forecast_<type>()` functions create a forecast object and validates it. They also allow users to rename/specify required columns and specify the forecast unit in a single step, taking over the functionality of `set_forecast_unit()` in most cases. See `?as_forecast()` for more information.
Expand Down Expand Up @@ -73,7 +74,6 @@ of our [original](https://doi.org/10.48550/arXiv.2205.07090) `scoringutils` pape
- Renamed `interval_coverage_quantile()` to `interval_coverage()`.
- "range" was consistently renamed to "interval_range" in the code. The "range"-format (which was mostly used internally) was renamed to "interval"-format
- Renamed `correlation()` to `get_correlations()` and `plot_correlation()` to `plot_correlations()`
- `pit()` was renamed to `get_pit()` and converted to an S3 method.

### Deleted functions
- Removed abs_error and squared_error from the package in favour of `Metrics::ae` and `Metrics::se`.`get_duplicate_forecasts()` now sorts outputs according to the forecast unit, making it easier to spot duplicates. In addition, there is a `counts` option that allows the user to display the number of duplicates for each forecast unit, rather than the raw duplicated rows.
Expand All @@ -84,6 +84,7 @@ of our [original](https://doi.org/10.48550/arXiv.2205.07090) `scoringutils` pape
- Removed `interval_coverage_sample()` as users are now expected to convert to a quantile format first before scoring.
- Function `set_forecast_unit()` was deleted. Instead there is now a `forecast_unit` argument in `as_forecast_<type>()` as well as in `get_duplicate_forecasts()`.
- Removed `interval_coverage_dev_quantile()`. Users can still access the difference between nominal and actual interval coverage using `get_coverage()`.
- `pit()`, `pit_sample()` and `plot_pit()` have been removed and replaced by functionality to create PIT histograms (`pit_histogram_sampel()` and `get_pit_histogram()`)

### Function changes
- `bias_quantile()` changed the way it handles forecasts where the median is missing: The median is now imputed by linear interpolation between the innermost quantiles. Previously, we imputed the median by simply taking the mean of the innermost quantiles.
Expand Down
40 changes: 34 additions & 6 deletions R/class-forecast-quantile.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,27 +175,55 @@ get_metrics.forecast_quantile <- function(x, select = NULL, exclude = NULL, ...)
}


#' @rdname get_pit
#' @rdname get_pit_histogram
#' @importFrom stats na.omit
#' @importFrom data.table `:=` as.data.table
#' @export
get_pit.forecast_quantile <- function(forecast, by, ...) {
get_pit_histogram.forecast_quantile <- function(forecast, num_bins = "auto",
breaks = NULL, by, ...) {
forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE)
forecast <- as.data.table(forecast)
present_quantiles <- unique(c(0, forecast$quantile_level, 1))
present_quantiles <- round(present_quantiles, 10)

if (!is.null(breaks)) {
quantiles <- unique(c(0, breaks, 1))
} else if (is.null(num_bins) || num_bins == "auto") {
quantiles <- present_quantiles
} else {
quantiles <- seq(0, 1, 1 / num_bins)
}
## avoid rounding errors
quantiles <- round(quantiles, 10)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to round twice? (we round in line 187 as well)

diffs <- round(diff(quantiles), 10)

if (length(setdiff(quantiles, present_quantiles)) > 0) {
cli::cli_warn(
"Some requested quantiles are missing in the forecast. ",
"The PIT histogram will be based on the quantiles present in the forecast."
)
}

forecast <- forecast[quantile_level %in% quantiles]
forecast[, quantile_coverage := (observed <= predicted)]

quantile_coverage <-
forecast[, .(quantile_coverage = mean(quantile_coverage)),
by = c(unique(c(by, "quantile_level")))]
quantile_coverage <- quantile_coverage[

bins <- sprintf("[%s,%s)", quantiles[-length(quantiles)], quantiles[-1])
mids <- (quantiles[-length(quantiles)] + quantiles[-1]) / 2

pit_histogram <- quantile_coverage[
order(quantile_level),
.(
quantile_level = c(quantile_level, 1),
pit_value = diff(c(0, quantile_coverage, 1))
density = diff(c(0, quantile_coverage, 1)) / diffs,
bin = bins,
mid = mids
),
by = c(get_forecast_unit(quantile_coverage))
]
return(quantile_coverage[])
return(pit_histogram[])
}


Expand Down
52 changes: 43 additions & 9 deletions R/class-forecast-sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -165,31 +165,65 @@ get_metrics.forecast_sample <- function(x, select = NULL, exclude = NULL, ...) {
}


#' @rdname get_pit
#' @importFrom stats na.omit
#' @rdname get_pit_histogram
#' @param integers How to handle inteteger forecasts (count data). This is based
#' on methods described Czado et al. (2007). If "nonrandom" (default) the
#' function will use the non-randomised PIT method. If "random", will use the
#' randomised PIT method. If "ignore", will treat integer forecasts as if they
#' were continuous.
#' @importFrom data.table `:=` as.data.table dcast
#' @inheritParams pit_sample
#' @inheritParams pit_histogram_sample
#' @seealso [pit_histogram_sample()]
#' @export
get_pit.forecast_sample <- function(forecast, by, n_replicates = 100, ...) {
get_pit_histogram.forecast_sample <- function(forecast, num_bins = "auto",
breaks = NULL, by, integers = c(
"nonrandom", "random", "ignore"
), n_replicates = 100, ...) {
integers <- match.arg(integers)

forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE)
forecast <- as.data.table(forecast)

assert_number(n_replicates)

if (!is.null(breaks)) {
quantiles <- unique(c(0, breaks, 1))
} else if (is.null(num_bins) || num_bins == "auto") {
quantiles <- seq(0, 1, 1 / 10)
} else {
quantiles <- seq(0, 1, 1 / num_bins)
}

# if prediction type is not quantile, calculate PIT values based on samples
forecast_wide <- data.table::dcast(
forecast,
... ~ paste0("InternalSampl_", sample_id),
value.var = "predicted"
)

pit <- forecast_wide[, .(pit_value = pit_sample(
observed = observed,
predicted = as.matrix(.SD)
)),
bins <- sprintf("[%s,%s)", quantiles[-length(quantiles)], quantiles[-1])
mids <- (quantiles[-length(quantiles)] + quantiles[-1]) / 2

if (missing(n_replicates) && integers != "random") {
n_replicates <- NULL
}

pit_histogram <- forecast_wide[, .(
density = pit_histogram_sample(
observed = observed,
predicted = as.matrix(.SD),
quantiles = quantiles,
integers = integers,
n_replicates = n_replicates
),
bin = bins,
mid = mids
),
by = by,
.SDcols = grepl("InternalSampl_", names(forecast_wide), fixed = TRUE)
]

return(pit[])
return(pit_histogram[])
}


Expand Down
2 changes: 1 addition & 1 deletion R/get-coverage.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ get_coverage <- function(forecast, by = "model") {
#' Default is "model".
#' @return ggplot object with a plot of interval coverage
#' @importFrom ggplot2 ggplot scale_colour_manual scale_fill_manual .data
#' facet_wrap facet_grid geom_polygon geom_line
#' facet_wrap facet_grid geom_polygon geom_line xlab ylab
#' @importFrom checkmate assert_subset
#' @importFrom data.table dcast
#' @export
Expand Down
53 changes: 53 additions & 0 deletions R/get-pit-histogram.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#' @title Probability integral transformation histogram
#'
#' @description
#' Generate a Probability Integral Transformation (PIT) histogram for
#' validated forecast objects.
#'
#' @inherit score params
#' @param num_bins The number of bins in the PIT histogram, default is "auto".
#' When `num_bins == "auto"`, a histogram will be created with either 10 bins,
#' or it a bin for each available quantile in case the forecasts are in a
#' quantile-based format.
#' You can control the number of bins by supplying a number. This is fine for
#' sample-based pit histograms, but may fail for quantile-based formats. In this
#' case it is preferred to supply explicit breaks points using the `breaks`
#' argument.
#' @param breaks Numeric vector with the break points for the bins in the
#' PIT histogram. This is preferred when creating a PIT histogram based on
#' quantile-based data. Default is `NULL` and breaks will be determined by
#' `num_bins`. If `breaks` is used, `num_bins` will be ignored.
#' @param by Character vector with the columns according to which the
#' PIT values shall be grouped. If you e.g. have the columns 'model' and
#' 'location' in the input data and want to have a PIT histogram for
#' every model and location, specify `by = c("model", "location")`.
#' @inheritParams pit_sample
#' @return A data.table with density values for each bin in the PIT histogram.
#' @examples
#' example <- as_forecast_sample(example_sample_continuous)
#' result <- get_pit_histogram(example, by = "model")
#'
#' # example with quantile data
#' example <- as_forecast_quantile(example_quantile)
#' result <- get_pit_histogram(example, by = "model")
#' @export
#' @keywords scoring
#' @references
#' Sebastian Funk, Anton Camacho, Adam J. Kucharski, Rachel Lowe,
#' Rosalind M. Eggo, W. John Edmunds (2019) Assessing the performance of
#' real-time epidemic forecasts: A case study of Ebola in the Western Area
#' region of Sierra Leone, 2014-15, \doi{10.1371/journal.pcbi.1006785}
get_pit_histogram <- function(forecast, num_bins = "auto", breaks = NULL, by,
...) {
UseMethod("get_pit_histogram")
}


#' @rdname get_pit_histogram
#' @importFrom cli cli_abort
#' @export
get_pit_histogram.default <- function(forecast, num_bins, breaks, by, ...) {
cli_abort(c(
"!" = "The input needs to be a valid forecast object represented as quantiles or samples." # nolint
))
}
Loading
Loading