add early stopping to xgb.cv

This commit is contained in:
hetong007 2015-05-11 16:03:40 -07:00
parent 60d307c445
commit 83ace55f51
3 changed files with 70 additions and 3 deletions

View File

@ -54,6 +54,13 @@
#' @param folds \code{list} provides a possibility of using a list of pre-defined CV folds (each element must be a vector of fold's indices). #' @param folds \code{list} provides a possibility of using a list of pre-defined CV folds (each element must be a vector of fold's indices).
#' If folds are supplied, the nfold and stratified parameters would be ignored. #' If folds are supplied, the nfold and stratified parameters would be ignored.
#' @param verbose \code{boolean}, print the statistics during the process #' @param verbose \code{boolean}, print the statistics during the process
#' @param early_stop_round If \code{NULL}, the early stopping function is not triggered.
#' If set to an integer \code{k}, training with a validation set will stop if the performance
#' keeps getting worse consecutively for \code{k} rounds.
#' @param early.stop.round An alternative of \code{early_stop_round}.
#' @param maximize If \code{feval} and \code{early_stop_round} are set, then \code{maximize} must be set as well.
#' \code{maximize=TRUE} means the larger the evaluation score the better.
#'
#' @param ... other parameters to pass to \code{params}. #' @param ... other parameters to pass to \code{params}.
#' #'
#' @return #' @return
@ -86,7 +93,8 @@
#' #'
xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NULL, xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = NULL,
prediction = FALSE, showsd = TRUE, metrics=list(), prediction = FALSE, showsd = TRUE, metrics=list(),
obj = NULL, feval = NULL, stratified = TRUE, folds = NULL, verbose = T,...) { obj = NULL, feval = NULL, stratified = TRUE, folds = NULL, verbose = T,
early_stop_round = NULL, early.stop.round = NULL, maximize = NULL, ...) {
if (typeof(params) != "list") { if (typeof(params) != "list") {
stop("xgb.cv: first argument params must be list") stop("xgb.cv: first argument params must be list")
} }
@ -110,6 +118,35 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
params <- append(params, list("eval_metric"=mc)) params <- append(params, list("eval_metric"=mc))
} }
# Early Stopping
if (is.null(early_stop_round) && !is.null(early.stop.round))
early_stop_round = early.stop.round
if (!is.null(early_stop_round)){
if (!is.null(feval) && is.null(maximize))
stop('Please set maximize to note whether the model is maximizing the evaluation or not.')
if (is.null(maximize) && is.null(params$eval_metric))
stop('Please set maximize to note whether the model is maximizing the evaluation or not.')
if (is.null(maximize))
{
if (params$eval_metric %in% c('rmse','logloss','error','merror','mlogloss')) {
maximize = FALSE
} else {
maximize = TRUE
}
}
if (maximize) {
bestScore = 0
} else {
bestScore = Inf
}
bestInd = 0
earlyStopflag = FALSE
if (length(metrics)>1)
warning('Only the first metric is used for early stopping process.')
}
xgb_folds <- xgb.cv.mknfold(dtrain, nfold, params, stratified, folds) xgb_folds <- xgb.cv.mknfold(dtrain, nfold, params, stratified, folds)
obj_type = params[['objective']] obj_type = params[['objective']]
mat_pred = FALSE mat_pred = FALSE
@ -149,6 +186,24 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
ret <- xgb.cv.aggcv(msg, showsd) ret <- xgb.cv.aggcv(msg, showsd)
history <- c(history, ret) history <- c(history, ret)
if(verbose) paste(ret, "\n", sep="") %>% cat if(verbose) paste(ret, "\n", sep="") %>% cat
# early_Stopping
if (!is.null(early_stop_round)){
score = strsplit(ret,'\\s+')[[1]][1+length(metrics)+1]
score = strsplit(score,'\\+|:')[[1]][[2]]
score = as.numeric(score)
if ((maximize && score>bestScore) || (!maximize && score<bestScore)) {
bestScore = score
bestInd = i
} else {
if (i-bestInd>early_stop_round) {
earlyStopflag = TRUE
cat('Stopping. Best iteration:',bestInd)
break
}
}
}
} }
colnames <- str_split(string = history[1], pattern = "\t")[[1]] %>% .[2:length(.)] %>% str_extract(".*:") %>% str_replace(":","") %>% str_replace("-", ".") colnames <- str_split(string = history[1], pattern = "\t")[[1]] %>% .[2:length(.)] %>% str_extract(".*:") %>% str_replace(":","") %>% str_replace("-", ".")

View File

@ -35,3 +35,5 @@ print ('start training with early Stopping setting')
# simply look at xgboost.py's implementation of train # simply look at xgboost.py's implementation of train
bst <- xgb.train(param, dtrain, num_round, watchlist, logregobj, evalerror, maximize = FALSE, bst <- xgb.train(param, dtrain, num_round, watchlist, logregobj, evalerror, maximize = FALSE,
earlyStopRound = 3) earlyStopRound = 3)
bst <- xgb.cv(param, dtrain, num_round, nfold=5, obj=logregobj, feval = evalerror,
maximize = FALSE, earlyStopRound = 3)

View File

@ -7,7 +7,8 @@
xgb.cv(params = list(), data, nrounds, nfold, label = NULL, xgb.cv(params = list(), data, nrounds, nfold, label = NULL,
missing = NULL, prediction = FALSE, showsd = TRUE, metrics = list(), missing = NULL, prediction = FALSE, showsd = TRUE, metrics = list(),
obj = NULL, feval = NULL, stratified = TRUE, folds = NULL, obj = NULL, feval = NULL, stratified = TRUE, folds = NULL,
verbose = T, ...) verbose = T, early_stop_round = NULL, early.stop.round = NULL,
maximize = NULL, ...)
} }
\arguments{ \arguments{
\item{params}{the list of parameters. Commonly used ones are: \item{params}{the list of parameters. Commonly used ones are:
@ -65,6 +66,15 @@ If folds are supplied, the nfold and stratified parameters would be ignored.}
\item{verbose}{\code{boolean}, print the statistics during the process} \item{verbose}{\code{boolean}, print the statistics during the process}
\item{early_stop_round}{If \code{NULL}, the early stopping function is not triggered.
If set to an integer \code{k}, training with a validation set will stop if the performance
keeps getting worse consecutively for \code{k} rounds.}
\item{early.stop.round}{An alternative of \code{early_stop_round}.}
\item{maximize}{If \code{feval} and \code{early_stop_round} are set, then \code{maximize} must be set as well.
\code{maximize=TRUE} means the larger the evaluation score the better.}
\item{...}{other parameters to pass to \code{params}.} \item{...}{other parameters to pass to \code{params}.}
} }
\value{ \value{