From c355ab65edd596597a3eacdec7caa62ecaa25c23 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 22 Apr 2020 02:19:09 +0800 Subject: [PATCH] Enable parameter validation for R. (#5569) * Enable parameter validation for R. * Add test. --- R-package/R/xgb.train.R | 2 ++ R-package/tests/testthat/test_basic.R | 45 +++++++++++++++++++++++---- R-package/tests/testthat/test_glm.R | 2 +- doc/parameter.rst | 2 +- 4 files changed, 43 insertions(+), 8 deletions(-) diff --git a/R-package/R/xgb.train.R b/R-package/R/xgb.train.R index b2b226559..8733bcce4 100644 --- a/R-package/R/xgb.train.R +++ b/R-package/R/xgb.train.R @@ -291,8 +291,10 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(), callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds, maximize = maximize, verbose = verbose)) } + # Sort the callbacks into categories cb <- categorize.callbacks(callbacks) + params['validate_parameters'] <- TRUE if (!is.null(params[['seed']])) { warning("xgb.train: `seed` is ignored in R package. Use `set.seed()` instead.") } diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index 7d6918027..b23e4dd70 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -35,6 +35,40 @@ test_that("train and predict binary classification", { expect_lt(abs(err_pred1 - err_log), 10e-6) }) +test_that("parameter validation works", { + p <- list(foo = "bar") + nrounds = 1 + set.seed(1994) + + d <- cbind( + x1 = rnorm(10), + x2 = rnorm(10), + x3 = rnorm(10)) + y <- d[,"x1"] + d[,"x2"]^2 + + ifelse(d[,"x3"] > .5, d[,"x3"]^2, 2^d[,"x3"]) + + rnorm(10) + dtrain <- xgb.DMatrix(data=d, info = list(label=y)) + + correct <- function() { + params <- list(max_depth = 2, booster = "dart", + rate_drop = 0.5, one_drop = TRUE, + objective = "reg:squarederror") + xgb.train(params = params, data = dtrain, nrounds = nrounds) + } + expect_silent(correct()) + incorrect <- function() { + params <- list(max_depth = 2, booster = "dart", + rate_drop = 0.5, one_drop = TRUE, + objective = "reg:squarederror", + foo = "bar", bar = "foo") + output <- capture.output( + xgb.train(params = params, data = dtrain, nrounds = nrounds)) + print(output) + } + expect_output(incorrect(), "bar, foo") +}) + + test_that("dart prediction works", { nrounds = 32 set.seed(1994) @@ -68,7 +102,6 @@ test_that("dart prediction works", { one_drop = TRUE, nthread = 1, tree_method= "exact", - verbosity = 3, objective = "reg:squarederror" ), data = dtrain, @@ -324,15 +357,15 @@ test_that("colsample_bytree works", { test_y <- as.numeric(rowSums(test_x) > 0) colnames(train_x) <- paste0("Feature_", sprintf("%03d", 1:100)) colnames(test_x) <- paste0("Feature_", sprintf("%03d", 1:100)) - dtrain <- xgb.DMatrix(train_x, label = train_y) + dtrain <- xgb.DMatrix(train_x, label = train_y) dtest <- xgb.DMatrix(test_x, label = test_y) watchlist <- list(train = dtrain, eval = dtest) - # Use colsample_bytree = 0.01, so that roughly one out of 100 features is - # chosen for each tree - param <- list(max_depth = 2, eta = 0, verbosity = 0, nthread = 2, + ## Use colsample_bytree = 0.01, so that roughly one out of 100 features is chosen for + ## each tree + param <- list(max_depth = 2, eta = 0, nthread = 2, colsample_bytree = 0.01, objective = "binary:logistic", eval_metric = "auc") - set.seed(2) + set.seed(2) bst <- xgb.train(param, dtrain, nrounds = 100, watchlist, verbose = 0) xgb.importance(model = bst) # If colsample_bytree works properly, a variety of features should be used diff --git a/R-package/tests/testthat/test_glm.R b/R-package/tests/testthat/test_glm.R index 7293e1ada..9b4aa73ad 100644 --- a/R-package/tests/testthat/test_glm.R +++ b/R-package/tests/testthat/test_glm.R @@ -40,7 +40,7 @@ test_that("gblinear works", { expect_lt(bst$evaluation_log$eval_error[2], ERR_UL) bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'thrifty', - top_n = 50, callbacks = list(cb.gblinear.history(sparse = TRUE))) + top_k = 50, callbacks = list(cb.gblinear.history(sparse = TRUE))) expect_lt(bst$evaluation_log$eval_error[n], ERR_UL) h <- xgb.gblinear.history(bst) expect_equal(dim(h), c(n, ncol(dtrain) + 1)) diff --git a/doc/parameter.rst b/doc/parameter.rst index a7c0479ae..99c70dcac 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -30,7 +30,7 @@ General Parameters is displayed as warning message. If there's unexpected behaviour, please try to increase value of verbosity. -* ``validate_parameters`` [default to false, except for Python interface] +* ``validate_parameters`` [default to false, except for Python and R interface] - When set to True, XGBoost will perform validation of input parameters to check whether a parameter is used or not. The feature is still experimental. It's expected to have