diff --git a/NAMESPACE b/NAMESPACE index 08e1dddc7..e99821a05 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -4,6 +4,7 @@ S3method(augment,coxph_exploratory) S3method(augment,glm_exploratory) S3method(augment,lm_exploratory) S3method(augment,multinom) +S3method(augment,prophet_exploratory) S3method(augment,randomForest) S3method(augment,randomForest.classification) S3method(augment,randomForest.formula) diff --git a/R/prophet.R b/R/prophet.R index 36dd21531..e7c25e926 100644 --- a/R/prophet.R +++ b/R/prophet.R @@ -783,7 +783,7 @@ do_prophet_ <- function(df, time_col, value_col = NULL, periods = 10, time_unit else { regressor_name_map <- regressor_final_output_cols names(regressor_name_map) <- regressor_output_cols - model <- list(result=ret, model=m, test_mode=test_mode, value_col=value_col, regressor_name_map=regressor_name_map) + model <- list(result=ret, model=m, test_mode=test_mode, time_col=time_col, value_col=value_col, regressor_name_map=regressor_name_map) class(model) <- c("prophet_exploratory", class(model)) model } @@ -884,3 +884,64 @@ tidy.prophet_exploratory <- function(x, type="result") { res } } + +#' @export +augment.prophet_exploratory <- function(x, data = NULL, newdata = NULL, data_type = "training", ...) { + if ("error" %in% class(x)) { + ret <- data.frame(Note = x$message) + return(ret) + } + + time_col <- x$time_col + value_col <- x$value_col + # TODO: Avoid column name conflict with the original data. + predicted_value_col <- "forecasted_value" + predicted_value_high_col <- "forecasted_value_high" + predicted_value_low_col <- "forecasted_value_low" + + if (!is.null(newdata)) { + + # create clean name data frame because the model learned by those names + original_data <- newdata + + # Drop unnecessary columns. + cleaned_data <- original_data %>% dplyr::select(!!rlang::sym(time_col)) + + # Remove NA rows. + na_row_numbers <- ranger.find_na(time_col, cleaned_data) + + if (length(na_row_numbers) > 0) { + # Remove NA rows. drop=FALSE is necessary to keep the data frame structure. + cleaned_data <- cleaned_data[-na_row_numbers, , drop=FALSE] + } + + if (nrow(cleaned_data) == 0) { + return(data.frame()) + } + + # The model requires the time column to be named "ds". + if (time_col != "ds") { + cleaned_data <- cleaned_data %>% dplyr::rename(ds = !!rlang::sym(time_col)) + } + + # Run prediction. + predicted_data <- stats::predict(x$model, cleaned_data) + + # Inserting once removed NA rows + original_data[[predicted_value_col]] <- restore_na(predicted_data$yhat, na_row_numbers) + original_data[[predicted_value_high_col]] <- restore_na(predicted_data$yhat_lower, na_row_numbers) + original_data[[predicted_value_low_col]] <- restore_na(predicted_data$yhat_upper, na_row_numbers) + + original_data + + } else if (!is.null(data)) { + # Return the original data with the forecasted values. + # Use the result in the model object. + return(x$result) + } else { + # Return the original data with the forecasted values. + # Use the result in the model object. + return(x$result) + } + +} \ No newline at end of file diff --git a/tests/testthat/test_prophet_5.R b/tests/testthat/test_prophet_5.R new file mode 100644 index 000000000..ac6534e63 --- /dev/null +++ b/tests/testthat/test_prophet_5.R @@ -0,0 +1,26 @@ +context("test prophet functions - Holiday Country Names, Repeat By") + +set.seed(1) + +test_that("augment.prophet_exploratory", { + # Create training data. + history <- data.frame(x = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'), + y = sin(1:366/200) + rnorm(366)/10) + + # Create a test data with the same data range. + # It including NAs and it drops some dates. + testdata <- history %>% + select(x) %>% + mutate(x = if_else(x == as.Date('2015-01-05'), as.Date(NA), x )) %>% + filter(x != as.Date('2015-01-10')) + + + model.df <- history %>% do_prophet(time=x, value=y, output="model", periods=0) + + ret <- broom::augment(model.df$model[[1]], newdata=testdata) + + #print(ret) + expect_true("forecasted_value" %in% colnames(ret)) + expect_true("forecasted_value_high" %in% colnames(ret)) + expect_true("forecasted_value_low" %in% colnames(ret)) +})