Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: scoringutils
Title: Utilities for Scoring and Assessing Predictions
Version: 1.1.4
Version: 1.1.5
Language: en-GB
Authors@R: c(
person(given = "Nikos",
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export(quantile_score)
export(sample_to_quantile)
export(score)
export(se_mean_sample)
export(set_forecast_unit)
export(squared_error)
export(summarise_scores)
export(summarize_scores)
Expand Down
8 changes: 8 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# scoringutils 1.1.5

## Feature updates
- Added a new function, `set_forecast_unit()` that allows the user to set the forecast unit manually. The function removes all columns that are not relevant for uniquely identifying a single forecast. If not done manually, `scoringutils` attempts to determine the unit of a single automatically by simply assuming that all column names are
relevant to determine the forecast unit. This can lead to unexpected behaviour, so setting the forecast unit explicitly can help make the code easier to debug and easier to read (see issue #268).
When used as part of a workflow, `set_forecast_unit()` can be directly piped into `check_forecasts()` to
check everything is in order.

# scoringutils 1.1.4

## Package updates
Expand Down
59 changes: 51 additions & 8 deletions R/convenience-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,6 @@ transform_forecasts <- function(data,
}






#' @title Log transformation with an additive shift
#'
#' @description Function that shifts a value by some offset and then applies the
Expand All @@ -183,7 +179,6 @@ transform_forecasts <- function(data,
#' <https://www.medrxiv.org/content/10.1101/2023.01.23.23284722v1> # nolint
#' @keywords check-forecasts
#' @examples
#'
#' log_shift(1:10)
#' log_shift(0:9, offset = 1)
#'
Expand All @@ -193,9 +188,7 @@ transform_forecasts <- function(data,
#' offset = 1
#' )

log_shift <- function(x,
offset = 0,
base = exp(1)) {
log_shift <- function(x, offset = 0, base = exp(1)) {

if (any(x < 0, na.rm = TRUE)) {
w <- paste("Detected input values < 0.")
Expand All @@ -209,3 +202,53 @@ log_shift <- function(x,
}
log(x + offset, base = base)
}


#' @title Set unit of a single forecast manually
#'
#' @description Helper function to set the unit of a single forecast (i.e. the
#' combination of columns that uniquely define a single forecast) manually.
#' This simple function keeps the columns specified in `forecast_unit` (plus
#' additional protected columns, e.g. for true values, predictions or quantile
#' levels) and removes duplicate rows.
#' If not done manually, `scoringutils` attempts to determine the unit
#' of a single forecast automatically by simply assuming that all column names
#' are relevant to determine the forecast unit. This may lead to unexpected
#' behaviour, so setting the forecast unit explicitly can help make the code
#' easier to debug and easier to read. When used as part of a workflow,
#' `set_forecast_unit()` can be directly piped into `check_forecasts()` to
#' check everything is in order.
#'
#' @inheritParams score
#' @param forecast_unit Character vector with the names of the columns that
#' uniquely identify a single forecast.
#' @return A data.table with only those columns kept that are relevant to
#' scoring or denote the unit of a single forecast as specified by the user.
#'
#' @importFrom data.table ':=' is.data.table copy
#' @export
#' @keywords data-handling
#' @examples
#' set_forecast_unit(
#' example_quantile,
#' c("location", "target_end_date", "target_type", "horizon", "model")
#' )

set_forecast_unit <- function(data, forecast_unit) {

datacols <- colnames(data)
missing <- forecast_unit[!(forecast_unit %in% datacols)]

if (length(missing) > 0) {
warning(
"Column(s) '",
missing,
"' are not columns of the data and will be ignored."
)
forecast_unit <- intersect(forecast_unit, datacols)
}

keep_cols <- c(get_protected_columns(data), forecast_unit)
out <- unique(data[, .SD, .SDcols = keep_cols])[]
return(out)
}
37 changes: 31 additions & 6 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,7 @@ get_target_type <- function(data) {

get_forecast_unit <- function(data, prediction_type) {

protected_columns <- c(
"prediction", "true_value", "sample", "quantile", "upper", "lower",
"pit_value",
"range", "boundary", available_metrics(),
grep("coverage_", names(data), fixed = TRUE, value = TRUE)
)
protected_columns <- get_protected_columns(data)
if (!missing(prediction_type)) {
if (prediction_type == "quantile") {
protected_columns <- setdiff(protected_columns, "sample")
Expand All @@ -256,3 +251,33 @@ get_forecast_unit <- function(data, prediction_type) {
forecast_unit <- setdiff(colnames(data), protected_columns)
return(forecast_unit)
}


#' @title Get protected columns from a data frame
#'
#' @description Helper function to get the names of all columns in a data frame
#' that are protected columns.
#'
#' @inheritParams check_forecasts
#'
#' @return A character vector with the names of protected columns in the data
#'
#' @keywords internal

get_protected_columns <- function(data) {

datacols <- colnames(data)
protected_columns <- c(
"prediction", "true_value", "sample", "quantile", "upper", "lower",
"pit_value", "range", "boundary", available_metrics(),
grep("coverage_", names(data), fixed = TRUE, value = TRUE)
)

# only return protected columns that are present
protected_columns <- intersect(
datacols,
protected_columns
)

return(protected_columns)
}
44 changes: 44 additions & 0 deletions man/get_protected_columns.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion man/log_shift.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

64 changes: 64 additions & 0 deletions man/set_forecast_unit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

46 changes: 46 additions & 0 deletions tests/testthat/test-convenience-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,49 @@ test_that("function transform_forecasts works", {

expect_equal(four$prediction, compare)
})



test_that("function set_forecast_unit() works", {

# some columns in the example data have duplicated information. So we can remove
# these and see whether the result stays the same.

scores1 <- suppressMessages(score(example_quantile))
scores1 <- scores1[order(location, target_end_date, target_type, horizon, model), ]

ex2 <- set_forecast_unit(
example_quantile,
c("location", "target_end_date", "target_type", "horizon", "model")
)
scores2 <- suppressMessages(score(ex2))
scores2 <- scores2[order(location, target_end_date, target_type, horizon, model), ]

expect_equal(scores1$interval_score, scores2$interval_score)
})


test_that("function set_forecast_unit() gives warning when column is not there", {

expect_warning(
set_forecast_unit(
example_quantile,
c("location", "target_end_date", "target_type", "horizon", "model", "test")
)
)
})


test_that("function get_forecast_unit() and set_forecast_unit() work together", {

fu_set <- c("location", "target_end_date", "target_type", "horizon", "model")

ex <- set_forecast_unit(
example_binary,
fu_set
)

fu_get <- get_forecast_unit(ex)
expect_equal(fu_set, fu_get)
})

36 changes: 36 additions & 0 deletions tests/testthat/test-utils.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
test_that("get_protected columns returns the correct result", {

data <- example_quantile
manual <- protected_columns <- c(
"prediction", "true_value", "sample", "quantile", "upper", "lower",
"pit_value",
"range", "boundary", available_metrics(),
grep("coverage_", names(data), fixed = TRUE, value = TRUE)
)
manual <- intersect(manual, colnames(example_quantile))
auto <- get_protected_columns(data)
expect_equal(sort(manual), sort(auto))


data <- example_binary
manual <- protected_columns <- c(
"prediction", "true_value", "sample", "quantile", "upper", "lower",
"pit_value",
"range", "boundary", available_metrics(),
grep("coverage_", names(data), fixed = TRUE, value = TRUE)
)
manual <- intersect(manual, colnames(example_binary))
auto <- get_protected_columns(data)
expect_equal(sort(manual), sort(auto))

data <- example_continuous
manual <- protected_columns <- c(
"prediction", "true_value", "sample", "quantile", "upper", "lower",
"pit_value",
"range", "boundary", available_metrics(),
grep("coverage_", names(data), fixed = TRUE, value = TRUE)
)
manual <- intersect(manual, colnames(example_continuous))
auto <- get_protected_columns(data)
expect_equal(sort(manual), sort(auto))
})