Merge pull request #914 from catena/master

R: fix "bestInd" and add "best_ntreelimit" to xgb.Booster
This commit is contained in:
Tong He 2016-06-13 11:13:00 -07:00 committed by GitHub
commit 9cb872b879
2 changed files with 10 additions and 5 deletions

View File

@ -196,9 +196,9 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
score <- as.numeric(score) score <- as.numeric(score)
if ( (maximize && score > bestScore) || (!maximize && score < bestScore)) { if ( (maximize && score > bestScore) || (!maximize && score < bestScore)) {
bestScore <- score bestScore <- score
bestInd <- i bestInd <- i - 1
} else { } else {
if (i - bestInd >= early.stop.round) { if (i - bestInd > early.stop.round) {
earlyStopflag <- TRUE earlyStopflag <- TRUE
cat('Stopping. Best iteration:', bestInd, '\n') cat('Stopping. Best iteration:', bestInd, '\n')
break break
@ -211,7 +211,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
for (k in 1:nfold) { for (k in 1:nfold) {
fd <- xgb_folds[[k]] fd <- xgb_folds[[k]]
if (!is.null(early.stop.round) && earlyStopflag) { if (!is.null(early.stop.round) && earlyStopflag) {
res <- xgb.iter.eval(fd$booster, fd$watchlist, bestInd - 1, feval, prediction) res <- xgb.iter.eval(fd$booster, fd$watchlist, bestInd, feval, prediction)
} else { } else {
res <- xgb.iter.eval(fd$booster, fd$watchlist, nrounds - 1, feval, prediction) res <- xgb.iter.eval(fd$booster, fd$watchlist, nrounds - 1, feval, prediction)
} }

View File

@ -208,10 +208,10 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
score <- as.numeric(score) score <- as.numeric(score)
if ( (maximize && score > bestScore) || (!maximize && score < bestScore)) { if ( (maximize && score > bestScore) || (!maximize && score < bestScore)) {
bestScore <- score bestScore <- score
bestInd <- i bestInd <- i - 1
} else { } else {
earlyStopflag = TRUE earlyStopflag = TRUE
if (i - bestInd >= early.stop.round) { if (i - bestInd > early.stop.round) {
cat('Stopping. Best iteration:', bestInd, '\n') cat('Stopping. Best iteration:', bestInd, '\n')
break break
} }
@ -229,6 +229,11 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
if (!is.null(early.stop.round)) { if (!is.null(early.stop.round)) {
bst$bestScore <- bestScore bst$bestScore <- bestScore
bst$bestInd <- bestInd bst$bestInd <- bestInd
if (!is.null(params$num_parallel_tree)) {
bst$best_ntreelimit <- (bst$bestInd + 1) * params$num_parallel_tree
} else {
bst$best_ntreelimit <- bst$bestInd + 1
}
} }
attr(bst, "call") <- fit.call attr(bst, "call") <- fit.call