Merge pull request #914 from catena/master
R: fix "bestInd" and add "best_ntreelimit" to xgb.Booster
This commit is contained in:
commit
9cb872b879
@ -196,9 +196,9 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
|
|||||||
score <- as.numeric(score)
|
score <- as.numeric(score)
|
||||||
if ( (maximize && score > bestScore) || (!maximize && score < bestScore)) {
|
if ( (maximize && score > bestScore) || (!maximize && score < bestScore)) {
|
||||||
bestScore <- score
|
bestScore <- score
|
||||||
bestInd <- i
|
bestInd <- i - 1
|
||||||
} else {
|
} else {
|
||||||
if (i - bestInd >= early.stop.round) {
|
if (i - bestInd > early.stop.round) {
|
||||||
earlyStopflag <- TRUE
|
earlyStopflag <- TRUE
|
||||||
cat('Stopping. Best iteration:', bestInd, '\n')
|
cat('Stopping. Best iteration:', bestInd, '\n')
|
||||||
break
|
break
|
||||||
@ -211,7 +211,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
|
|||||||
for (k in 1:nfold) {
|
for (k in 1:nfold) {
|
||||||
fd <- xgb_folds[[k]]
|
fd <- xgb_folds[[k]]
|
||||||
if (!is.null(early.stop.round) && earlyStopflag) {
|
if (!is.null(early.stop.round) && earlyStopflag) {
|
||||||
res <- xgb.iter.eval(fd$booster, fd$watchlist, bestInd - 1, feval, prediction)
|
res <- xgb.iter.eval(fd$booster, fd$watchlist, bestInd, feval, prediction)
|
||||||
} else {
|
} else {
|
||||||
res <- xgb.iter.eval(fd$booster, fd$watchlist, nrounds - 1, feval, prediction)
|
res <- xgb.iter.eval(fd$booster, fd$watchlist, nrounds - 1, feval, prediction)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -208,10 +208,10 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
|
|||||||
score <- as.numeric(score)
|
score <- as.numeric(score)
|
||||||
if ( (maximize && score > bestScore) || (!maximize && score < bestScore)) {
|
if ( (maximize && score > bestScore) || (!maximize && score < bestScore)) {
|
||||||
bestScore <- score
|
bestScore <- score
|
||||||
bestInd <- i
|
bestInd <- i - 1
|
||||||
} else {
|
} else {
|
||||||
earlyStopflag = TRUE
|
earlyStopflag = TRUE
|
||||||
if (i - bestInd >= early.stop.round) {
|
if (i - bestInd > early.stop.round) {
|
||||||
cat('Stopping. Best iteration:', bestInd, '\n')
|
cat('Stopping. Best iteration:', bestInd, '\n')
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -229,6 +229,11 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
|
|||||||
if (!is.null(early.stop.round)) {
|
if (!is.null(early.stop.round)) {
|
||||||
bst$bestScore <- bestScore
|
bst$bestScore <- bestScore
|
||||||
bst$bestInd <- bestInd
|
bst$bestInd <- bestInd
|
||||||
|
if (!is.null(params$num_parallel_tree)) {
|
||||||
|
bst$best_ntreelimit <- (bst$bestInd + 1) * params$num_parallel_tree
|
||||||
|
} else {
|
||||||
|
bst$best_ntreelimit <- bst$bestInd + 1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
attr(bst, "call") <- fit.call
|
attr(bst, "call") <- fit.call
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user