Enable parameter validation for R. (#5569)

* Enable parameter validation for R.

* Add test.
This commit is contained in:
Jiaming Yuan 2020-04-22 02:19:09 +08:00 committed by GitHub
parent 564b22cee5
commit c355ab65ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 43 additions and 8 deletions

View File

@ -291,8 +291,10 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds, callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds,
maximize = maximize, verbose = verbose)) maximize = maximize, verbose = verbose))
} }
# Sort the callbacks into categories # Sort the callbacks into categories
cb <- categorize.callbacks(callbacks) cb <- categorize.callbacks(callbacks)
params['validate_parameters'] <- TRUE
if (!is.null(params[['seed']])) { if (!is.null(params[['seed']])) {
warning("xgb.train: `seed` is ignored in R package. Use `set.seed()` instead.") warning("xgb.train: `seed` is ignored in R package. Use `set.seed()` instead.")
} }

View File

@ -35,6 +35,40 @@ test_that("train and predict binary classification", {
expect_lt(abs(err_pred1 - err_log), 10e-6) expect_lt(abs(err_pred1 - err_log), 10e-6)
}) })
test_that("parameter validation works", {
p <- list(foo = "bar")
nrounds = 1
set.seed(1994)
d <- cbind(
x1 = rnorm(10),
x2 = rnorm(10),
x3 = rnorm(10))
y <- d[,"x1"] + d[,"x2"]^2 +
ifelse(d[,"x3"] > .5, d[,"x3"]^2, 2^d[,"x3"]) +
rnorm(10)
dtrain <- xgb.DMatrix(data=d, info = list(label=y))
correct <- function() {
params <- list(max_depth = 2, booster = "dart",
rate_drop = 0.5, one_drop = TRUE,
objective = "reg:squarederror")
xgb.train(params = params, data = dtrain, nrounds = nrounds)
}
expect_silent(correct())
incorrect <- function() {
params <- list(max_depth = 2, booster = "dart",
rate_drop = 0.5, one_drop = TRUE,
objective = "reg:squarederror",
foo = "bar", bar = "foo")
output <- capture.output(
xgb.train(params = params, data = dtrain, nrounds = nrounds))
print(output)
}
expect_output(incorrect(), "bar, foo")
})
test_that("dart prediction works", { test_that("dart prediction works", {
nrounds = 32 nrounds = 32
set.seed(1994) set.seed(1994)
@ -68,7 +102,6 @@ test_that("dart prediction works", {
one_drop = TRUE, one_drop = TRUE,
nthread = 1, nthread = 1,
tree_method= "exact", tree_method= "exact",
verbosity = 3,
objective = "reg:squarederror" objective = "reg:squarederror"
), ),
data = dtrain, data = dtrain,
@ -324,15 +357,15 @@ test_that("colsample_bytree works", {
test_y <- as.numeric(rowSums(test_x) > 0) test_y <- as.numeric(rowSums(test_x) > 0)
colnames(train_x) <- paste0("Feature_", sprintf("%03d", 1:100)) colnames(train_x) <- paste0("Feature_", sprintf("%03d", 1:100))
colnames(test_x) <- paste0("Feature_", sprintf("%03d", 1:100)) colnames(test_x) <- paste0("Feature_", sprintf("%03d", 1:100))
dtrain <- xgb.DMatrix(train_x, label = train_y) dtrain <- xgb.DMatrix(train_x, label = train_y)
dtest <- xgb.DMatrix(test_x, label = test_y) dtest <- xgb.DMatrix(test_x, label = test_y)
watchlist <- list(train = dtrain, eval = dtest) watchlist <- list(train = dtrain, eval = dtest)
# Use colsample_bytree = 0.01, so that roughly one out of 100 features is ## Use colsample_bytree = 0.01, so that roughly one out of 100 features is chosen for
# chosen for each tree ## each tree
param <- list(max_depth = 2, eta = 0, verbosity = 0, nthread = 2, param <- list(max_depth = 2, eta = 0, nthread = 2,
colsample_bytree = 0.01, objective = "binary:logistic", colsample_bytree = 0.01, objective = "binary:logistic",
eval_metric = "auc") eval_metric = "auc")
set.seed(2) set.seed(2)
bst <- xgb.train(param, dtrain, nrounds = 100, watchlist, verbose = 0) bst <- xgb.train(param, dtrain, nrounds = 100, watchlist, verbose = 0)
xgb.importance(model = bst) xgb.importance(model = bst)
# If colsample_bytree works properly, a variety of features should be used # If colsample_bytree works properly, a variety of features should be used

View File

@ -40,7 +40,7 @@ test_that("gblinear works", {
expect_lt(bst$evaluation_log$eval_error[2], ERR_UL) expect_lt(bst$evaluation_log$eval_error[2], ERR_UL)
bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'thrifty', bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'thrifty',
top_n = 50, callbacks = list(cb.gblinear.history(sparse = TRUE))) top_k = 50, callbacks = list(cb.gblinear.history(sparse = TRUE)))
expect_lt(bst$evaluation_log$eval_error[n], ERR_UL) expect_lt(bst$evaluation_log$eval_error[n], ERR_UL)
h <- xgb.gblinear.history(bst) h <- xgb.gblinear.history(bst)
expect_equal(dim(h), c(n, ncol(dtrain) + 1)) expect_equal(dim(h), c(n, ncol(dtrain) + 1))

View File

@ -30,7 +30,7 @@ General Parameters
is displayed as warning message. If there's unexpected behaviour, please try to is displayed as warning message. If there's unexpected behaviour, please try to
increase value of verbosity. increase value of verbosity.
* ``validate_parameters`` [default to false, except for Python interface] * ``validate_parameters`` [default to false, except for Python and R interface]
- When set to True, XGBoost will perform validation of input parameters to check whether - When set to True, XGBoost will perform validation of input parameters to check whether
a parameter is used or not. The feature is still experimental. It's expected to have a parameter is used or not. The feature is still experimental. It's expected to have