diff --git a/DESCRIPTION b/DESCRIPTION index b52246bfb..5377be989 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: scoringutils Title: Utilities for Scoring and Assessing Predictions -Version: 1.1.4 +Version: 1.1.5 Language: en-GB Authors@R: c( person(given = "Nikos", diff --git a/NAMESPACE b/NAMESPACE index 437b12630..f7ba26377 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -42,6 +42,7 @@ export(quantile_score) export(sample_to_quantile) export(score) export(se_mean_sample) +export(set_forecast_unit) export(squared_error) export(summarise_scores) export(summarize_scores) diff --git a/NEWS.md b/NEWS.md index e8fe3b6ca..d81135d29 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,11 @@ +# scoringutils 1.1.5 + +## Feature updates +- Added a new function, `set_forecast_unit()` that allows the user to set the forecast unit manually. The function removes all columns that are not relevant for uniquely identifying a single forecast. If not done manually, `scoringutils` attempts to determine the unit of a single automatically by simply assuming that all column names are +relevant to determine the forecast unit. This can lead to unexpected behaviour, so setting the forecast unit explicitly can help make the code easier to debug and easier to read (see issue #268). +When used as part of a workflow, `set_forecast_unit()` can be directly piped into `check_forecasts()` to +check everything is in order. + # scoringutils 1.1.4 ## Package updates diff --git a/R/convenience-functions.R b/R/convenience-functions.R index ba529258e..f21aeefc8 100644 --- a/R/convenience-functions.R +++ b/R/convenience-functions.R @@ -156,10 +156,6 @@ transform_forecasts <- function(data, } - - - - #' @title Log transformation with an additive shift #' #' @description Function that shifts a value by some offset and then applies the @@ -183,7 +179,6 @@ transform_forecasts <- function(data, #' # nolint #' @keywords check-forecasts #' @examples -#' #' log_shift(1:10) #' log_shift(0:9, offset = 1) #' @@ -193,9 +188,7 @@ transform_forecasts <- function(data, #' offset = 1 #' ) -log_shift <- function(x, - offset = 0, - base = exp(1)) { +log_shift <- function(x, offset = 0, base = exp(1)) { if (any(x < 0, na.rm = TRUE)) { w <- paste("Detected input values < 0.") @@ -209,3 +202,53 @@ log_shift <- function(x, } log(x + offset, base = base) } + + +#' @title Set unit of a single forecast manually +#' +#' @description Helper function to set the unit of a single forecast (i.e. the +#' combination of columns that uniquely define a single forecast) manually. +#' This simple function keeps the columns specified in `forecast_unit` (plus +#' additional protected columns, e.g. for true values, predictions or quantile +#' levels) and removes duplicate rows. +#' If not done manually, `scoringutils` attempts to determine the unit +#' of a single forecast automatically by simply assuming that all column names +#' are relevant to determine the forecast unit. This may lead to unexpected +#' behaviour, so setting the forecast unit explicitly can help make the code +#' easier to debug and easier to read. When used as part of a workflow, +#' `set_forecast_unit()` can be directly piped into `check_forecasts()` to +#' check everything is in order. +#' +#' @inheritParams score +#' @param forecast_unit Character vector with the names of the columns that +#' uniquely identify a single forecast. +#' @return A data.table with only those columns kept that are relevant to +#' scoring or denote the unit of a single forecast as specified by the user. +#' +#' @importFrom data.table ':=' is.data.table copy +#' @export +#' @keywords data-handling +#' @examples +#' set_forecast_unit( +#' example_quantile, +#' c("location", "target_end_date", "target_type", "horizon", "model") +#' ) + +set_forecast_unit <- function(data, forecast_unit) { + + datacols <- colnames(data) + missing <- forecast_unit[!(forecast_unit %in% datacols)] + + if (length(missing) > 0) { + warning( + "Column(s) '", + missing, + "' are not columns of the data and will be ignored." + ) + forecast_unit <- intersect(forecast_unit, datacols) + } + + keep_cols <- c(get_protected_columns(data), forecast_unit) + out <- unique(data[, .SD, .SDcols = keep_cols])[] + return(out) +} diff --git a/R/utils.R b/R/utils.R index ae3e6f382..029a70da8 100644 --- a/R/utils.R +++ b/R/utils.R @@ -242,12 +242,7 @@ get_target_type <- function(data) { get_forecast_unit <- function(data, prediction_type) { - protected_columns <- c( - "prediction", "true_value", "sample", "quantile", "upper", "lower", - "pit_value", - "range", "boundary", available_metrics(), - grep("coverage_", names(data), fixed = TRUE, value = TRUE) - ) + protected_columns <- get_protected_columns(data) if (!missing(prediction_type)) { if (prediction_type == "quantile") { protected_columns <- setdiff(protected_columns, "sample") @@ -256,3 +251,33 @@ get_forecast_unit <- function(data, prediction_type) { forecast_unit <- setdiff(colnames(data), protected_columns) return(forecast_unit) } + + +#' @title Get protected columns from a data frame +#' +#' @description Helper function to get the names of all columns in a data frame +#' that are protected columns. +#' +#' @inheritParams check_forecasts +#' +#' @return A character vector with the names of protected columns in the data +#' +#' @keywords internal + +get_protected_columns <- function(data) { + + datacols <- colnames(data) + protected_columns <- c( + "prediction", "true_value", "sample", "quantile", "upper", "lower", + "pit_value", "range", "boundary", available_metrics(), + grep("coverage_", names(data), fixed = TRUE, value = TRUE) + ) + + # only return protected columns that are present + protected_columns <- intersect( + datacols, + protected_columns + ) + + return(protected_columns) +} diff --git a/man/get_protected_columns.Rd b/man/get_protected_columns.Rd new file mode 100644 index 000000000..79af67598 --- /dev/null +++ b/man/get_protected_columns.Rd @@ -0,0 +1,44 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{get_protected_columns} +\alias{get_protected_columns} +\title{Get protected columns from a data frame} +\usage{ +get_protected_columns(data) +} +\arguments{ +\item{data}{A data.frame or data.table with the predictions and observations. +For scoring using \code{\link[=score]{score()}}, the following columns need to be present: +\itemize{ +\item \code{true_value} - the true observed values +\item \code{prediction} - predictions or predictive samples for one +true value. (You only don't need to provide a prediction column if +you want to score quantile forecasts in a wide range format.)} +For scoring integer and continuous forecasts a \code{sample} column is needed: +\itemize{ +\item \code{sample} - an index to identify the predictive samples in the +prediction column generated by one model for one true value. Only +necessary for continuous and integer forecasts, not for +binary predictions.} +For scoring predictions in a quantile-format forecast you should provide +a column called \code{quantile}: +\itemize{ +\item \code{quantile}: quantile to which the prediction corresponds +} + +In addition a \code{model} column is suggested and if not present this will be +flagged and added to the input data with all forecasts assigned as an +"unspecified model"). + +You can check the format of your data using \code{\link[=check_forecasts]{check_forecasts()}} and there +are examples for each format (\link{example_quantile}, \link{example_continuous}, +\link{example_integer}, and \link{example_binary}).} +} +\value{ +A character vector with the names of protected columns in the data +} +\description{ +Helper function to get the names of all columns in a data frame +that are protected columns. +} +\keyword{internal} diff --git a/man/log_shift.Rd b/man/log_shift.Rd index ddeb26ca3..66e0ffb53 100644 --- a/man/log_shift.Rd +++ b/man/log_shift.Rd @@ -26,7 +26,6 @@ natural logarithm to it. The output is computed as log(x + offset) } \examples{ - log_shift(1:10) log_shift(0:9, offset = 1) diff --git a/man/set_forecast_unit.Rd b/man/set_forecast_unit.Rd new file mode 100644 index 000000000..8482a27d7 --- /dev/null +++ b/man/set_forecast_unit.Rd @@ -0,0 +1,64 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/convenience-functions.R +\name{set_forecast_unit} +\alias{set_forecast_unit} +\title{Set unit of a single forecast manually} +\usage{ +set_forecast_unit(data, forecast_unit) +} +\arguments{ +\item{data}{A data.frame or data.table with the predictions and observations. +For scoring using \code{\link[=score]{score()}}, the following columns need to be present: +\itemize{ +\item \code{true_value} - the true observed values +\item \code{prediction} - predictions or predictive samples for one +true value. (You only don't need to provide a prediction column if +you want to score quantile forecasts in a wide range format.)} +For scoring integer and continuous forecasts a \code{sample} column is needed: +\itemize{ +\item \code{sample} - an index to identify the predictive samples in the +prediction column generated by one model for one true value. Only +necessary for continuous and integer forecasts, not for +binary predictions.} +For scoring predictions in a quantile-format forecast you should provide +a column called \code{quantile}: +\itemize{ +\item \code{quantile}: quantile to which the prediction corresponds +} + +In addition a \code{model} column is suggested and if not present this will be +flagged and added to the input data with all forecasts assigned as an +"unspecified model"). + +You can check the format of your data using \code{\link[=check_forecasts]{check_forecasts()}} and there +are examples for each format (\link{example_quantile}, \link{example_continuous}, +\link{example_integer}, and \link{example_binary}).} + +\item{forecast_unit}{Character vector with the names of the columns that +uniquely identify a single forecast.} +} +\value{ +A data.table with only those columns kept that are relevant to +scoring or denote the unit of a single forecast as specified by the user. +} +\description{ +Helper function to set the unit of a single forecast (i.e. the +combination of columns that uniquely define a single forecast) manually. +This simple function keeps the columns specified in \code{forecast_unit} (plus +additional protected columns, e.g. for true values, predictions or quantile +levels) and removes duplicate rows. +If not done manually, \code{scoringutils} attempts to determine the unit +of a single forecast automatically by simply assuming that all column names +are relevant to determine the forecast unit. This may lead to unexpected +behaviour, so setting the forecast unit explicitly can help make the code +easier to debug and easier to read. When used as part of a workflow, +\code{set_forecast_unit()} can be directly piped into \code{check_forecasts()} to +check everything is in order. +} +\examples{ +set_forecast_unit( + example_quantile, + c("location", "target_end_date", "target_type", "horizon", "model") +) +} +\keyword{data-handling} diff --git a/tests/testthat/test-convenience-functions.R b/tests/testthat/test-convenience-functions.R index b4bdecc58..ad7a40550 100644 --- a/tests/testthat/test-convenience-functions.R +++ b/tests/testthat/test-convenience-functions.R @@ -36,3 +36,49 @@ test_that("function transform_forecasts works", { expect_equal(four$prediction, compare) }) + + + +test_that("function set_forecast_unit() works", { + + # some columns in the example data have duplicated information. So we can remove + # these and see whether the result stays the same. + + scores1 <- suppressMessages(score(example_quantile)) + scores1 <- scores1[order(location, target_end_date, target_type, horizon, model), ] + + ex2 <- set_forecast_unit( + example_quantile, + c("location", "target_end_date", "target_type", "horizon", "model") + ) + scores2 <- suppressMessages(score(ex2)) + scores2 <- scores2[order(location, target_end_date, target_type, horizon, model), ] + + expect_equal(scores1$interval_score, scores2$interval_score) +}) + + +test_that("function set_forecast_unit() gives warning when column is not there", { + + expect_warning( + set_forecast_unit( + example_quantile, + c("location", "target_end_date", "target_type", "horizon", "model", "test") + ) + ) +}) + + +test_that("function get_forecast_unit() and set_forecast_unit() work together", { + + fu_set <- c("location", "target_end_date", "target_type", "horizon", "model") + + ex <- set_forecast_unit( + example_binary, + fu_set + ) + + fu_get <- get_forecast_unit(ex) + expect_equal(fu_set, fu_get) +}) + diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R new file mode 100644 index 000000000..2831715a5 --- /dev/null +++ b/tests/testthat/test-utils.R @@ -0,0 +1,36 @@ +test_that("get_protected columns returns the correct result", { + + data <- example_quantile + manual <- protected_columns <- c( + "prediction", "true_value", "sample", "quantile", "upper", "lower", + "pit_value", + "range", "boundary", available_metrics(), + grep("coverage_", names(data), fixed = TRUE, value = TRUE) + ) + manual <- intersect(manual, colnames(example_quantile)) + auto <- get_protected_columns(data) + expect_equal(sort(manual), sort(auto)) + + + data <- example_binary + manual <- protected_columns <- c( + "prediction", "true_value", "sample", "quantile", "upper", "lower", + "pit_value", + "range", "boundary", available_metrics(), + grep("coverage_", names(data), fixed = TRUE, value = TRUE) + ) + manual <- intersect(manual, colnames(example_binary)) + auto <- get_protected_columns(data) + expect_equal(sort(manual), sort(auto)) + + data <- example_continuous + manual <- protected_columns <- c( + "prediction", "true_value", "sample", "quantile", "upper", "lower", + "pit_value", + "range", "boundary", available_metrics(), + grep("coverage_", names(data), fixed = TRUE, value = TRUE) + ) + manual <- intersect(manual, colnames(example_continuous)) + auto <- get_protected_columns(data) + expect_equal(sort(manual), sort(auto)) +})