From 661c062bd927ddf9f99452d2dd61f98dbea9fe86 Mon Sep 17 00:00:00 2001 From: catena Date: Sat, 5 Mar 2016 17:35:42 +0530 Subject: [PATCH] add best_ntreelimit attribute --- R-package/R/xgb.cv.R | 6 +++--- R-package/R/xgb.train.R | 9 +++++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/R-package/R/xgb.cv.R b/R-package/R/xgb.cv.R index e3faf33a0..c61cdbc5b 100644 --- a/R-package/R/xgb.cv.R +++ b/R-package/R/xgb.cv.R @@ -196,9 +196,9 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = score <- as.numeric(score) if ( (maximize && score > bestScore) || (!maximize && score < bestScore)) { bestScore <- score - bestInd <- i + bestInd <- i - 1 } else { - if (i - bestInd >= early.stop.round) { + if (i - bestInd > early.stop.round) { earlyStopflag <- TRUE cat('Stopping. Best iteration:', bestInd, '\n') break @@ -211,7 +211,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = for (k in 1:nfold) { fd <- xgb_folds[[k]] if (!is.null(early.stop.round) && earlyStopflag) { - res <- xgb.iter.eval(fd$booster, fd$watchlist, bestInd - 1, feval, prediction) + res <- xgb.iter.eval(fd$booster, fd$watchlist, bestInd, feval, prediction) } else { res <- xgb.iter.eval(fd$booster, fd$watchlist, nrounds - 1, feval, prediction) } diff --git a/R-package/R/xgb.train.R b/R-package/R/xgb.train.R index 92b5aea10..3868ddf89 100644 --- a/R-package/R/xgb.train.R +++ b/R-package/R/xgb.train.R @@ -208,10 +208,10 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(), score <- as.numeric(score) if ( (maximize && score > bestScore) || (!maximize && score < bestScore)) { bestScore <- score - bestInd <- i + bestInd <- i - 1 } else { earlyStopflag = TRUE - if (i - bestInd >= early.stop.round) { + if (i - bestInd > early.stop.round) { cat('Stopping. Best iteration:', bestInd, '\n') break } @@ -229,6 +229,11 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(), if (!is.null(early.stop.round)) { bst$bestScore <- bestScore bst$bestInd <- bestInd + if (!is.null(params$num_parallel_tree)) { + bst$best_ntreelimit <- (bst$bestInd + 1) * params$num_parallel_tree + } else { + bst$best_ntreelimit <- bst$bestInd + 1 + } } attr(bst, "call") <- fit.call