From d7d554f35124ceca554e515f98e877e4e2fec566 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Mon, 6 Jun 2022 16:21:50 -0400 Subject: [PATCH 1/4] prefit xgboost model in `xrf_fit()` --- DESCRIPTION | 5 +- R/rule_fit.R | 69 ++++--- R/rule_fit_data.R | 15 +- man/rules-internal.Rd | 8 +- man/tidy.cubist.Rd | 73 +++---- tests/testthat/_snaps/rule-fit-regression.md | 32 ---- tests/testthat/test-rule-fit-binomial.R | 4 +- tests/testthat/test-rule-fit-multinomial.R | 4 +- tests/testthat/test-rule-fit-regression.R | 188 +++++-------------- 9 files changed, 140 insertions(+), 258 deletions(-) delete mode 100644 tests/testthat/_snaps/rule-fit-regression.md diff --git a/DESCRIPTION b/DESCRIPTION index 1bfe4c9..5cb8d52 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -17,7 +17,7 @@ License: MIT + file LICENSE URL: https://github.com/tidymodels/rules, https://rules.tidymodels.org/ BugReports: https://github.com/tidymodels/rules/issues Depends: - parsnip (>= 0.2.0), + parsnip (>= 0.2.1.9003), R (>= 3.4) Imports: dials (>= 0.1.1.9001), @@ -40,7 +40,8 @@ Suggests: testthat (>= 3.0.0), xrf (>= 0.2.0) Remotes: - tidymodels/dials + tidymodels/dials, + tidymodels/parsnip@rule-fit-stop-iter Config/Needs/website: tidyr, tidyverse/tidytemplate, diff --git a/R/rule_fit.R b/R/rule_fit.R index 55f13f9..78924c5 100644 --- a/R/rule_fit.R +++ b/R/rule_fit.R @@ -7,47 +7,67 @@ xrf_fit <- max_depth = 6, nrounds = 15, eta = 0.3, - colsample_bytree = 1, + colsample_bynode = NULL, + colsample_bytree = NULL, min_child_weight = 1, gamma = 0, subsample = 1, - lambda = 0.1, + validation = 0, + early_stop = NULL, counts = TRUE, + event_level = c("first", "second"), + lambda = 0.1, ...) { - mtry <- - process_mtry( + converted <- + parsnip::.convert_form_to_xy_fit( + formula = formula, + data = data + ) + + prefit <- + parsnip::xgb_train( + converted$x, + converted$y, + max_depth = max_depth, + nrounds = nrounds, + eta = eta, + colsample_bynode = colsample_bynode, colsample_bytree = colsample_bytree, + min_child_weight = min_child_weight, + gamma = gamma, + subsample = subsample, + validation = validation, + early_stop = early_stop, + objective = NULL, counts = counts, - n_predictors = get_num_terms(formula, data), - is_missing = missing(colsample_bytree) + event_level = event_level ) - args <- list( - object = formula, - data = rlang::expr(data), - xgb_control = - list( - nrounds = nrounds, - max_depth = max_depth, - eta = eta, - colsample_bytree = mtry, - min_child_weight = min_child_weight, - gamma = gamma, - subsample = subsample - ) - ) + + args <- + list( + object = formula, + data = rlang::expr(data), + prefit_xgb = prefit + ) + dots <- rlang::enquos(...) if (!any(names(dots) == "family")) { info <- get_family(formula, data) args$family <- info$fam if (info$fam == "multinomial") { + # have to mock an xgb_control object for xrf for now + args$xgb_control <- list() args$xgb_control$num_class <- info$classes + args$xgb_control$nrounds <- 10 } } if (length(dots) > 0) { args <- c(args, dots) } + cl <- rlang::call2(.fn = "xrf", .ns = "xrf", !!!args) res <- rlang::eval_tidy(cl) + res$lambda <- lambda res$family <- args$family res$levels <- get_levels(formula, data) @@ -96,15 +116,6 @@ process_mtry <- function(colsample_bytree, counts, n_predictors, is_missing) { colsample_bytree } -# adapted from parsnip::max_mtry_formula -get_num_terms <- function(formula, data) { - preds <- stats::model.frame(formula, head(data)) - trms <- attr(preds, "terms") - p <- ncol(attr(trms, "factors")) - - max(p, 1) -} - get_family <- function(formula, data) { m <- model.frame(formula, head(data)) y <- model.response(m) diff --git a/R/rule_fit_data.R b/R/rule_fit_data.R index 6bec0b1..06dfcbb 100644 --- a/R/rule_fit_data.R +++ b/R/rule_fit_data.R @@ -34,7 +34,7 @@ make_rule_fit <- function() { model = "rule_fit", eng = "xrf", parsnip = "mtry", - original = "colsample_bytree", + original = "colsample_bynode", func = list(pkg = "dials", fun = "mtry"), has_submodel = FALSE ) @@ -71,13 +71,22 @@ make_rule_fit <- function() { has_submodel = TRUE ) + parsnip::set_model_arg( + model = "rule_fit", + eng = "xrf", + parsnip = "stop_iter", + original = "early_stop", + func = list(pkg = "dials", fun = "stop_iter"), + has_submodel = FALSE + ) + parsnip::set_fit( model = "rule_fit", eng = "xrf", mode = "regression", value = list( interface = "formula", - protect = c("formula", "data"), + protect = c("formula", "data", "xgb_control"), func = c(pkg = "rules", fun = "xrf_fit"), defaults = list() ) @@ -119,7 +128,7 @@ make_rule_fit <- function() { mode = "classification", value = list( interface = "formula", - protect = c("formula", "data"), + protect = c("formula", "data", "xgb_control"), func = c(pkg = "rules", fun = "xrf_fit"), defaults = list() ) diff --git a/man/rules-internal.Rd b/man/rules-internal.Rd index 9db5f45..6c71d73 100644 --- a/man/rules-internal.Rd +++ b/man/rules-internal.Rd @@ -30,12 +30,16 @@ xrf_fit( max_depth = 6, nrounds = 15, eta = 0.3, - colsample_bytree = 1, + colsample_bynode = NULL, + colsample_bytree = NULL, min_child_weight = 1, gamma = 0, subsample = 1, - lambda = 0.1, + validation = 0, + early_stop = NULL, counts = TRUE, + event_level = c("first", "second"), + lambda = 0.1, ... ) diff --git a/man/tidy.cubist.Rd b/man/tidy.cubist.Rd index 93dc83b..055290b 100644 --- a/man/tidy.cubist.Rd +++ b/man/tidy.cubist.Rd @@ -55,21 +55,8 @@ Turn rule models into tidy tibbles \subsection{An example}{ \if{html}{\out{
}}\preformatted{library(dplyr) -}\if{html}{\out{
}} - -\if{html}{\out{
}}\preformatted{## -## Attaching package: 'dplyr' - -## The following objects are masked from 'package:stats': -## -## filter, lag - -## The following objects are masked from 'package:base': -## -## intersect, setdiff, setequal, union -}\if{html}{\out{
}} -\if{html}{\out{
}}\preformatted{data(ames, package = "modeldata") +data(ames, package = "modeldata") ames <- ames \%>\% @@ -89,18 +76,18 @@ cb_res }\if{html}{\out{
}} \if{html}{\out{
}}\preformatted{## # A tibble: 157 × 5 -## committee rule_num rule estimate statistic -## -## 1 1 1 ( Central_Air == 'N' ) & ( Gr_Liv_Area… -## 2 1 2 ( Gr_Liv_Area <= 3.0326188 ) & ( Neigh… -## 3 1 3 ( Neighborhood \%in\% c( 'Old_Town','Ed… -## 4 1 4 ( Neighborhood \%in\% c( 'Old_Town','Ed… -## 5 1 5 ( Central_Air == 'N' ) & ( Gr_Liv_Area… -## 6 1 6 ( Longitude <= -93.652023 ) & ( Neighb… -## 7 1 7 ( Gr_Liv_Area > 3.2284005 ) & ( Neighb… -## 8 1 8 ( Neighborhood \%in\% c( 'North_Ames','… -## 9 1 9 ( Latitude <= 42.009399 ) & ( Neighbor… -## 10 1 10 ( Neighborhood \%in\% c( 'College_Creek… +## committee rule_num rule estimate statistic +## +## 1 1 1 ( Central_Air == 'N' ) & ( Gr_Liv_Area <= 3.228… +## 2 1 2 ( Gr_Liv_Area <= 3.0326188 ) & ( Neighborhood … +## 3 1 3 ( Neighborhood \%in\% c( 'Old_Town','Edwards','B… +## 4 1 4 ( Neighborhood \%in\% c( 'Old_Town','Edwards','B… +## 5 1 5 ( Central_Air == 'N' ) & ( Gr_Liv_Area > 3.2284… +## 6 1 6 ( Longitude <= -93.652023 ) & ( Neighborhood \%… +## 7 1 7 ( Gr_Liv_Area > 3.2284005 ) & ( Neighborhood \%… +## 8 1 8 ( Neighborhood \%in\% c( 'North_Ames','Gilbert',… +## 9 1 9 ( Latitude <= 42.009399 ) & ( Neighborhood \%in… +## 10 1 10 ( Neighborhood \%in\% c( 'College_Creek','Somers… ## # … with 147 more rows }\if{html}{\out{
}} @@ -284,29 +271,29 @@ xrf_reg_fit <- xrf_rule_res$rule[nrow(xrf_rule_res)] \%>\% rlang::parse_expr() }\if{html}{\out{}} -\if{html}{\out{
}}\preformatted{## (Gr_Liv_Area < 3.30210185) & (Gr_Liv_Area < 3.38872266) & (Gr_Liv_Area >= -## 2.94571471) & (Gr_Liv_Area >= 3.24870872) & (Latitude < 42.0271072) & -## (Neighborhood_Old_Town >= -9.53674316e-07) +\if{html}{\out{
}}\preformatted{## (Central_Air_Y >= 0.5) & (Gr_Liv_Area < 3.38872266) & (Gr_Liv_Area >= +## 2.94571471) & (Gr_Liv_Area >= 3.24870872) & (Latitude >= +## 42.0271072) & (Neighborhood_Old_Town >= 0.5) }\if{html}{\out{
}} \if{html}{\out{
}}\preformatted{xrf_col_res <- tidy(xrf_reg_fit, unit = "columns") xrf_col_res }\if{html}{\out{
}} -\if{html}{\out{
}}\preformatted{## # A tibble: 149 × 3 -## rule_id term estimate -## -## 1 r0_1 Gr_Liv_Area -1.27e- 2 -## 2 r2_4 Gr_Liv_Area -3.92e-10 -## 3 r2_2 Gr_Liv_Area 7.59e- 3 -## 4 r2_4 Central_Air_Y -3.92e-10 -## 5 r3_5 Longitude 1.06e- 1 -## 6 r3_6 Longitude 2.65e- 2 -## 7 r3_5 Latitude 1.06e- 1 -## 8 r3_6 Latitude 2.65e- 2 -## 9 r3_5 Longitude 1.06e- 1 -## 10 r3_6 Longitude 2.65e- 2 -## # … with 139 more rows +\if{html}{\out{
}}\preformatted{## # A tibble: 417 × 3 +## rule_id term estimate +## +## 1 r0_1 Gr_Liv_Area -0.0138 +## 2 r2_3 Gr_Liv_Area -0.0310 +## 3 r2_2 Gr_Liv_Area 0.0127 +## 4 r2_3 Central_Air_Y -0.0310 +## 5 r3_5 Longitude 0.0859 +## 6 r3_6 Longitude 0.0171 +## 7 r3_2 Longitude -0.0109 +## 8 r3_5 Latitude 0.0859 +## 9 r3_6 Latitude 0.0171 +## 10 r3_5 Longitude 0.0859 +## # … with 407 more rows }\if{html}{\out{
}} } } diff --git a/tests/testthat/_snaps/rule-fit-regression.md b/tests/testthat/_snaps/rule-fit-regression.md deleted file mode 100644 index f6ca0ef..0000000 --- a/tests/testthat/_snaps/rule-fit-regression.md +++ /dev/null @@ -1,32 +0,0 @@ -# rule_fit handles mtry vs mtry_prop gracefully - - The supplied argument `mtry = 0.5` must be greater than or equal to 1. - - `mtry` is currently being interpreted as a count rather than a proportion. Supply `counts = FALSE` to `set_engine()` to supply this argument as a proportion rather than a count. - - See `?details_rule_fit_xrf` for more details. - ---- - - The supplied argument `mtry = 3` must be less than or equal to 1. - - `mtry` is currently being interpreted as a proportion rather than a count. Supply `counts = TRUE` to `set_engine()` to supply this argument as a count rather than a proportion. - - See `?details_rule_fit_xrf` for more details. - ---- - - Code - pars_fit_8 <- rule_fit(mtry = 0.5, trees = 5) %>% set_engine("xrf", - colsample_bytree = 0.5) %>% set_mode("regression") %>% fit(Sale_Price ~ - Neighborhood + Longitude + Latitude + Gr_Liv_Area + Central_Air, data = ames_data$ - ames) - Warning - The following arguments cannot be manually modified and were removed: colsample_bytree. - Error - The supplied argument `mtry = 0.5` must be greater than or equal to 1. - - `mtry` is currently being interpreted as a count rather than a proportion. Supply `counts = FALSE` to `set_engine()` to supply this argument as a proportion rather than a count. - - See `?details_rule_fit_xrf` for more details. - diff --git a/tests/testthat/test-rule-fit-binomial.R b/tests/testthat/test-rule-fit-binomial.R index 2a13b4b..98b749d 100644 --- a/tests/testthat/test-rule-fit-binomial.R +++ b/tests/testthat/test-rule-fit-binomial.R @@ -39,7 +39,7 @@ test_that("formula method", { rf_pred <- predict(rf_fit, ad_data$ad_pred) rf_prob <- predict(rf_fit, ad_data$ad_pred, type = "prob") - expect_equal(rf_fit_exp$xgb$evaluation_log, rf_fit$fit$xgb$evaluation_log) + expect_equal(unname(rf_fit_exp$xgb$evaluation_log), unname(rf_fit$fit$xgb$evaluation_log)) expect_equal(names(rf_pred), ".pred_class") expect_true(tibble::is_tibble(rf_pred)) @@ -123,7 +123,7 @@ test_that("non-formula method", { rf_pred <- predict(rf_fit, ad_data$ad_pred) rf_prob <- predict(rf_fit, ad_data$ad_pred, type = "prob") - expect_equal(rf_fit_exp$xgb$evaluation_log, rf_fit$fit$xgb$evaluation_log) + expect_equal(unname(rf_fit_exp$xgb$evaluation_log), unname(rf_fit$fit$xgb$evaluation_log)) expect_equal(names(rf_pred), ".pred_class") expect_true(tibble::is_tibble(rf_pred)) diff --git a/tests/testthat/test-rule-fit-multinomial.R b/tests/testthat/test-rule-fit-multinomial.R index 204a2d3..4a20153 100644 --- a/tests/testthat/test-rule-fit-multinomial.R +++ b/tests/testthat/test-rule-fit-multinomial.R @@ -37,7 +37,7 @@ test_that("formula method", { rf_pred <- predict(rf_fit, hpc_data$hpc_pred) rf_prob <- predict(rf_fit, hpc_data$hpc_pred, type = "prob") - expect_equal(rf_fit_exp$xgb$evaluation_log, rf_fit$fit$xgb$evaluation_log) + expect_equal(unname(rf_fit_exp$xgb$evaluation_log), unname(rf_fit$fit$xgb$evaluation_log)) expect_equal(names(rf_pred), ".pred_class") expect_true(tibble::is_tibble(rf_pred)) @@ -124,7 +124,7 @@ test_that("non-formula method", { rf_pred <- predict(rf_fit, hpc_data$hpc_pred) rf_prob <- predict(rf_fit, hpc_data$hpc_pred, type = "prob") - expect_equal(rf_fit_exp$xgb$evaluation_log, rf_fit$fit$xgb$evaluation_log) + expect_equal(unname(rf_fit_exp$xgb$evaluation_log), unname(rf_fit$fit$xgb$evaluation_log)) expect_equal(names(rf_pred), ".pred_class") expect_true(tibble::is_tibble(rf_pred)) diff --git a/tests/testthat/test-rule-fit-regression.R b/tests/testthat/test-rule-fit-regression.R index 27be3af..02dbbd6 100644 --- a/tests/testthat/test-rule-fit-regression.R +++ b/tests/testthat/test-rule-fit-regression.R @@ -34,7 +34,7 @@ test_that("formula method", { ) rf_pred <- predict(rf_fit, chi_data$chi_pred) - expect_equal(rf_fit_exp$xgb$evaluation_log, rf_fit$fit$xgb$evaluation_log) + expect_equal(unname(rf_fit_exp$xgb$evaluation_log), unname(rf_fit_exp$xgb$evaluation_log)) expect_equal(names(rf_pred), ".pred") expect_true(tibble::is_tibble(rf_pred)) expect_equal(rf_pred$.pred, unname(rf_pred_exp)) @@ -89,7 +89,7 @@ test_that("non-formula method", { ) rf_pred <- predict(rf_fit, chi_data$chi_pred) - expect_equal(rf_fit_exp$xgb$evaluation_log, rf_fit$fit$xgb$evaluation_log) + expect_equal(unname(rf_fit_exp$xgb$evaluation_log), unname(rf_fit$fit$xgb$evaluation_log)) expect_equal(rf_fit_exp$glm$model$nzero, rf_fit$fit$glm$model$nzero) expect_equal(names(rf_pred), ".pred") expect_true(tibble::is_tibble(rf_pred)) @@ -157,162 +157,64 @@ test_that("tidy method - regression", { ) }) -test_that("rule_fit handles mtry vs mtry_prop gracefully", { - skip_on_cran() - skip_if_not_installed("xrf") +test_that("early stopping works in xrf_fit", { + rf_mod_1 <- + rule_fit(trees = 5) %>% + set_engine("xrf") %>% + set_mode("regression") - ames_data <- make_ames_data() + rf_mod_2 <- + rule_fit(trees = 5, stop_iter = 3) %>% + set_engine("xrf") %>% + set_mode("regression") - # supply no mtry - expect_error_free({ - pars_fit_1 <- - rule_fit(trees = 5) %>% - set_engine("xrf") %>% - set_mode("regression") %>% - fit( - Sale_Price ~ Neighborhood + Longitude + Latitude + - Gr_Liv_Area + Central_Air, - data = ames_data$ames - ) - }) + rf_mod_3 <- + rule_fit(trees = 5, stop_iter = 5) %>% + set_engine("xrf") %>% + set_mode("regression") - expect_equal( - extract_fit_engine(pars_fit_1)$xgb$params$colsample_bytree, - 1 + expect_error_free( + rf_fit_1 <- fit(rf_mod_1, mpg ~ ., data = mtcars) ) - # supply mtry = 1 (edge cases) - expect_error_free({ - pars_fit_2 <- - rule_fit(mtry = 5, trees = 5) %>% - set_engine("xrf", counts = TRUE) %>% - set_mode("regression") %>% - fit( - Sale_Price ~ Neighborhood + Longitude + Latitude + - Gr_Liv_Area + Central_Air, - data = ames_data$ames - ) - }) - - expect_equal( - extract_fit_engine(pars_fit_2)$xgb$params$colsample_bytree, - 1 + expect_error_free( + rf_fit_2 <- fit(rf_mod_2, mpg ~ ., data = mtcars) ) - expect_error_free({ - pars_fit_3 <- - rule_fit(mtry = 1, trees = 5) %>% - set_engine("xrf", counts = FALSE) %>% - set_mode("regression") %>% - fit( - Sale_Price ~ Neighborhood + Longitude + Latitude + - Gr_Liv_Area + Central_Air, - data = ames_data$ames - ) - }) - - expect_equal( - extract_fit_engine(pars_fit_3)$xgb$params$colsample_bytree, - 1 + expect_warning( + rf_fit_3 <- fit(rf_mod_3, mpg ~ ., data = mtcars), + "\\`early_stop\\` was reduced to 4" ) - # supply a count (with default counts = TRUE) - expect_error_free({ - pars_fit_4 <- - rule_fit(mtry = 5, trees = 5) %>% - set_engine("xrf") %>% - set_mode("regression") %>% - fit( - Sale_Price ~ Neighborhood + Longitude + Latitude + - Gr_Liv_Area + Central_Air, - data = ames_data$ames - ) - }) + expect_true( is.null(rf_fit_1$fit$xgb$best_iteration)) + expect_true(!is.null(rf_fit_2$fit$xgb$best_iteration)) + expect_true(!is.null(rf_fit_3$fit$xgb$best_iteration)) +}) - expect_equal( - extract_fit_engine(pars_fit_4)$xgb$params$colsample_bytree, - 1 - ) +test_that("xrf_fit is sensitive to glm_control", { + rf_mod <- + rule_fit(trees = 3) %>% + set_engine("xrf", glm_control = list(type.measure = "deviance", nfolds = 8)) %>% + set_mode("regression") - # supply a proportion when count expected - expect_snapshot_error({ - pars_fit_5 <- - rule_fit(mtry = .5, trees = 5) %>% - set_engine("xrf") %>% - set_mode("regression") %>% - fit( - Sale_Price ~ Neighborhood + Longitude + Latitude + - Gr_Liv_Area + Central_Air, - data = ames_data$ames - ) - }) - - # supply a count when proportion expected - expect_snapshot_error({ - pars_fit_6 <- - rule_fit(mtry = 3, trees = 5) %>% - set_engine("xrf", counts = FALSE) %>% - set_mode("regression") %>% - fit( - Sale_Price ~ Neighborhood + Longitude + Latitude + - Gr_Liv_Area + Central_Air, - data = ames_data$ames - ) - }) - - expect_warning({ - pars_fit_7 <- - rule_fit(trees = 5) %>% - set_engine("xrf", colsample_bytree = .5) %>% - set_mode("regression") %>% - fit( - Sale_Price ~ Neighborhood + Longitude + Latitude + - Gr_Liv_Area + Central_Air, - data = ames_data$ames - )}, - "manually modified and were removed: colsample_bytree." + expect_error_free( + rf_fit_1 <- fit(rf_mod, mpg ~ ., data = mtcars) ) - expect_equal( - extract_fit_engine(pars_fit_7)$xgb$params$colsample_bytree, - 1 - ) + rf_fit_1_call_args <- rlang::call_args(rf_fit_1$fit$glm$model$call) - # supply both feature fraction and mtry - expect_snapshot({ - pars_fit_8 <- - rule_fit(mtry = .5, trees = 5) %>% - set_engine("xrf", colsample_bytree = .5) %>% - set_mode("regression") %>% - fit( - Sale_Price ~ Neighborhood + Longitude + Latitude + - Gr_Liv_Area + Central_Air, - data = ames_data$ames - )}, - error = TRUE - ) + expect_equal(rf_fit_1_call_args$nfolds, 8) + expect_equal(rf_fit_1_call_args$type.measure, "deviance") +}) - expect_warning({ - pars_fit_9 <- - rule_fit(mtry = 5, trees = 5) %>% - set_engine("xrf", colsample_bytree = .5) %>% - set_mode("regression") %>% - fit( - Sale_Price ~ Neighborhood + Longitude + Latitude + - Gr_Liv_Area + Central_Air, - data = ames_data$ames - )}, - "manually modified and were removed: colsample_bytree." - ) +test_that("xrf_fit guards xgb_control", { + rf_mod <- + rule_fit(trees = 3) %>% + set_engine("xrf", xgb_control = list(nrounds = 3)) %>% + set_mode("regression") - expect_equal( - extract_fit_engine(pars_fit_9)$xgb$params$colsample_bytree, - 1 + expect_warning( + fit(rf_mod, mpg ~ ., data = mtcars), + "and were removed: xgb_control" ) - - # internal helper works as expected - expect_equal(get_num_terms(mpg ~ ., mtcars), ncol(mtcars) - 1) - expect_equal(get_num_terms(mpg ~ . + disp*drat, mtcars), ncol(mtcars)) - expect_equal(get_num_terms(mpg ~ disp, mtcars), 1) - expect_equal(get_num_terms(mpg ~ NULL, mtcars), 1) }) From 05beb3d86a256cc870e8473bf324a745fdd565ef Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Mon, 6 Jun 2022 16:50:48 -0400 Subject: [PATCH 2/4] add NEWS bullets --- NEWS.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index e01d499..d874033 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,8 +6,11 @@ * The `mtry_prop` parameter was moved to the dials package and is now re-exported here for backward compatibility. -* A bug was fixed related to `multi_predict()` with C5.0 rule-based models (#49) +* A bug was fixed related to `multi_predict()` with C5.0 rule-based models (#49). +* The `mtry` argument is now mapped to `colsample_bynode` rather than `colsample_bytree`. This is consistent with parsnip's interface to `xgboost` as of parsnip 0.1.6. `colsample_bytree` can still be optimized by passing it in as an engine argument to `set_engine()` (#60). + +* Introduced support for early stopping in `rule_fit()` via the `stop_iter` argument. See `parsnip::details_rule_fit_xrf`. Note that this is a _main_ argument to `rule_fit()` requiring parsnip 1.0.0. # rules 0.2.0 From 1cca1f903c7f375d62bb8469ffd44d50f77476a2 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Mon, 6 Jun 2022 17:01:54 -0400 Subject: [PATCH 3/4] ignore parsnip-protected argument --- R/rule_fit.R | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/R/rule_fit.R b/R/rule_fit.R index 78924c5..864f951 100644 --- a/R/rule_fit.R +++ b/R/rule_fit.R @@ -51,6 +51,10 @@ xrf_fit <- ) dots <- rlang::enquos(...) + + # ignore parsnip-protected argument + dots[["xgb_control"]] <- NULL + if (!any(names(dots) == "family")) { info <- get_family(formula, data) args$family <- info$fam From f6404dea1c84cf0d469d7b8fb5dbab61cce2ebf6 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Wed, 8 Jun 2022 14:38:22 -0400 Subject: [PATCH 4/4] `expect_warning()` -> `expect_snapshot()` --- tests/testthat/_snaps/rule-fit-regression.md | 22 ++++++++++++++++++++ tests/testthat/test-rule-fit-regression.R | 14 +++++++------ 2 files changed, 30 insertions(+), 6 deletions(-) create mode 100644 tests/testthat/_snaps/rule-fit-regression.md diff --git a/tests/testthat/_snaps/rule-fit-regression.md b/tests/testthat/_snaps/rule-fit-regression.md new file mode 100644 index 0000000..cbceb19 --- /dev/null +++ b/tests/testthat/_snaps/rule-fit-regression.md @@ -0,0 +1,22 @@ +# early stopping works in xrf_fit + + Code + suppressMessages(rf_fit_3 <- fit(rf_mod_3, mpg ~ ., data = mtcars)) + Warning + `early_stop` was reduced to 4. + +# xrf_fit guards xgb_control + + Code + suppressMessages(fit(rf_mod, mpg ~ ., data = mtcars)) + Warning + The following arguments cannot be manually modified and were removed: xgb_control. + Output + parsnip model object + + An eXtreme RuleFit model of 7 rules. + + Original Formula: + + mpg ~ cyl + disp + hp + drat + wt + qsec + vs + am + gear + carb + diff --git a/tests/testthat/test-rule-fit-regression.R b/tests/testthat/test-rule-fit-regression.R index 02dbbd6..a5a764f 100644 --- a/tests/testthat/test-rule-fit-regression.R +++ b/tests/testthat/test-rule-fit-regression.R @@ -181,9 +181,10 @@ test_that("early stopping works in xrf_fit", { rf_fit_2 <- fit(rf_mod_2, mpg ~ ., data = mtcars) ) - expect_warning( - rf_fit_3 <- fit(rf_mod_3, mpg ~ ., data = mtcars), - "\\`early_stop\\` was reduced to 4" + expect_snapshot( + suppressMessages( + rf_fit_3 <- fit(rf_mod_3, mpg ~ ., data = mtcars) + ) ) expect_true( is.null(rf_fit_1$fit$xgb$best_iteration)) @@ -213,8 +214,9 @@ test_that("xrf_fit guards xgb_control", { set_engine("xrf", xgb_control = list(nrounds = 3)) %>% set_mode("regression") - expect_warning( - fit(rf_mod, mpg ~ ., data = mtcars), - "and were removed: xgb_control" + expect_snapshot( + suppressMessages( + fit(rf_mod, mpg ~ ., data = mtcars) + ) ) })