add demo for early_stopping in R

This commit is contained in:
hetong007
2015-05-06 15:14:29 -07:00
parent 0f182b0b66
commit 419e4dbda6
2 changed files with 29 additions and 50 deletions

View File

@@ -139,31 +139,34 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
params = append(params, list(...))
# Early stopping
if (!is.null(feval) && is.null(maximize) && !is.null(earlyStopRound))
stop('Please set maximize to note whether the model is maximizing the evaluation or not.')
if (length(watchlist) == 0 && !is.null(earlyStopRound))
stop('For early stopping you need at least one set in watchlist.')
if (is.null(maximize) && is.null(params$eval_metric) && !is.null(earlyStopRound))
stop('Please set maximize to note whether the model is maximizing the evaluation or not.')
if (is.null(maximize))
{
if (params$eval_metric %in% c('rmse','logloss','error','merror','mlogloss')) {
maximize = FALSE
} else {
maximize = TRUE
if (!is.null(earlyStopRound)){
if (!is.null(feval) && is.null(maximize))
stop('Please set maximize to note whether the model is maximizing the evaluation or not.')
if (length(watchlist) == 0)
stop('For early stopping you need at least one set in watchlist.')
if (is.null(maximize) && is.null(params$eval_metric))
stop('Please set maximize to note whether the model is maximizing the evaluation or not.')
if (is.null(maximize))
{
if (params$eval_metric %in% c('rmse','logloss','error','merror','mlogloss')) {
maximize = FALSE
} else {
maximize = TRUE
}
}
if (maximize) {
bestScore = 0
} else {
bestScore = Inf
}
bestInd = 0
earlyStopflag = FALSE
if (length(watchlist)>1)
warning('Only the first data set in watchlist is used for early stopping process.')
}
if (maximize) {
bestScore = 0
} else {
bestScore = Inf
}
bestInd = 0
earlyStopflag = FALSE
if (length(watchlist)>1 && !is.null(earlyStopRound))
warning('Only the first data set in watchlist is used for early stopping process.')
handle <- xgb.Booster(params, append(watchlist, dtrain))
bst <- xgb.handleToBooster(handle)
@@ -174,8 +177,7 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
cat(paste(msg, "\n", sep=""))
if (!is.null(earlyStopRound))
{
score = strsplit(msg,'\\s+')[[1]][1]
score = strsplit(score,':')[[1]][2]
score = strsplit(msg,':|\\s+')[[1]][3]
score = as.numeric(score)
if ((maximize && score>bestScore) || (!maximize && score<bestScore)) {
bestScore = score
@@ -183,14 +185,12 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
} else {
if (i-bestInd>earlyStopRound) {
earlyStopflag = TRUE
cat('Stopping. Best iteration:',bestInd)
break
}
}
}
}
if (earlyStopflag) {
cat('Stopping. Best iteration:',bestInd)
break
}
}
bst <- xgb.Booster.check(bst)
if (!is.null(earlyStopRound)) {