Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ S3method(print,scoringutils_check)
S3method(quantile_to_interval,data.frame)
S3method(quantile_to_interval,numeric)
S3method(score,default)
S3method(score,scoringutils_binary)
S3method(score,scoringutils_point)
S3method(score,scoringutils_quantile)
S3method(score,scoringutils_sample)
S3method(validate_forecast,scoringutils_binary)
S3method(validate_forecast,scoringutils_point)
S3method(validate_forecast,scoringutils_quantile)
S3method(validate_forecast,scoringutils_sample)
S3method(score,forecast_binary)
S3method(score,forecast_point)
S3method(score,forecast_quantile)
S3method(score,forecast_sample)
S3method(validate_forecast,forecast_binary)
S3method(validate_forecast,forecast_point)
S3method(validate_forecast,forecast_quantile)
S3method(validate_forecast,forecast_sample)
export(abs_error)
export(add_coverage)
export(add_pairwise_comparison)
Expand Down Expand Up @@ -41,7 +41,7 @@ export(mad_sample)
export(make_NA)
export(make_na)
export(merge_pred_and_obs)
export(new_scoringutils)
export(new_forecast)
export(overprediction)
export(pairwise_comparison)
export(pit)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The update introduces breaking changes. If you want to keep using the older vers

## Package updates
- In `score()`, required columns "true_value" and "prediction" were renamed and replaced by required columns "observed" and "predicted". Scoring functions now also use the function arguments "observed" and "predicted" everywhere consistently.
- The overall scoring workflow was updated. `score()` is now a generic function that dispatches the correct method based on the forecast type. forecast types currently supported are "binary", "point", "sample" and "quantile" with corresponding classes "forecast_binary", "forecast_point", "forecast_sample" and "forecast_quantile". An object of class `forecast_*` can be created using the function `as_forecast()`, which also replaces the previous function `check_forecasts()` (see more information below).
- Scoring functions received a consistent interface and input checks:
- metrics for binary forecasts:
- `observed`: factor with exactly 2 levels
Expand Down
2 changes: 1 addition & 1 deletion R/get_-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ get_duplicate_forecasts <- function(data, forecast_unit = NULL) {

#' @title Get a list of all attributes of a scoringutils object
#'
#' @param object A object of class `scoringutils_`
#' @param object A object of class `forecast_`
#'
#' @return A named list with the attributes of that object.
#' @keywords internal
Expand Down
14 changes: 8 additions & 6 deletions R/score.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ score <- function(data, ...) {
score.default <- function(data, ...) {
assert(check_data_columns(data))
forecast_type <- get_forecast_type(data)
data <- new_scoringutils(data, paste0("scoringutils_", forecast_type))
data <- new_forecast(data, paste0("forecast_", forecast_type))
score(data, ...)
}

#' @rdname score
#' @export
score.scoringutils_binary <- function(data, metrics = metrics_binary, ...) {
score.forecast_binary <- function(data, metrics = metrics_binary, ...) {
data <- validate_forecast(data)
data <- remove_na_observed_predicted(data)
metrics <- validate_metrics(metrics)
Expand All @@ -96,7 +96,7 @@ score.scoringutils_binary <- function(data, metrics = metrics_binary, ...) {
#' @importFrom Metrics se ae ape
#' @rdname score
#' @export
score.scoringutils_point <- function(data, metrics = metrics_point, ...) {
score.forecast_point <- function(data, metrics = metrics_point, ...) {
data <- validate_forecast(data)
data <- remove_na_observed_predicted(data)
metrics <- validate_metrics(metrics)
Expand All @@ -113,7 +113,7 @@ score.scoringutils_point <- function(data, metrics = metrics_point, ...) {

#' @rdname score
#' @export
score.scoringutils_sample <- function(data, metrics = metrics_sample, ...) {
score.forecast_sample <- function(data, metrics = metrics_sample, ...) {
data <- validate_forecast(data)
data <- remove_na_observed_predicted(data)
forecast_unit <- attr(data, "forecast_unit")
Expand Down Expand Up @@ -150,7 +150,7 @@ score.scoringutils_sample <- function(data, metrics = metrics_sample, ...) {
#' @importFrom data.table `:=` as.data.table rbindlist %like%
#' @rdname score
#' @export
score.scoringutils_quantile <- function(data, metrics = metrics_quantile, ...) {
score.forecast_quantile <- function(data, metrics = metrics_quantile, ...) {
data <- validate_forecast(data)
data <- remove_na_observed_predicted(data)
forecast_unit <- attr(data, "forecast_unit")
Expand All @@ -176,7 +176,9 @@ score.scoringutils_quantile <- function(data, metrics = metrics_quantile, ...) {
observed <- data$observed
predicted <- do.call(rbind, data$predicted)
quantile <- unlist(unique(data$quantile))
data[, c("observed", "predicted", "quantile", "scoringutils_quantile") := NULL]
data[, c(
"observed", "predicted", "quantile", "scoringutils_quantile"
) := NULL]

data <- apply_metrics(
data, metrics,
Expand Down
31 changes: 16 additions & 15 deletions R/validate.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
#' quantile-based) from the input data (using the function
#' [get_forecast_type()]. It then constructs an object of the
#' appropriate class (`forecast_binary`, `forecast_point`, `forecast_sample`, or
#' `forecast_quantile`, using the function [new_scoringutils()]).
#' `forecast_quantile`, using the function [new_forecast()]).
#' Lastly, it calls [as_forecast()] on the object to make sure it conforms with
#' the required input formats.
#' @inheritParams score
#' @inheritSection forecast_types Forecast types and input format
#' @return Depending on the forecast type, an object of class
#' `scoringutils_binary`, `scoringutils_point`, `scoringutils_sample` or
#' `scoringutils_quantile`.
#' `forecast_binary`, `forecast_point`, `forecast_sample` or
#' `forecast_quantile`.
#' @export
#' @keywords check-forecasts
#' @examples
Expand All @@ -32,7 +32,7 @@ as_forecast.default <- function(data, ...) {
forecast_type <- get_forecast_type(data)

# construct class
data <- new_scoringutils(data, paste0("scoringutils_", forecast_type))
data <- new_forecast(data, paste0("forecast_", forecast_type))

# validate class
validate_forecast(data)
Expand All @@ -48,8 +48,8 @@ as_forecast.default <- function(data, ...) {
#' @inheritParams score
#' @inheritSection forecast_types Forecast types and input format
#' @return Depending on the forecast type, an object of class
#' `scoringutils_binary`, `scoringutils_point`, `scoringutils_sample` or
#' `scoringutils_quantile`.
#' `forecast_binary`, `forecast_point`, `forecast_sample` or
#' `forecast_quantile`.
#' @importFrom data.table ':=' is.data.table
#' @importFrom checkmate assert_data_frame
#' @export
Expand All @@ -62,10 +62,9 @@ validate_forecast <- function(data, ...) {
}


#' @rdname validate
#' @export
#' @keywords check-forecasts
validate_forecast.scoringutils_binary <- function(data, ...) {
validate_forecast.forecast_binary <- function(data, ...) {
data <- validate_general(data)

columns_correct <- test_columns_not_present(data, c("sample_id", "quantile"))
Expand All @@ -83,10 +82,10 @@ validate_forecast.scoringutils_binary <- function(data, ...) {
return(data[])
}

#' @rdname validate

#' @export
#' @keywords check-forecasts
validate_forecast.scoringutils_point <- function(data, ...) {
validate_forecast.forecast_point <- function(data, ...) {
data <- validate_general(data)

input_check <- check_input_point(data$observed, data$predicted)
Expand All @@ -98,22 +97,24 @@ validate_forecast.scoringutils_point <- function(data, ...) {
return(data[])
}

#' @rdname validate

#' @export
validate_forecast.scoringutils_quantile <- function(data, ...) {
#' @keywords check-forecasts
validate_forecast.forecast_quantile <- function(data, ...) {
data <- validate_general(data)
assert_numeric(data$quantile, lower = 0, upper = 1)
return(data[])
}

#' @rdname validate

#' @export
#' @keywords check-forecasts
validate_forecast.scoringutils_sample <- function(data, ...) {
validate_forecast.forecast_sample <- function(data, ...) {
data <- validate_general(data)
return(data[])
}


#' @title Apply scoringutls input checks that are the same across forecast types
#'
#' @description
Expand Down Expand Up @@ -181,7 +182,7 @@ validate_general <- function(data) {
#' @return An object of the class indicated by `classname`
#' @export
#' @keywords internal
new_scoringutils <- function(data, classname) {
new_forecast <- function(data, classname) {
data <- as.data.table(data)
data <- assure_model_column(data)
class(data) <- c(classname, class(data))
Expand Down
6 changes: 3 additions & 3 deletions man/as_forecast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/get_scoringutils_attributes.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/new_scoringutils.Rd → man/new_forecast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 8 additions & 8 deletions man/score.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/validate_forecast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test-check_forecasts.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
test_that("as_forecast() function works", {
check <- suppressMessages(as_forecast(example_quantile))
expect_s3_class(check, "scoringutils_quantile")
expect_s3_class(check, "forecast_quantile")
})

test_that("as_forecast() function has an error for empty data.frame", {
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-metrics-point.R
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ test_that("abs error is correct, point and median forecasts same", {
observations = truth_scoringutils
)

data_scoringutils_point <- data_scoringutils[type == "point"][, quantile := NULL]
data_forecast_point <- data_scoringutils[type == "point"][, quantile := NULL]

eval <- score(data = data_scoringutils_point)
eval <- score(data = data_forecast_point)
eval <- summarise_scores(eval,
by = c(
"location", "target_end_date",
Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/test-score.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ test_that("function produces output for a binary case", {
expect_true("brier_score" %in% names(eval))
})

test_that("score.scoringutils_binary() errors with only NA values", {
test_that("score.forecast_binary() errors with only NA values", {
only_nas <- copy(example_binary)[, predicted := NA_real_]
expect_error(
score(only_nas),
Expand Down Expand Up @@ -156,7 +156,7 @@ test_that("Changing metrics names works", {
})


test_that("score.scoringutils_point() errors with only NA values", {
test_that("score.forecast_point() errors with only NA values", {
only_nas <- copy(example_point)[, predicted := NA_real_]
expect_error(
score(only_nas),
Expand Down Expand Up @@ -239,7 +239,7 @@ test_that("WIS is the same with other metrics omitted or included", {
})


test_that("score.scoringutils_quantile() errors with only NA values", {
test_that("score.forecast_quantile() errors with only NA values", {
only_nas <- copy(example_quantile)[, predicted := NA_real_]
expect_error(
score(only_nas),
Expand Down