diff --git a/DESCRIPTION b/DESCRIPTION
index 1bfe4c9..5cb8d52 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -17,7 +17,7 @@ License: MIT + file LICENSE
URL: https://github.com/tidymodels/rules, https://rules.tidymodels.org/
BugReports: https://github.com/tidymodels/rules/issues
Depends:
- parsnip (>= 0.2.0),
+ parsnip (>= 0.2.1.9003),
R (>= 3.4)
Imports:
dials (>= 0.1.1.9001),
@@ -40,7 +40,8 @@ Suggests:
testthat (>= 3.0.0),
xrf (>= 0.2.0)
Remotes:
- tidymodels/dials
+ tidymodels/dials,
+ tidymodels/parsnip@rule-fit-stop-iter
Config/Needs/website:
tidyr,
tidyverse/tidytemplate,
diff --git a/NEWS.md b/NEWS.md
index e01d499..d874033 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -6,8 +6,11 @@
* The `mtry_prop` parameter was moved to the dials package and is now re-exported here for backward compatibility.
-* A bug was fixed related to `multi_predict()` with C5.0 rule-based models (#49)
+* A bug was fixed related to `multi_predict()` with C5.0 rule-based models (#49).
+* The `mtry` argument is now mapped to `colsample_bynode` rather than `colsample_bytree`. This is consistent with parsnip's interface to `xgboost` as of parsnip 0.1.6. `colsample_bytree` can still be optimized by passing it in as an engine argument to `set_engine()` (#60).
+
+* Introduced support for early stopping in `rule_fit()` via the `stop_iter` argument. See `parsnip::details_rule_fit_xrf`. Note that this is a _main_ argument to `rule_fit()` requiring parsnip 1.0.0.
# rules 0.2.0
diff --git a/R/rule_fit.R b/R/rule_fit.R
index 55f13f9..864f951 100644
--- a/R/rule_fit.R
+++ b/R/rule_fit.R
@@ -7,47 +7,71 @@ xrf_fit <-
max_depth = 6,
nrounds = 15,
eta = 0.3,
- colsample_bytree = 1,
+ colsample_bynode = NULL,
+ colsample_bytree = NULL,
min_child_weight = 1,
gamma = 0,
subsample = 1,
- lambda = 0.1,
+ validation = 0,
+ early_stop = NULL,
counts = TRUE,
+ event_level = c("first", "second"),
+ lambda = 0.1,
...) {
- mtry <-
- process_mtry(
+ converted <-
+ parsnip::.convert_form_to_xy_fit(
+ formula = formula,
+ data = data
+ )
+
+ prefit <-
+ parsnip::xgb_train(
+ converted$x,
+ converted$y,
+ max_depth = max_depth,
+ nrounds = nrounds,
+ eta = eta,
+ colsample_bynode = colsample_bynode,
colsample_bytree = colsample_bytree,
+ min_child_weight = min_child_weight,
+ gamma = gamma,
+ subsample = subsample,
+ validation = validation,
+ early_stop = early_stop,
+ objective = NULL,
counts = counts,
- n_predictors = get_num_terms(formula, data),
- is_missing = missing(colsample_bytree)
+ event_level = event_level
)
- args <- list(
- object = formula,
- data = rlang::expr(data),
- xgb_control =
- list(
- nrounds = nrounds,
- max_depth = max_depth,
- eta = eta,
- colsample_bytree = mtry,
- min_child_weight = min_child_weight,
- gamma = gamma,
- subsample = subsample
- )
- )
+
+ args <-
+ list(
+ object = formula,
+ data = rlang::expr(data),
+ prefit_xgb = prefit
+ )
+
dots <- rlang::enquos(...)
+
+ # ignore parsnip-protected argument
+ dots[["xgb_control"]] <- NULL
+
if (!any(names(dots) == "family")) {
info <- get_family(formula, data)
args$family <- info$fam
if (info$fam == "multinomial") {
+ # have to mock an xgb_control object for xrf for now
+ args$xgb_control <- list()
args$xgb_control$num_class <- info$classes
+ args$xgb_control$nrounds <- 10
}
}
if (length(dots) > 0) {
args <- c(args, dots)
}
+
cl <- rlang::call2(.fn = "xrf", .ns = "xrf", !!!args)
res <- rlang::eval_tidy(cl)
+
res$lambda <- lambda
res$family <- args$family
res$levels <- get_levels(formula, data)
@@ -96,15 +120,6 @@ process_mtry <- function(colsample_bytree, counts, n_predictors, is_missing) {
colsample_bytree
}
-# adapted from parsnip::max_mtry_formula
-get_num_terms <- function(formula, data) {
- preds <- stats::model.frame(formula, head(data))
- trms <- attr(preds, "terms")
- p <- ncol(attr(trms, "factors"))
-
- max(p, 1)
-}
-
get_family <- function(formula, data) {
m <- model.frame(formula, head(data))
y <- model.response(m)
diff --git a/R/rule_fit_data.R b/R/rule_fit_data.R
index 6bec0b1..06dfcbb 100644
--- a/R/rule_fit_data.R
+++ b/R/rule_fit_data.R
@@ -34,7 +34,7 @@ make_rule_fit <- function() {
model = "rule_fit",
eng = "xrf",
parsnip = "mtry",
- original = "colsample_bytree",
+ original = "colsample_bynode",
func = list(pkg = "dials", fun = "mtry"),
has_submodel = FALSE
)
@@ -71,13 +71,22 @@ make_rule_fit <- function() {
has_submodel = TRUE
)
+ parsnip::set_model_arg(
+ model = "rule_fit",
+ eng = "xrf",
+ parsnip = "stop_iter",
+ original = "early_stop",
+ func = list(pkg = "dials", fun = "stop_iter"),
+ has_submodel = FALSE
+ )
+
parsnip::set_fit(
model = "rule_fit",
eng = "xrf",
mode = "regression",
value = list(
interface = "formula",
- protect = c("formula", "data"),
+ protect = c("formula", "data", "xgb_control"),
func = c(pkg = "rules", fun = "xrf_fit"),
defaults = list()
)
@@ -119,7 +128,7 @@ make_rule_fit <- function() {
mode = "classification",
value = list(
interface = "formula",
- protect = c("formula", "data"),
+ protect = c("formula", "data", "xgb_control"),
func = c(pkg = "rules", fun = "xrf_fit"),
defaults = list()
)
diff --git a/man/rules-internal.Rd b/man/rules-internal.Rd
index 9db5f45..6c71d73 100644
--- a/man/rules-internal.Rd
+++ b/man/rules-internal.Rd
@@ -30,12 +30,16 @@ xrf_fit(
max_depth = 6,
nrounds = 15,
eta = 0.3,
- colsample_bytree = 1,
+ colsample_bynode = NULL,
+ colsample_bytree = NULL,
min_child_weight = 1,
gamma = 0,
subsample = 1,
- lambda = 0.1,
+ validation = 0,
+ early_stop = NULL,
counts = TRUE,
+ event_level = c("first", "second"),
+ lambda = 0.1,
...
)
diff --git a/man/tidy.cubist.Rd b/man/tidy.cubist.Rd
index 93dc83b..055290b 100644
--- a/man/tidy.cubist.Rd
+++ b/man/tidy.cubist.Rd
@@ -55,21 +55,8 @@ Turn rule models into tidy tibbles
\subsection{An example}{
\if{html}{\out{
}}\preformatted{library(dplyr)
-}\if{html}{\out{
}}
-
-\if{html}{\out{}}\preformatted{##
-## Attaching package: 'dplyr'
-
-## The following objects are masked from 'package:stats':
-##
-## filter, lag
-
-## The following objects are masked from 'package:base':
-##
-## intersect, setdiff, setequal, union
-}\if{html}{\out{
}}
-\if{html}{\out{}}\preformatted{data(ames, package = "modeldata")
+data(ames, package = "modeldata")
ames <-
ames \%>\%
@@ -89,18 +76,18 @@ cb_res
}\if{html}{\out{
}}
\if{html}{\out{}}\preformatted{## # A tibble: 157 × 5
-## committee rule_num rule estimate statistic
-##
-## 1 1 1 ( Central_Air == 'N' ) & ( Gr_Liv_Area…
-## 2 1 2 ( Gr_Liv_Area <= 3.0326188 ) & ( Neigh…
-## 3 1 3 ( Neighborhood \%in\% c( 'Old_Town','Ed…
-## 4 1 4 ( Neighborhood \%in\% c( 'Old_Town','Ed…
-## 5 1 5 ( Central_Air == 'N' ) & ( Gr_Liv_Area…
-## 6 1 6 ( Longitude <= -93.652023 ) & ( Neighb…
-## 7 1 7 ( Gr_Liv_Area > 3.2284005 ) & ( Neighb…
-## 8 1 8 ( Neighborhood \%in\% c( 'North_Ames','…
-## 9 1 9 ( Latitude <= 42.009399 ) & ( Neighbor…
-## 10 1 10 ( Neighborhood \%in\% c( 'College_Creek…
+## committee rule_num rule estimate statistic
+##
+## 1 1 1 ( Central_Air == 'N' ) & ( Gr_Liv_Area <= 3.228…
+## 2 1 2 ( Gr_Liv_Area <= 3.0326188 ) & ( Neighborhood …
+## 3 1 3 ( Neighborhood \%in\% c( 'Old_Town','Edwards','B…
+## 4 1 4 ( Neighborhood \%in\% c( 'Old_Town','Edwards','B…
+## 5 1 5 ( Central_Air == 'N' ) & ( Gr_Liv_Area > 3.2284…
+## 6 1 6 ( Longitude <= -93.652023 ) & ( Neighborhood \%…
+## 7 1 7 ( Gr_Liv_Area > 3.2284005 ) & ( Neighborhood \%…
+## 8 1 8 ( Neighborhood \%in\% c( 'North_Ames','Gilbert',…
+## 9 1 9 ( Latitude <= 42.009399 ) & ( Neighborhood \%in…
+## 10 1 10 ( Neighborhood \%in\% c( 'College_Creek','Somers…
## # … with 147 more rows
}\if{html}{\out{
}}
@@ -284,29 +271,29 @@ xrf_reg_fit <-
xrf_rule_res$rule[nrow(xrf_rule_res)] \%>\% rlang::parse_expr()
}\if{html}{\out{}}
-\if{html}{\out{}}\preformatted{## (Gr_Liv_Area < 3.30210185) & (Gr_Liv_Area < 3.38872266) & (Gr_Liv_Area >=
-## 2.94571471) & (Gr_Liv_Area >= 3.24870872) & (Latitude < 42.0271072) &
-## (Neighborhood_Old_Town >= -9.53674316e-07)
+\if{html}{\out{
}}\preformatted{## (Central_Air_Y >= 0.5) & (Gr_Liv_Area < 3.38872266) & (Gr_Liv_Area >=
+## 2.94571471) & (Gr_Liv_Area >= 3.24870872) & (Latitude >=
+## 42.0271072) & (Neighborhood_Old_Town >= 0.5)
}\if{html}{\out{
}}
\if{html}{\out{
}}\preformatted{xrf_col_res <- tidy(xrf_reg_fit, unit = "columns")
xrf_col_res
}\if{html}{\out{
}}
-\if{html}{\out{
}}\preformatted{## # A tibble: 149 × 3
-## rule_id term estimate
-##
-## 1 r0_1 Gr_Liv_Area -1.27e- 2
-## 2 r2_4 Gr_Liv_Area -3.92e-10
-## 3 r2_2 Gr_Liv_Area 7.59e- 3
-## 4 r2_4 Central_Air_Y -3.92e-10
-## 5 r3_5 Longitude 1.06e- 1
-## 6 r3_6 Longitude 2.65e- 2
-## 7 r3_5 Latitude 1.06e- 1
-## 8 r3_6 Latitude 2.65e- 2
-## 9 r3_5 Longitude 1.06e- 1
-## 10 r3_6 Longitude 2.65e- 2
-## # … with 139 more rows
+\if{html}{\out{}}\preformatted{## # A tibble: 417 × 3
+## rule_id term estimate
+##
+## 1 r0_1 Gr_Liv_Area -0.0138
+## 2 r2_3 Gr_Liv_Area -0.0310
+## 3 r2_2 Gr_Liv_Area 0.0127
+## 4 r2_3 Central_Air_Y -0.0310
+## 5 r3_5 Longitude 0.0859
+## 6 r3_6 Longitude 0.0171
+## 7 r3_2 Longitude -0.0109
+## 8 r3_5 Latitude 0.0859
+## 9 r3_6 Latitude 0.0171
+## 10 r3_5 Longitude 0.0859
+## # … with 407 more rows
}\if{html}{\out{
}}
}
}
diff --git a/tests/testthat/_snaps/rule-fit-regression.md b/tests/testthat/_snaps/rule-fit-regression.md
index f6ca0ef..cbceb19 100644
--- a/tests/testthat/_snaps/rule-fit-regression.md
+++ b/tests/testthat/_snaps/rule-fit-regression.md
@@ -1,32 +1,22 @@
-# rule_fit handles mtry vs mtry_prop gracefully
+# early stopping works in xrf_fit
- The supplied argument `mtry = 0.5` must be greater than or equal to 1.
-
- `mtry` is currently being interpreted as a count rather than a proportion. Supply `counts = FALSE` to `set_engine()` to supply this argument as a proportion rather than a count.
-
- See `?details_rule_fit_xrf` for more details.
-
----
-
- The supplied argument `mtry = 3` must be less than or equal to 1.
-
- `mtry` is currently being interpreted as a proportion rather than a count. Supply `counts = TRUE` to `set_engine()` to supply this argument as a count rather than a proportion.
-
- See `?details_rule_fit_xrf` for more details.
+ Code
+ suppressMessages(rf_fit_3 <- fit(rf_mod_3, mpg ~ ., data = mtcars))
+ Warning
+ `early_stop` was reduced to 4.
----
+# xrf_fit guards xgb_control
Code
- pars_fit_8 <- rule_fit(mtry = 0.5, trees = 5) %>% set_engine("xrf",
- colsample_bytree = 0.5) %>% set_mode("regression") %>% fit(Sale_Price ~
- Neighborhood + Longitude + Latitude + Gr_Liv_Area + Central_Air, data = ames_data$
- ames)
+ suppressMessages(fit(rf_mod, mpg ~ ., data = mtcars))
Warning
- The following arguments cannot be manually modified and were removed: colsample_bytree.
- Error
- The supplied argument `mtry = 0.5` must be greater than or equal to 1.
+ The following arguments cannot be manually modified and were removed: xgb_control.
+ Output
+ parsnip model object
+
+ An eXtreme RuleFit model of 7 rules.
- `mtry` is currently being interpreted as a count rather than a proportion. Supply `counts = FALSE` to `set_engine()` to supply this argument as a proportion rather than a count.
+ Original Formula:
- See `?details_rule_fit_xrf` for more details.
+ mpg ~ cyl + disp + hp + drat + wt + qsec + vs + am + gear + carb
diff --git a/tests/testthat/test-rule-fit-binomial.R b/tests/testthat/test-rule-fit-binomial.R
index 2a13b4b..98b749d 100644
--- a/tests/testthat/test-rule-fit-binomial.R
+++ b/tests/testthat/test-rule-fit-binomial.R
@@ -39,7 +39,7 @@ test_that("formula method", {
rf_pred <- predict(rf_fit, ad_data$ad_pred)
rf_prob <- predict(rf_fit, ad_data$ad_pred, type = "prob")
- expect_equal(rf_fit_exp$xgb$evaluation_log, rf_fit$fit$xgb$evaluation_log)
+ expect_equal(unname(rf_fit_exp$xgb$evaluation_log), unname(rf_fit$fit$xgb$evaluation_log))
expect_equal(names(rf_pred), ".pred_class")
expect_true(tibble::is_tibble(rf_pred))
@@ -123,7 +123,7 @@ test_that("non-formula method", {
rf_pred <- predict(rf_fit, ad_data$ad_pred)
rf_prob <- predict(rf_fit, ad_data$ad_pred, type = "prob")
- expect_equal(rf_fit_exp$xgb$evaluation_log, rf_fit$fit$xgb$evaluation_log)
+ expect_equal(unname(rf_fit_exp$xgb$evaluation_log), unname(rf_fit$fit$xgb$evaluation_log))
expect_equal(names(rf_pred), ".pred_class")
expect_true(tibble::is_tibble(rf_pred))
diff --git a/tests/testthat/test-rule-fit-multinomial.R b/tests/testthat/test-rule-fit-multinomial.R
index 204a2d3..4a20153 100644
--- a/tests/testthat/test-rule-fit-multinomial.R
+++ b/tests/testthat/test-rule-fit-multinomial.R
@@ -37,7 +37,7 @@ test_that("formula method", {
rf_pred <- predict(rf_fit, hpc_data$hpc_pred)
rf_prob <- predict(rf_fit, hpc_data$hpc_pred, type = "prob")
- expect_equal(rf_fit_exp$xgb$evaluation_log, rf_fit$fit$xgb$evaluation_log)
+ expect_equal(unname(rf_fit_exp$xgb$evaluation_log), unname(rf_fit$fit$xgb$evaluation_log))
expect_equal(names(rf_pred), ".pred_class")
expect_true(tibble::is_tibble(rf_pred))
@@ -124,7 +124,7 @@ test_that("non-formula method", {
rf_pred <- predict(rf_fit, hpc_data$hpc_pred)
rf_prob <- predict(rf_fit, hpc_data$hpc_pred, type = "prob")
- expect_equal(rf_fit_exp$xgb$evaluation_log, rf_fit$fit$xgb$evaluation_log)
+ expect_equal(unname(rf_fit_exp$xgb$evaluation_log), unname(rf_fit$fit$xgb$evaluation_log))
expect_equal(names(rf_pred), ".pred_class")
expect_true(tibble::is_tibble(rf_pred))
diff --git a/tests/testthat/test-rule-fit-regression.R b/tests/testthat/test-rule-fit-regression.R
index 27be3af..a5a764f 100644
--- a/tests/testthat/test-rule-fit-regression.R
+++ b/tests/testthat/test-rule-fit-regression.R
@@ -34,7 +34,7 @@ test_that("formula method", {
)
rf_pred <- predict(rf_fit, chi_data$chi_pred)
- expect_equal(rf_fit_exp$xgb$evaluation_log, rf_fit$fit$xgb$evaluation_log)
+ expect_equal(unname(rf_fit_exp$xgb$evaluation_log), unname(rf_fit_exp$xgb$evaluation_log))
expect_equal(names(rf_pred), ".pred")
expect_true(tibble::is_tibble(rf_pred))
expect_equal(rf_pred$.pred, unname(rf_pred_exp))
@@ -89,7 +89,7 @@ test_that("non-formula method", {
)
rf_pred <- predict(rf_fit, chi_data$chi_pred)
- expect_equal(rf_fit_exp$xgb$evaluation_log, rf_fit$fit$xgb$evaluation_log)
+ expect_equal(unname(rf_fit_exp$xgb$evaluation_log), unname(rf_fit$fit$xgb$evaluation_log))
expect_equal(rf_fit_exp$glm$model$nzero, rf_fit$fit$glm$model$nzero)
expect_equal(names(rf_pred), ".pred")
expect_true(tibble::is_tibble(rf_pred))
@@ -157,162 +157,66 @@ test_that("tidy method - regression", {
)
})
-test_that("rule_fit handles mtry vs mtry_prop gracefully", {
- skip_on_cran()
- skip_if_not_installed("xrf")
+test_that("early stopping works in xrf_fit", {
+ rf_mod_1 <-
+ rule_fit(trees = 5) %>%
+ set_engine("xrf") %>%
+ set_mode("regression")
- ames_data <- make_ames_data()
+ rf_mod_2 <-
+ rule_fit(trees = 5, stop_iter = 3) %>%
+ set_engine("xrf") %>%
+ set_mode("regression")
- # supply no mtry
- expect_error_free({
- pars_fit_1 <-
- rule_fit(trees = 5) %>%
- set_engine("xrf") %>%
- set_mode("regression") %>%
- fit(
- Sale_Price ~ Neighborhood + Longitude + Latitude +
- Gr_Liv_Area + Central_Air,
- data = ames_data$ames
- )
- })
+ rf_mod_3 <-
+ rule_fit(trees = 5, stop_iter = 5) %>%
+ set_engine("xrf") %>%
+ set_mode("regression")
- expect_equal(
- extract_fit_engine(pars_fit_1)$xgb$params$colsample_bytree,
- 1
+ expect_error_free(
+ rf_fit_1 <- fit(rf_mod_1, mpg ~ ., data = mtcars)
)
- # supply mtry = 1 (edge cases)
- expect_error_free({
- pars_fit_2 <-
- rule_fit(mtry = 5, trees = 5) %>%
- set_engine("xrf", counts = TRUE) %>%
- set_mode("regression") %>%
- fit(
- Sale_Price ~ Neighborhood + Longitude + Latitude +
- Gr_Liv_Area + Central_Air,
- data = ames_data$ames
- )
- })
-
- expect_equal(
- extract_fit_engine(pars_fit_2)$xgb$params$colsample_bytree,
- 1
+ expect_error_free(
+ rf_fit_2 <- fit(rf_mod_2, mpg ~ ., data = mtcars)
)
- expect_error_free({
- pars_fit_3 <-
- rule_fit(mtry = 1, trees = 5) %>%
- set_engine("xrf", counts = FALSE) %>%
- set_mode("regression") %>%
- fit(
- Sale_Price ~ Neighborhood + Longitude + Latitude +
- Gr_Liv_Area + Central_Air,
- data = ames_data$ames
- )
- })
-
- expect_equal(
- extract_fit_engine(pars_fit_3)$xgb$params$colsample_bytree,
- 1
+ expect_snapshot(
+ suppressMessages(
+ rf_fit_3 <- fit(rf_mod_3, mpg ~ ., data = mtcars)
+ )
)
- # supply a count (with default counts = TRUE)
- expect_error_free({
- pars_fit_4 <-
- rule_fit(mtry = 5, trees = 5) %>%
- set_engine("xrf") %>%
- set_mode("regression") %>%
- fit(
- Sale_Price ~ Neighborhood + Longitude + Latitude +
- Gr_Liv_Area + Central_Air,
- data = ames_data$ames
- )
- })
+ expect_true( is.null(rf_fit_1$fit$xgb$best_iteration))
+ expect_true(!is.null(rf_fit_2$fit$xgb$best_iteration))
+ expect_true(!is.null(rf_fit_3$fit$xgb$best_iteration))
+})
- expect_equal(
- extract_fit_engine(pars_fit_4)$xgb$params$colsample_bytree,
- 1
- )
+test_that("xrf_fit is sensitive to glm_control", {
+ rf_mod <-
+ rule_fit(trees = 3) %>%
+ set_engine("xrf", glm_control = list(type.measure = "deviance", nfolds = 8)) %>%
+ set_mode("regression")
- # supply a proportion when count expected
- expect_snapshot_error({
- pars_fit_5 <-
- rule_fit(mtry = .5, trees = 5) %>%
- set_engine("xrf") %>%
- set_mode("regression") %>%
- fit(
- Sale_Price ~ Neighborhood + Longitude + Latitude +
- Gr_Liv_Area + Central_Air,
- data = ames_data$ames
- )
- })
-
- # supply a count when proportion expected
- expect_snapshot_error({
- pars_fit_6 <-
- rule_fit(mtry = 3, trees = 5) %>%
- set_engine("xrf", counts = FALSE) %>%
- set_mode("regression") %>%
- fit(
- Sale_Price ~ Neighborhood + Longitude + Latitude +
- Gr_Liv_Area + Central_Air,
- data = ames_data$ames
- )
- })
-
- expect_warning({
- pars_fit_7 <-
- rule_fit(trees = 5) %>%
- set_engine("xrf", colsample_bytree = .5) %>%
- set_mode("regression") %>%
- fit(
- Sale_Price ~ Neighborhood + Longitude + Latitude +
- Gr_Liv_Area + Central_Air,
- data = ames_data$ames
- )},
- "manually modified and were removed: colsample_bytree."
+ expect_error_free(
+ rf_fit_1 <- fit(rf_mod, mpg ~ ., data = mtcars)
)
- expect_equal(
- extract_fit_engine(pars_fit_7)$xgb$params$colsample_bytree,
- 1
- )
+ rf_fit_1_call_args <- rlang::call_args(rf_fit_1$fit$glm$model$call)
- # supply both feature fraction and mtry
- expect_snapshot({
- pars_fit_8 <-
- rule_fit(mtry = .5, trees = 5) %>%
- set_engine("xrf", colsample_bytree = .5) %>%
- set_mode("regression") %>%
- fit(
- Sale_Price ~ Neighborhood + Longitude + Latitude +
- Gr_Liv_Area + Central_Air,
- data = ames_data$ames
- )},
- error = TRUE
- )
+ expect_equal(rf_fit_1_call_args$nfolds, 8)
+ expect_equal(rf_fit_1_call_args$type.measure, "deviance")
+})
- expect_warning({
- pars_fit_9 <-
- rule_fit(mtry = 5, trees = 5) %>%
- set_engine("xrf", colsample_bytree = .5) %>%
- set_mode("regression") %>%
- fit(
- Sale_Price ~ Neighborhood + Longitude + Latitude +
- Gr_Liv_Area + Central_Air,
- data = ames_data$ames
- )},
- "manually modified and were removed: colsample_bytree."
- )
+test_that("xrf_fit guards xgb_control", {
+ rf_mod <-
+ rule_fit(trees = 3) %>%
+ set_engine("xrf", xgb_control = list(nrounds = 3)) %>%
+ set_mode("regression")
- expect_equal(
- extract_fit_engine(pars_fit_9)$xgb$params$colsample_bytree,
- 1
+ expect_snapshot(
+ suppressMessages(
+ fit(rf_mod, mpg ~ ., data = mtcars)
+ )
)
-
- # internal helper works as expected
- expect_equal(get_num_terms(mpg ~ ., mtcars), ncol(mtcars) - 1)
- expect_equal(get_num_terms(mpg ~ . + disp*drat, mtcars), ncol(mtcars))
- expect_equal(get_num_terms(mpg ~ disp, mtcars), 1)
- expect_equal(get_num_terms(mpg ~ NULL, mtcars), 1)
})