TST: Added more checks for testing custom objective
This commit is contained in:
parent
886955148d
commit
3a49e1bdb1
@ -28,6 +28,8 @@ test_that("custom objective works", {
|
|||||||
objective=logregobj, eval_metric=evalerror)
|
objective=logregobj, eval_metric=evalerror)
|
||||||
|
|
||||||
bst <- xgb.train(param, dtrain, num_round, watchlist)
|
bst <- xgb.train(param, dtrain, num_round, watchlist)
|
||||||
|
expect_equal(class(bst), "xgb.Booster")
|
||||||
|
expect_equal(length(bst$raw), 1064)
|
||||||
attr(dtrain, 'label') <- getinfo(dtrain, 'label')
|
attr(dtrain, 'label') <- getinfo(dtrain, 'label')
|
||||||
|
|
||||||
logregobjattr <- function(preds, dtrain) {
|
logregobjattr <- function(preds, dtrain) {
|
||||||
@ -40,4 +42,6 @@ test_that("custom objective works", {
|
|||||||
param <- list(max.depth=2, eta=1, nthread = 2, silent=1,
|
param <- list(max.depth=2, eta=1, nthread = 2, silent=1,
|
||||||
objective=logregobjattr, eval_metric=evalerror)
|
objective=logregobjattr, eval_metric=evalerror)
|
||||||
bst <- xgb.train(param, dtrain, num_round, watchlist)
|
bst <- xgb.train(param, dtrain, num_round, watchlist)
|
||||||
|
expect_equal(class(bst), "xgb.Booster")
|
||||||
|
expect_equal(length(bst$raw), 1064)
|
||||||
})
|
})
|
||||||
Loading…
x
Reference in New Issue
Block a user