@@ -46,3 +46,31 @@ test_that("gblinear works", {
|
||||
expect_equal(dim(h), c(n, ncol(dtrain) + 1))
|
||||
expect_s4_class(h, "dgCMatrix")
|
||||
})
|
||||
|
||||
test_that("gblinear early stopping works", {
|
||||
data(agaricus.train, package = 'xgboost')
|
||||
data(agaricus.test, package = 'xgboost')
|
||||
dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label)
|
||||
dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label)
|
||||
|
||||
param <- list(
|
||||
objective = "binary:logistic", eval_metric = "error", booster = "gblinear",
|
||||
nthread = 2, eta = 0.8, alpha = 0.0001, lambda = 0.0001,
|
||||
updater = "coord_descent"
|
||||
)
|
||||
|
||||
es_round <- 1
|
||||
n <- 10
|
||||
booster <- xgb.train(
|
||||
param, dtrain, n, list(eval = dtest, train = dtrain), early_stopping_rounds = es_round
|
||||
)
|
||||
expect_equal(booster$best_iteration, 5)
|
||||
predt_es <- predict(booster, dtrain)
|
||||
|
||||
n <- booster$best_iteration + es_round
|
||||
booster <- xgb.train(
|
||||
param, dtrain, n, list(eval = dtest, train = dtrain), early_stopping_rounds = es_round
|
||||
)
|
||||
predt <- predict(booster, dtrain)
|
||||
expect_equal(predt_es, predt)
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user