add demo for early_stopping in R

This commit is contained in:
hetong007 2015-05-06 15:14:29 -07:00
parent 0f182b0b66
commit 419e4dbda6
2 changed files with 29 additions and 50 deletions

View File

@ -139,11 +139,12 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
params = append(params, list(...)) params = append(params, list(...))
# Early stopping # Early stopping
if (!is.null(feval) && is.null(maximize) && !is.null(earlyStopRound)) if (!is.null(earlyStopRound)){
if (!is.null(feval) && is.null(maximize))
stop('Please set maximize to note whether the model is maximizing the evaluation or not.') stop('Please set maximize to note whether the model is maximizing the evaluation or not.')
if (length(watchlist) == 0 && !is.null(earlyStopRound)) if (length(watchlist) == 0)
stop('For early stopping you need at least one set in watchlist.') stop('For early stopping you need at least one set in watchlist.')
if (is.null(maximize) && is.null(params$eval_metric) && !is.null(earlyStopRound)) if (is.null(maximize) && is.null(params$eval_metric))
stop('Please set maximize to note whether the model is maximizing the evaluation or not.') stop('Please set maximize to note whether the model is maximizing the evaluation or not.')
if (is.null(maximize)) if (is.null(maximize))
{ {
@ -162,8 +163,10 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
bestInd = 0 bestInd = 0
earlyStopflag = FALSE earlyStopflag = FALSE
if (length(watchlist)>1 && !is.null(earlyStopRound)) if (length(watchlist)>1)
warning('Only the first data set in watchlist is used for early stopping process.') warning('Only the first data set in watchlist is used for early stopping process.')
}
handle <- xgb.Booster(params, append(watchlist, dtrain)) handle <- xgb.Booster(params, append(watchlist, dtrain))
bst <- xgb.handleToBooster(handle) bst <- xgb.handleToBooster(handle)
@ -174,8 +177,7 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
cat(paste(msg, "\n", sep="")) cat(paste(msg, "\n", sep=""))
if (!is.null(earlyStopRound)) if (!is.null(earlyStopRound))
{ {
score = strsplit(msg,'\\s+')[[1]][1] score = strsplit(msg,':|\\s+')[[1]][3]
score = strsplit(score,':')[[1]][2]
score = as.numeric(score) score = as.numeric(score)
if ((maximize && score>bestScore) || (!maximize && score<bestScore)) { if ((maximize && score>bestScore) || (!maximize && score<bestScore)) {
bestScore = score bestScore = score
@ -183,15 +185,13 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(),
} else { } else {
if (i-bestInd>earlyStopRound) { if (i-bestInd>earlyStopRound) {
earlyStopflag = TRUE earlyStopflag = TRUE
}
}
}
}
if (earlyStopflag) {
cat('Stopping. Best iteration:',bestInd) cat('Stopping. Best iteration:',bestInd)
break break
} }
} }
}
}
}
bst <- xgb.Booster.check(bst) bst <- xgb.Booster.check(bst)
if (!is.null(earlyStopRound)) { if (!is.null(earlyStopRound)) {
bst$bestScore = bestScore bst$bestScore = bestScore

View File

@ -30,29 +30,8 @@ evalerror <- function(preds, dtrain) {
err <- as.numeric(sum(labels != (preds > 0)))/length(labels) err <- as.numeric(sum(labels != (preds > 0)))/length(labels)
return(list(metric = "error", value = err)) return(list(metric = "error", value = err))
} }
print ('start training with user customized objective') print ('start training with early Stopping setting')
# training with customized objective, we can also do step by step training # training with customized objective, we can also do step by step training
# simply look at xgboost.py's implementation of train # simply look at xgboost.py's implementation of train
bst <- xgb.train(param, dtrain, num_round, watchlist, logregobj, evalerror, maximize = FALSE, bst <- xgb.train(param, dtrain, num_round, watchlist, logregobj, evalerror, maximize = FALSE,
earlyStopRound = 3) earlyStopRound = 3)
#
# there can be cases where you want additional information
# being considered besides the property of DMatrix you can get by getinfo
# you can set additional information as attributes if DMatrix
# set label attribute of dtrain to be label, we use label as an example, it can be anything
attr(dtrain, 'label') <- getinfo(dtrain, 'label')
# this is new customized objective, where you can access things you set
# same thing applies to customized evaluation function
logregobjattr <- function(preds, dtrain) {
# now you can access the attribute in customized function
labels <- attr(dtrain, 'label')
preds <- 1/(1 + exp(-preds))
grad <- preds - labels
hess <- preds * (1 - preds)
return(list(grad = grad, hess = hess))
}
print ('start training with user customized objective, with additional attributes in DMatrix')
# training with customized objective, we can also do step by step training
# simply look at xgboost.py's implementation of train
bst <- xgb.train(param, dtrain, num_round, watchlist, logregobjattr, evalerror, maximize = FALSE,
earlyStopRound = 3)