add demo for early_stopping in R
This commit is contained in:
@@ -139,31 +139,34 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
|
||||
params = append(params, list(...))
|
||||
|
||||
# Early stopping
|
||||
if (!is.null(feval) && is.null(maximize) && !is.null(earlyStopRound))
|
||||
stop('Please set maximize to note whether the model is maximizing the evaluation or not.')
|
||||
if (length(watchlist) == 0 && !is.null(earlyStopRound))
|
||||
stop('For early stopping you need at least one set in watchlist.')
|
||||
if (is.null(maximize) && is.null(params$eval_metric) && !is.null(earlyStopRound))
|
||||
stop('Please set maximize to note whether the model is maximizing the evaluation or not.')
|
||||
if (is.null(maximize))
|
||||
{
|
||||
if (params$eval_metric %in% c('rmse','logloss','error','merror','mlogloss')) {
|
||||
maximize = FALSE
|
||||
} else {
|
||||
maximize = TRUE
|
||||
if (!is.null(earlyStopRound)){
|
||||
if (!is.null(feval) && is.null(maximize))
|
||||
stop('Please set maximize to note whether the model is maximizing the evaluation or not.')
|
||||
if (length(watchlist) == 0)
|
||||
stop('For early stopping you need at least one set in watchlist.')
|
||||
if (is.null(maximize) && is.null(params$eval_metric))
|
||||
stop('Please set maximize to note whether the model is maximizing the evaluation or not.')
|
||||
if (is.null(maximize))
|
||||
{
|
||||
if (params$eval_metric %in% c('rmse','logloss','error','merror','mlogloss')) {
|
||||
maximize = FALSE
|
||||
} else {
|
||||
maximize = TRUE
|
||||
}
|
||||
}
|
||||
|
||||
if (maximize) {
|
||||
bestScore = 0
|
||||
} else {
|
||||
bestScore = Inf
|
||||
}
|
||||
bestInd = 0
|
||||
earlyStopflag = FALSE
|
||||
|
||||
if (length(watchlist)>1)
|
||||
warning('Only the first data set in watchlist is used for early stopping process.')
|
||||
}
|
||||
|
||||
if (maximize) {
|
||||
bestScore = 0
|
||||
} else {
|
||||
bestScore = Inf
|
||||
}
|
||||
bestInd = 0
|
||||
earlyStopflag = FALSE
|
||||
|
||||
if (length(watchlist)>1 && !is.null(earlyStopRound))
|
||||
warning('Only the first data set in watchlist is used for early stopping process.')
|
||||
|
||||
handle <- xgb.Booster(params, append(watchlist, dtrain))
|
||||
bst <- xgb.handleToBooster(handle)
|
||||
@@ -174,8 +177,7 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
|
||||
cat(paste(msg, "\n", sep=""))
|
||||
if (!is.null(earlyStopRound))
|
||||
{
|
||||
score = strsplit(msg,'\\s+')[[1]][1]
|
||||
score = strsplit(score,':')[[1]][2]
|
||||
score = strsplit(msg,':|\\s+')[[1]][3]
|
||||
score = as.numeric(score)
|
||||
if ((maximize && score>bestScore) || (!maximize && score<bestScore)) {
|
||||
bestScore = score
|
||||
@@ -183,14 +185,12 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
|
||||
} else {
|
||||
if (i-bestInd>earlyStopRound) {
|
||||
earlyStopflag = TRUE
|
||||
cat('Stopping. Best iteration:',bestInd)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (earlyStopflag) {
|
||||
cat('Stopping. Best iteration:',bestInd)
|
||||
break
|
||||
}
|
||||
}
|
||||
bst <- xgb.Booster.check(bst)
|
||||
if (!is.null(earlyStopRound)) {
|
||||
|
||||
Reference in New Issue
Block a user