Added test for eta decay (+3 squashed commits)
Squashed commits: [9109887] Added test for eta decay(+1 squashed commit) Squashed commits: [1336bd4] Added tests for eta decay (+2 squashed commit) Squashed commits: [91aac2d] Added tests for eta decay (+1 squashed commit) Squashed commits: [3ff48e7] Added test for eta decay [6bb1eed] Rewrote Rd files [bf0dec4] Added learning_rates for diff eta in each boosting round
This commit is contained in:
@@ -2,11 +2,12 @@ context('Test models with custom objective')
|
||||
|
||||
require(xgboost)
|
||||
|
||||
data(agaricus.train, package='xgboost')
|
||||
data(agaricus.test, package='xgboost')
|
||||
dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label)
|
||||
dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label)
|
||||
|
||||
test_that("custom objective works", {
|
||||
data(agaricus.train, package='xgboost')
|
||||
data(agaricus.test, package='xgboost')
|
||||
dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label)
|
||||
dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label)
|
||||
|
||||
watchlist <- list(eval = dtest, train = dtrain)
|
||||
num_round <- 2
|
||||
@@ -44,4 +45,14 @@ test_that("custom objective works", {
|
||||
bst <- xgb.train(param, dtrain, num_round, watchlist)
|
||||
expect_equal(class(bst), "xgb.Booster")
|
||||
expect_equal(length(bst$raw), 1064)
|
||||
})
|
||||
})
|
||||
|
||||
test_that("different eta for each boosting round works", {
|
||||
num_round <- 2
|
||||
watchlist <- list(eval = dtest, train = dtrain)
|
||||
param <- list(max.depth=2, eta=1, nthread = 2, silent=1)
|
||||
|
||||
bst <- xgb.train(param, dtrain, num_round, watchlist, learning_rates = c(0.2, 0.3))
|
||||
})
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user