diff --git a/DESCRIPTION b/DESCRIPTION index 8c77e9321..ceaa46161 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -56,7 +56,7 @@ LazyData: true Imports: checkmate, cli, - data.table (>= 1.16.0), + data.table (>= 1.17.0), ggplot2 (>= 3.4.0), methods, purrr, diff --git a/NAMESPACE b/NAMESPACE index ca5534b46..0659e3d61 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -38,6 +38,7 @@ S3method(get_pit_histogram,forecast_quantile) S3method(get_pit_histogram,forecast_sample) S3method(head,forecast) S3method(print,forecast) +S3method(print,scores) S3method(score,default) S3method(score,forecast_binary) S3method(score,forecast_multivariate_point) diff --git a/R/class-forecast.R b/R/class-forecast.R index 8ec258401..b862fbceb 100644 --- a/R/class-forecast.R +++ b/R/class-forecast.R @@ -271,16 +271,20 @@ is_forecast <- function(x) { # where we used data.table := operator which will turn x into out before we # arrive to this function. is_dt_force_print <- identical(x, out) && ...length() == 1 - # ...length() as it still returns 1 in x[] and then skips validations in - # undesired situation if we set ...length() > 1 + + # Detect in-place modification via `:=`. When `:=` is used, data.table + # modifies x in place so x and out are identical. We distinguish from x[] + # (force-print) by checking ...length(): x[] has ...length() == 1, while + # := has ...length() > 1. We skip validation for := since the user is just + # modifying a column and the autoprint suppression is handled by + # print.forecast()'s shouldPrint() check. + # See https://github.com/epiforecasts/scoringutils/issues/935 + is_inplace_modify <- identical(x, out) && ...length() > 1 + # is.data.table: when [.data.table returns an atomic vector, it's clear it # cannot be a valid forecast object, and it is likely intended by the user - # in addition, we also check for a maximum length. The reason is that - # print.data.table will internally subset the data.table before printing. - # this subsetting triggers the validation, which is not desired in this case. - # this is a hack and ideally, we'd do things differently. - if (nrow(out) > 30 && is.data.table(out) && !is_dt_force_print) { + if (is.data.table(out) && !is_dt_force_print && !is_inplace_modify) { # check whether subset object passes validation validation <- try( assert_forecast(forecast = out, verbose = FALSE), @@ -290,8 +294,7 @@ is_forecast <- function(x) { cli_warn( c( `!` = "Error in validating forecast object: {validation}.", - i = "Note this error is sometimes related to `data.table`s `print`. - Run {.help [{.fun assert_forecast}](scoringutils::assert_forecast)} + i = "Run {.help [{.fun assert_forecast}](scoringutils::assert_forecast)} to confirm. To get rid of this warning entirely, call `as.data.table()` on the forecast object." ) @@ -408,6 +411,17 @@ tail.forecast <- function(x, ...) { #' print(dat) print.forecast <- function(x, ...) { + # Suppress autoprinting during data.table `:=` operations. + # When `:=` modifies a data.table in place, R's autoprint mechanism triggers + # the print method. data.table tracks this via an internal shouldPrint() + # function which returns FALSE when `:=` was just used. We check this early + # to avoid printing the forecast header in that case. + # See https://github.com/epiforecasts/scoringutils/issues/935 + shouldPrint <- utils::getFromNamespace("shouldPrint", "data.table") # nolint: object_name_linter + if (!shouldPrint(x)) { + return(invisible(x)) + } + # get forecast type, forecast unit and score columns forecast_type <- try( do.call(get_forecast_type, list(forecast = x)), diff --git a/R/class-scores.R b/R/class-scores.R index bc6e40d37..f52470c57 100644 --- a/R/class-scores.R +++ b/R/class-scores.R @@ -71,6 +71,27 @@ assert_scores <- function(scores) { } +#' @title Print a scores object +#' @description +#' Prints a `scores` object. Suppresses autoprinting during data.table +#' `:=` operations by checking data.table's internal `shouldPrint()` flag. +#' @param x A `scores` object +#' @param ... Additional arguments for [print()]. +#' @returns Returns `x` invisibly. +#' @export +#' @keywords gain-insights +print.scores <- function(x, ...) { + # Suppress autoprinting during data.table `:=` operations. + # See https://github.com/epiforecasts/scoringutils/issues/935 + shouldPrint <- utils::getFromNamespace("shouldPrint", "data.table") # nolint: object_name_linter + if (!shouldPrint(x)) { + return(invisible(x)) + } + NextMethod() + return(invisible(x)) +} + + #' @title Get names of the metrics that were used for scoring #' @description #' When applying a scoring rule via [score()], the names of the scoring rules diff --git a/man/print.scores.Rd b/man/print.scores.Rd new file mode 100644 index 000000000..a02d29fa0 --- /dev/null +++ b/man/print.scores.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/class-scores.R +\name{print.scores} +\alias{print.scores} +\title{Print a scores object} +\usage{ +\method{print}{scores}(x, ...) +} +\arguments{ +\item{x}{A \code{scores} object} + +\item{...}{Additional arguments for \code{\link[=print]{print()}}.} +} +\value{ +Returns \code{x} invisibly. +} +\description{ +Prints a \code{scores} object. Suppresses autoprinting during data.table +\verb{:=} operations by checking data.table's internal \code{shouldPrint()} flag. +} +\keyword{gain-insights} diff --git a/tests/testthat/test-class-forecast-point.R b/tests/testthat/test-class-forecast-point.R index ed33a421f..9ead1c10d 100644 --- a/tests/testthat/test-class-forecast-point.R +++ b/tests/testthat/test-class-forecast-point.R @@ -29,11 +29,9 @@ test_that("assert_forecast.forecast_point() works as expected", { test <- na.omit(data.table::as.data.table(example_point)) test <- as_forecast_point(test) - # expect an error if column is changed to character after initial validation. - expect_warning( - test[, "predicted" := as.character(predicted)], - "Input looks like a point forecast, but found the following issue" - ) + # := skips validation (to avoid spurious autoprinting, see #935), + # but assert_forecast() should catch the invalid state afterwards. + test[, "predicted" := as.character(predicted)] expect_error( assert_forecast(test), "Input looks like a point forecast, but found the following issue" diff --git a/tests/testthat/test-class-forecast.R b/tests/testthat/test-class-forecast.R index 92f9ca2db..3601be18f 100644 --- a/tests/testthat/test-class-forecast.R +++ b/tests/testthat/test-class-forecast.R @@ -155,8 +155,6 @@ test_that("print() throws the expected messages", { test <- data.table::copy(example_point) class(test) <- c("point", "forecast", "data.table", "data.frame") - # note that since introducing a length maximum for validation to be triggered, - # we don't throw a warning automatically anymore suppressMessages( expect_message( capture.output(print(test)), @@ -174,6 +172,114 @@ test_that("print() throws the expected messages", { }) +# ============================================================================== +# Autoprint suppression during := operations (issue #935) +# ============================================================================== + +test_that(":= on forecast objects does not trigger spurious printing", { + # This is the core issue from #935: modifying a column via := should + # not print the forecast object + ex <- data.table::copy(example_quantile) + output <- capture.output(ex[, model := paste(model, "a")]) + expect_identical(output, character(0)) +}) + +test_that(":= adding a new column to forecast objects does not print", { + ex <- data.table::copy(example_quantile) + output <- capture.output(ex[, new_col := "test"]) + expect_identical(output, character(0)) +}) + +test_that(":= on different forecast types does not trigger printing", { + # Test across all forecast types to ensure the fix is comprehensive + forecast_objects <- list( + binary = data.table::copy(example_binary), + quantile = data.table::copy(example_quantile), + point = data.table::copy(example_point), + sample_continuous = data.table::copy(example_sample_continuous), + sample_discrete = data.table::copy(example_sample_discrete) + ) + + for (name in names(forecast_objects)) { + ex <- forecast_objects[[name]] + output <- capture.output(ex[, model := paste(model, "a")]) + expect_identical( + output, character(0), + label = paste("Spurious printing for forecast type:", name) + ) + } +}) + +test_that("multiple sequential := operations do not trigger printing", { + ex <- data.table::copy(example_quantile) + output <- capture.output({ + ex[, model := paste(model, "a")] + ex[, new_col1 := 1] + ex[, new_col2 := "test"] + }) + expect_identical(output, character(0)) +}) + +test_that("explicit print() still works after := suppression", { + # After :=, data.table sets a flag that suppresses the next print. + # In interactive R, autoprint consumes this flag. In non-interactive + # contexts (testthat, scripts), we consume it manually with x[]. + ex <- data.table::copy(example_quantile) + ex[, model := paste(model, "a")] + invisible(capture.output(suppressMessages(ex[]))) # consume shouldPrint flag + + output <- capture.output(suppressMessages(print(ex))) + expect_gt(length(output), 0) +}) + +test_that("print() on forecast objects still shows header and data", { + ex <- as_forecast_quantile(na.omit(example_quantile)) + + messages <- capture.output(print(ex), type = "message") + output <- capture.output(suppressMessages(print(ex))) + + # Header should contain forecast type and unit info + header_text <- paste(messages, collapse = " ") + expect_true(grepl("Forecast type", header_text, fixed = TRUE)) + expect_true(grepl("Forecast unit", header_text, fixed = TRUE)) + + # Data should be printed + expect_gt(length(output), 0) +}) + +test_that("x[] force-print still works on forecast objects", { + # x[] is data.table's force-print syntax, should still produce output + ex <- as_forecast_quantile(na.omit(example_quantile)) + output <- capture.output(suppressMessages(ex[])) + expect_gt(length(output), 0) +}) + +test_that(":= on scores objects does not trigger spurious printing", { + scores <- score(example_quantile) + output <- capture.output(scores[, test := 3]) + expect_identical(output, character(0)) +}) + +test_that("[.forecast() validates subsets regardless of size", { + # After removing the 30-row hack, validation should trigger for + # any size subset that breaks the forecast contract + test <- na.omit(data.table::copy(example_quantile)) + + # Small subset (previously skipped validation due to nrow <= 30 hack) + small_test <- test[1:20] + expect_warning( + local(small_test[, colnames(small_test) != "observed", with = FALSE]), + "Error in validating" + ) + + # Large subset (was already validated before) + expect_warning( + local(test[, colnames(test) != "observed", with = FALSE]), + "Error in validating" + ) +}) + + # ============================================================================== # check_number_per_forecast() # nolint: commented_code_linter # ==============================================================================