Improved logic in stratified CV to guess class/regr
Somewhat more robust and clear logic in stratified CV to guess classification/regression settings. Allows to accomodate custom objectives (classification is assumed when number of unique values in labels <= 5).
This commit is contained in:
parent
bab7b58d94
commit
f325930bd9
@ -211,6 +211,7 @@ xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL, prediction = F
|
|||||||
}
|
}
|
||||||
return(msg)
|
return(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
#------------------------------------------
|
#------------------------------------------
|
||||||
# helper functions for cross validation
|
# helper functions for cross validation
|
||||||
#
|
#
|
||||||
@ -223,11 +224,23 @@ xgb.cv.mknfold <- function(dall, nfold, param, stratified, folds) {
|
|||||||
randidx <- sample(1 : xgb.numrow(dall))
|
randidx <- sample(1 : xgb.numrow(dall))
|
||||||
if (stratified & length(y) == length(randidx)) {
|
if (stratified & length(y) == length(randidx)) {
|
||||||
y <- y[randidx]
|
y <- y[randidx]
|
||||||
# By default assume that y is a classification label,
|
#
|
||||||
# and only leave it numeric for the reg:linear objective.
|
# WARNING: some heuristic logic is employed to identify classification setting!
|
||||||
# WARNING: if there would be any other objectives with truly
|
#
|
||||||
# numerical labels, they currently would not be treated correctly!
|
# For classification, need to convert y labels to factor before making the folds,
|
||||||
if (param[['objective']] != 'reg:linear') y <- factor(y)
|
# and then do stratification by factor levels.
|
||||||
|
# For regression, leave y numeric and do stratification by quantiles.
|
||||||
|
n_uniq <- length(unique(y))
|
||||||
|
if (exists('objective', where=param)) {
|
||||||
|
# If 'objective' provided in params, assume that y is a classification label
|
||||||
|
# unless objective is reg:linear
|
||||||
|
if (param[['objective']] != 'reg:linear') y <- factor(y)
|
||||||
|
} else {
|
||||||
|
# If no 'objective' given in params, it means that user either wants to use
|
||||||
|
# the default 'reg:linear' objective or has provided a custom obj function.
|
||||||
|
# Here, assume classification setting when y has 5 or less unique values:
|
||||||
|
if (length(unique(y)) <= 5) y <- factor(y)
|
||||||
|
}
|
||||||
folds <- xgb.createFolds(y, nfold)
|
folds <- xgb.createFolds(y, nfold)
|
||||||
} else {
|
} else {
|
||||||
# make simple non-stratified folds
|
# make simple non-stratified folds
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user