Enable parameter validation for R. (#5569)
* Enable parameter validation for R. * Add test.
This commit is contained in:
parent
564b22cee5
commit
c355ab65ed
@ -291,8 +291,10 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
|
||||
callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds,
|
||||
maximize = maximize, verbose = verbose))
|
||||
}
|
||||
|
||||
# Sort the callbacks into categories
|
||||
cb <- categorize.callbacks(callbacks)
|
||||
params['validate_parameters'] <- TRUE
|
||||
if (!is.null(params[['seed']])) {
|
||||
warning("xgb.train: `seed` is ignored in R package. Use `set.seed()` instead.")
|
||||
}
|
||||
|
||||
@ -35,6 +35,40 @@ test_that("train and predict binary classification", {
|
||||
expect_lt(abs(err_pred1 - err_log), 10e-6)
|
||||
})
|
||||
|
||||
test_that("parameter validation works", {
|
||||
p <- list(foo = "bar")
|
||||
nrounds = 1
|
||||
set.seed(1994)
|
||||
|
||||
d <- cbind(
|
||||
x1 = rnorm(10),
|
||||
x2 = rnorm(10),
|
||||
x3 = rnorm(10))
|
||||
y <- d[,"x1"] + d[,"x2"]^2 +
|
||||
ifelse(d[,"x3"] > .5, d[,"x3"]^2, 2^d[,"x3"]) +
|
||||
rnorm(10)
|
||||
dtrain <- xgb.DMatrix(data=d, info = list(label=y))
|
||||
|
||||
correct <- function() {
|
||||
params <- list(max_depth = 2, booster = "dart",
|
||||
rate_drop = 0.5, one_drop = TRUE,
|
||||
objective = "reg:squarederror")
|
||||
xgb.train(params = params, data = dtrain, nrounds = nrounds)
|
||||
}
|
||||
expect_silent(correct())
|
||||
incorrect <- function() {
|
||||
params <- list(max_depth = 2, booster = "dart",
|
||||
rate_drop = 0.5, one_drop = TRUE,
|
||||
objective = "reg:squarederror",
|
||||
foo = "bar", bar = "foo")
|
||||
output <- capture.output(
|
||||
xgb.train(params = params, data = dtrain, nrounds = nrounds))
|
||||
print(output)
|
||||
}
|
||||
expect_output(incorrect(), "bar, foo")
|
||||
})
|
||||
|
||||
|
||||
test_that("dart prediction works", {
|
||||
nrounds = 32
|
||||
set.seed(1994)
|
||||
@ -68,7 +102,6 @@ test_that("dart prediction works", {
|
||||
one_drop = TRUE,
|
||||
nthread = 1,
|
||||
tree_method= "exact",
|
||||
verbosity = 3,
|
||||
objective = "reg:squarederror"
|
||||
),
|
||||
data = dtrain,
|
||||
@ -324,15 +357,15 @@ test_that("colsample_bytree works", {
|
||||
test_y <- as.numeric(rowSums(test_x) > 0)
|
||||
colnames(train_x) <- paste0("Feature_", sprintf("%03d", 1:100))
|
||||
colnames(test_x) <- paste0("Feature_", sprintf("%03d", 1:100))
|
||||
dtrain <- xgb.DMatrix(train_x, label = train_y)
|
||||
dtrain <- xgb.DMatrix(train_x, label = train_y)
|
||||
dtest <- xgb.DMatrix(test_x, label = test_y)
|
||||
watchlist <- list(train = dtrain, eval = dtest)
|
||||
# Use colsample_bytree = 0.01, so that roughly one out of 100 features is
|
||||
# chosen for each tree
|
||||
param <- list(max_depth = 2, eta = 0, verbosity = 0, nthread = 2,
|
||||
## Use colsample_bytree = 0.01, so that roughly one out of 100 features is chosen for
|
||||
## each tree
|
||||
param <- list(max_depth = 2, eta = 0, nthread = 2,
|
||||
colsample_bytree = 0.01, objective = "binary:logistic",
|
||||
eval_metric = "auc")
|
||||
set.seed(2)
|
||||
set.seed(2)
|
||||
bst <- xgb.train(param, dtrain, nrounds = 100, watchlist, verbose = 0)
|
||||
xgb.importance(model = bst)
|
||||
# If colsample_bytree works properly, a variety of features should be used
|
||||
|
||||
@ -40,7 +40,7 @@ test_that("gblinear works", {
|
||||
expect_lt(bst$evaluation_log$eval_error[2], ERR_UL)
|
||||
|
||||
bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'thrifty',
|
||||
top_n = 50, callbacks = list(cb.gblinear.history(sparse = TRUE)))
|
||||
top_k = 50, callbacks = list(cb.gblinear.history(sparse = TRUE)))
|
||||
expect_lt(bst$evaluation_log$eval_error[n], ERR_UL)
|
||||
h <- xgb.gblinear.history(bst)
|
||||
expect_equal(dim(h), c(n, ncol(dtrain) + 1))
|
||||
|
||||
@ -30,7 +30,7 @@ General Parameters
|
||||
is displayed as warning message. If there's unexpected behaviour, please try to
|
||||
increase value of verbosity.
|
||||
|
||||
* ``validate_parameters`` [default to false, except for Python interface]
|
||||
* ``validate_parameters`` [default to false, except for Python and R interface]
|
||||
|
||||
- When set to True, XGBoost will perform validation of input parameters to check whether
|
||||
a parameter is used or not. The feature is still experimental. It's expected to have
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user