Minor addition to R unit tests
This commit is contained in:
@@ -6,14 +6,15 @@ data(agaricus.train, package='xgboost')
|
||||
data(agaricus.test, package='xgboost')
|
||||
train <- agaricus.train
|
||||
test <- agaricus.test
|
||||
set.seed(1994)
|
||||
|
||||
test_that("train and predict", {
|
||||
bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
|
||||
eta = 1, nthread = 2, nround = 2, objective = "binary:logistic")
|
||||
pred <- predict(bst, test$data)
|
||||
expect_equal(length(pred), 1611)
|
||||
})
|
||||
|
||||
|
||||
test_that("early stopping", {
|
||||
res <- xgb.cv(data = train$data, label = train$label, max.depth = 2, nfold = 5,
|
||||
eta = 0.3, nthread = 2, nround = 20, objective = "binary:logistic",
|
||||
@@ -23,6 +24,7 @@ test_that("early stopping", {
|
||||
eta = 0.3, nthread = 2, nround = 20, objective = "binary:logistic",
|
||||
early.stop.round = 3, maximize = FALSE)
|
||||
pred <- predict(bst, test$data)
|
||||
expect_equal(length(pred), 1611)
|
||||
})
|
||||
|
||||
test_that("save_period", {
|
||||
@@ -30,4 +32,5 @@ test_that("save_period", {
|
||||
eta = 0.3, nthread = 2, nround = 20, objective = "binary:logistic",
|
||||
save_period = 10, save_name = "xgb.model")
|
||||
pred <- predict(bst, test$data)
|
||||
expect_equal(length(pred), 1611)
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user