diff --git a/NAMESPACE b/NAMESPACE index bc2a6b542..d3b6ffe22 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -82,6 +82,7 @@ importFrom(checkmate,assert_data_frame) importFrom(checkmate,assert_data_table) importFrom(checkmate,assert_factor) importFrom(checkmate,assert_list) +importFrom(checkmate,assert_logical) importFrom(checkmate,assert_number) importFrom(checkmate,assert_numeric) importFrom(checkmate,check_atomic_vector) diff --git a/R/metrics-quantile.R b/R/metrics-quantile.R index db187d376..985b00002 100644 --- a/R/metrics-quantile.R +++ b/R/metrics-quantile.R @@ -91,6 +91,7 @@ #' @param count_median_twice if TRUE, count the median twice in the score #' @param na.rm if TRUE, ignore NA values when computing the score #' @importFrom stats weighted.mean +#' @importFrom checkmate assert_logical #' @return #' `wis()`: a numeric vector with WIS values of size n (one per observation), #' or a list with separate entries if `separate_results` is `TRUE`. @@ -105,6 +106,11 @@ wis <- function(observed, assert_input_quantile(observed, predicted, quantile) reformatted <- quantile_to_interval(observed, predicted, quantile) + assert_logical(separate_results, len = 1) + assert_logical(weigh, len = 1) + assert_logical(count_median_twice, len = 1) + assert_logical(na.rm, len = 1) + if (separate_results) { cols <- c("wis", "dispersion", "underprediction", "overprediction") } else { diff --git a/R/score.R b/R/score.R index 7ec7763e3..29230e6e0 100644 --- a/R/score.R +++ b/R/score.R @@ -152,18 +152,10 @@ score.scoringutils_binary <- function(data, metrics = metrics_binary, ...) { data <- remove_na_observed_predicted(data) metrics <- validate_metrics(metrics) - # Extract the arguments passed in ... - args <- list(...) - lapply(seq_along(metrics), function(i, ...) { - metric_name <- names(metrics[i]) - fun <- metrics[[i]] - matching_args <- filter_function_args(fun, args) - - data[, (metric_name) := do.call( - fun, c(list(observed, predicted), matching_args) - )] - return() - }, ...) + data <- apply_metrics( + data, metrics, + data$observed, data$predicted, ... + ) setattr(data, "metric_names", names(metrics)) @@ -180,18 +172,10 @@ score.scoringutils_point <- function(data, metrics = metrics_point, ...) { data <- remove_na_observed_predicted(data) metrics <- validate_metrics(metrics) - # Extract the arguments passed in ... - args <- list(...) - lapply(seq_along(metrics), function(i, ...) { - metric_name <- names(metrics[i]) - fun <- metrics[[i]] - matching_args <- filter_function_args(fun, args) - - data[, (metric_name) := do.call( - fun, c(list(observed, predicted), matching_args) - )] - return() - }, ...) + data <- apply_metrics( + data, metrics, + data$observed, data$predicted, ... + ) setattr(data, "metric_names", names(metrics)) @@ -206,26 +190,29 @@ score.scoringutils_sample <- function(data, metrics = metrics_sample, ...) { forecast_unit <- attr(data, "forecast_unit") metrics <- validate_metrics(metrics) - # Extract the arguments passed in ... - args <- list(...) - lapply(seq_along(metrics), function(i, ...) { - metric_name <- names(metrics[i]) - fun <- metrics[[i]] - matching_args <- filter_function_args(fun, args) + # transpose the forecasts that belong to the same forecast unit + d_transposed <- data[, .(predicted = list(predicted), + observed = unique(observed), + scoringutils_N = length(list(sample_id))), + by = forecast_unit] - data[, (metric_name) := do.call( - fun, c(list(unique(observed), t(predicted)), matching_args) - ), by = forecast_unit] - return() - }, - ...) + # split according to number of samples and do calculations for different + # sample lengths separately + d_split <- split(d_transposed, d_transposed$scoringutils_N) - data <- data[ - , lapply(.SD, unique), - .SDcols = colnames(data) %like% paste(names(metrics), collapse = "|"), - by = forecast_unit - ] + split_result <- lapply(d_split, function(data) { + # create a matrix + observed <- data$observed + predicted <- do.call(rbind, data$predicted) + data[, c("observed", "predicted", "scoringutils_N") := NULL] + data <- apply_metrics( + data, metrics, + observed, predicted, ... + ) + return(data) + }) + data <- rbindlist(split_result) setattr(data, "metric_names", names(metrics)) return(data[]) @@ -240,9 +227,6 @@ score.scoringutils_quantile <- function(data, metrics = metrics_quantile, ...) { forecast_unit <- attr(data, "forecast_unit") metrics <- validate_metrics(metrics) - # Extract the arguments passed in ... - args <- list(...) - # transpose the forecasts that belong to the same forecast unit # make sure the quantiles and predictions are ordered in the same way d_transposed <- data[, .(predicted = list(predicted[order(quantile)]), @@ -263,18 +247,10 @@ score.scoringutils_quantile <- function(data, metrics = metrics_quantile, ...) { quantile <- unlist(unique(data$quantile)) data[, c("observed", "predicted", "quantile", "scoringutils_quantile") := NULL] - # for each metric, compute score - lapply(seq_along(metrics), function(i, ...) { - metric_name <- names(metrics[i]) - fun <- metrics[[i]] - matching_args <- filter_function_args(fun, args) - - data[, eval(metric_name) := do.call( - fun, c(list(observed), list(predicted), list(quantile), matching_args) - )] - return() - }, - ...) + data <- apply_metrics( + data, metrics, + observed, predicted, quantile, ... + ) return(data) }) @@ -283,3 +259,18 @@ score.scoringutils_quantile <- function(data, metrics = metrics_quantile, ...) { return(data[]) } + +apply_metrics <- function(data, metrics, ...) { + expr <- expression( + data[, (metric_name) := do.call(run_safely, list(..., fun = fun))] + ) + lapply(seq_along(metrics), function(i, data, ...) { + metric_name <- names(metrics[i]) + fun <- metrics[[i]] + eval(expr) + }, data, ...) + return(data) +} + + + diff --git a/R/z_globalVariables.R b/R/z_globalVariables.R index 28bcfb95b..89cfc2eab 100644 --- a/R/z_globalVariables.R +++ b/R/z_globalVariables.R @@ -62,6 +62,7 @@ globalVariables(c( "rel_to_baseline", "relative_skill", "rn", + "sample_id", "scoringutils_InternalDuplicateCheck", "scoringutils_InternalNumCheck", "se_mean", diff --git a/data/metrics_quantile.rda b/data/metrics_quantile.rda index 70a00a932..b14a8321c 100644 Binary files a/data/metrics_quantile.rda and b/data/metrics_quantile.rda differ diff --git a/inst/create-list-available-forecasts.R b/inst/create-list-available-forecasts.R index fcac2950c..fc4926797 100644 --- a/inst/create-list-available-forecasts.R +++ b/inst/create-list-available-forecasts.R @@ -28,8 +28,8 @@ metrics_quantile <- list( "underprediction" = underprediction, "dispersion" = dispersion, "bias" = bias_quantile, - "coverage_50" = \(...) {run_safely(..., range = 50, fun = interval_coverage_quantile)}, - "coverage_90" = \(...) {run_safely(..., range = 90, fun = interval_coverage_quantile)}, + "coverage_50" = \(...) {do.call(interval_coverage_quantile, c(list(...), range = 50))}, + "coverage_90" = \(...) {do.call(interval_coverage_quantile, c(list(...), range = 90))}, "coverage_deviation" = interval_coverage_deviation_quantile, "ae_median" = ae_median_quantile )