Fix r early stop with custom objective. (#5923)

* Specify `ntreelimit`.
This commit is contained in:
Jiaming Yuan 2020-07-23 03:28:17 +08:00 committed by GitHub
parent 30363d9c35
commit bc1d3ee230
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 4 deletions

View File

@ -145,7 +145,8 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj = NULL) {
if (is.null(obj)) { if (is.null(obj)) {
.Call(XGBoosterUpdateOneIter_R, booster_handle, as.integer(iter), dtrain) .Call(XGBoosterUpdateOneIter_R, booster_handle, as.integer(iter), dtrain)
} else { } else {
pred <- predict(booster_handle, dtrain, outputmargin = TRUE, training = TRUE) pred <- predict(booster_handle, dtrain, outputmargin = TRUE, training = TRUE,
ntreelimit = 0)
gpair <- obj(pred, dtrain) gpair <- obj(pred, dtrain)
.Call(XGBoosterBoostOneIter_R, booster_handle, dtrain, gpair$grad, gpair$hess) .Call(XGBoosterBoostOneIter_R, booster_handle, dtrain, gpair$grad, gpair$hess)
} }
@ -172,7 +173,7 @@ xgb.iter.eval <- function(booster_handle, watchlist, iter, feval = NULL) {
} else { } else {
res <- sapply(seq_along(watchlist), function(j) { res <- sapply(seq_along(watchlist), function(j) {
w <- watchlist[[j]] w <- watchlist[[j]]
preds <- predict(booster_handle, w) # predict using all trees preds <- predict(booster_handle, w, ntreelimit = 0) # predict using all trees
eval_res <- feval(preds, w) eval_res <- feval(preds, w)
out <- eval_res$value out <- eval_res$value
names(out) <- paste0(evnames[j], "-", eval_res$metric) names(out) <- paste0(evnames[j], "-", eval_res$metric)

View File

@ -20,7 +20,7 @@ logregobj <- function(preds, dtrain) {
evalerror <- function(preds, dtrain) { evalerror <- function(preds, dtrain) {
labels <- getinfo(dtrain, "label") labels <- getinfo(dtrain, "label")
err <- as.numeric(sum(labels != (preds > 0))) / length(labels) err <- as.numeric(sum(labels != (preds > 0.5))) / length(labels)
return(list(metric = "error", value = err)) return(list(metric = "error", value = err))
} }
@ -43,6 +43,13 @@ test_that("custom objective in CV works", {
expect_lt(cv$evaluation_log[num_round, test_error_mean], 0.03) expect_lt(cv$evaluation_log[num_round, test_error_mean], 0.03)
}) })
test_that("custom objective with early stop works", {
bst <- xgb.train(param, dtrain, 10, watchlist)
expect_equal(class(bst), "xgb.Booster")
train_log <- bst$evaluation_log$train_error
expect_true(all(diff(train_log)) <= 0)
})
test_that("custom objective using DMatrix attr works", { test_that("custom objective using DMatrix attr works", {
attr(dtrain, 'label') <- getinfo(dtrain, 'label') attr(dtrain, 'label') <- getinfo(dtrain, 'label')

View File

@ -210,7 +210,7 @@ class TestModels(unittest.TestCase):
def evalerror(preds, dtrain): def evalerror(preds, dtrain):
labels = dtrain.get_label() labels = dtrain.get_label()
return 'error', float(sum(labels != (preds > 0.0))) / len(labels) return 'error', float(sum(labels != (preds > 0.5))) / len(labels)
# test custom_objective in training # test custom_objective in training
bst = xgb.train(param, dtrain, num_round, watchlist, logregobj, evalerror) bst = xgb.train(param, dtrain, num_round, watchlist, logregobj, evalerror)