Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions R/score.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ score.forecast_binary <- function(data, metrics = metrics_binary, ...) {
data <- remove_na_observed_predicted(data)
metrics <- validate_metrics(metrics)

data <- apply_metrics(
data <- apply_rules(
data, metrics,
data$observed, data$predicted, ...
)
Expand All @@ -101,7 +101,7 @@ score.forecast_point <- function(data, metrics = metrics_point, ...) {
data <- remove_na_observed_predicted(data)
metrics <- validate_metrics(metrics)

data <- apply_metrics(
data <- apply_rules(
data, metrics,
data$observed, data$predicted, ...
)
Expand Down Expand Up @@ -135,7 +135,7 @@ score.forecast_sample <- function(data, metrics = metrics_sample, ...) {
predicted <- do.call(rbind, data$predicted)
data[, c("observed", "predicted", "scoringutils_N") := NULL]

data <- apply_metrics(
data <- apply_rules(
data, metrics,
observed, predicted, ...
)
Expand Down Expand Up @@ -180,7 +180,7 @@ score.forecast_quantile <- function(data, metrics = metrics_quantile, ...) {
"observed", "predicted", "quantile", "scoringutils_quantile"
) := NULL]

data <- apply_metrics(
data <- apply_rules(
data, metrics,
observed, predicted, quantile, ...
)
Expand All @@ -193,7 +193,18 @@ score.forecast_quantile <- function(data, metrics = metrics_quantile, ...) {
return(data[])
}

apply_metrics <- function(data, metrics, ...) {

#' @title Apply A List Of Functions To A Data Table Of Forecasts
#' @description This helper function applies scoring rules (stored as a list of
#' functions) to a data table of forecasts. `apply_rules` is used within
#' `score()` to apply all scoring rules to the data.
#' Scoring rules are wrapped in [run_safely()] to catch errors and to make
#' sure that only arguments are passed to the scoring rule that are actually
#' accepted by it.
#' @inheritParams score
#' @return A data table with the forecasts and the calculated metrics
#' @keywords internal
apply_rules <- function(data, metrics, ...) {
expr <- expression(
data[, (metric_name) := do.call(run_safely, list(..., fun = fun))]
)
Expand Down
30 changes: 30 additions & 0 deletions man/apply_rules.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 28 additions & 13 deletions tests/testthat/test-score.R
Original file line number Diff line number Diff line change
Expand Up @@ -277,16 +277,31 @@ test_that("function throws an error if data is missing", {
expect_error(suppressMessages(score(data = NULL)))
})

# test_that(
# "score() can support a sample column when a quantile forecast is used", {
# ex <- example_quantile[!is.na(quantile)][1:200, ]
# ex <- rbind(
# data.table::copy(ex)[, sample_id := 1],
# ex[, sample_id := 2]
# )
# scores <- suppressWarnings(score(ex))
# expect_snapshot(summarise_scores(
# summarise_scores(scores, by = "model"), by = "model",
# fun = signif, digits = 2
# ))
# })
# =============================================================================
# `apply_rules()`
# =============================================================================

test_that("apply_rules() works", {

dt <- data.table::data.table(x = 1:10)
scoringutils:::apply_rules(
data = dt, metrics = list("test" = function(x) x + 1),
dt$x
)
expect_equal(dt$test, 2:11)

# additional named argument works
expect_no_condition(
scoringutils:::apply_rules(
data = dt, metrics = list("test" = function(x) x + 1),
dt$x, y = dt$test)
)

# additional unnamed argument does not work

expect_warning(
scoringutils:::apply_rules(
data = dt, metrics = list("test" = function(x) x + 1),
dt$x, dt$test)
)
})