[R] Do not convert continuous labels to factors (#6380)
* [R] Do not convert continuous labels to factors * Address reviewer's comment
This commit is contained in:
committed by
GitHub
parent
3cca1c5fa1
commit
e426b6e040
@@ -20,6 +20,12 @@ NVL <- function(x, val) {
|
||||
stop("typeof(x) == ", typeof(x), " is not supported by NVL")
|
||||
}
|
||||
|
||||
# List of classification and ranking objectives
|
||||
.CLASSIFICATION_OBJECTIVES <- function() {
|
||||
return(c('binary:logistic', 'binary:logitraw', 'binary:hinge', 'multi:softmax',
|
||||
'multi:softprob', 'rank:pairwise', 'rank:ndcg', 'rank:map'))
|
||||
}
|
||||
|
||||
|
||||
#
|
||||
# Low-level functions for boosting --------------------------------------------
|
||||
@@ -187,6 +193,17 @@ xgb.iter.eval <- function(booster_handle, watchlist, iter, feval = NULL) {
|
||||
# Helper functions for cross validation ---------------------------------------
|
||||
#
|
||||
|
||||
# Possibly convert the labels into factors, depending on the objective.
|
||||
# The labels are converted into factors only when the given objective refers to the classification
|
||||
# or ranking tasks.
|
||||
convert.labels <- function(labels, objective_name) {
|
||||
if (objective_name %in% .CLASSIFICATION_OBJECTIVES()) {
|
||||
return(as.factor(labels))
|
||||
} else {
|
||||
return(labels)
|
||||
}
|
||||
}
|
||||
|
||||
# Generates random (stratified if needed) CV folds
|
||||
generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
|
||||
|
||||
@@ -206,17 +223,15 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
|
||||
# and then do stratification by factor levels.
|
||||
# - For regression, leave y numeric and do stratification by quantiles.
|
||||
if (is.character(objective)) {
|
||||
# If 'objective' provided in params, assume that y is a classification label
|
||||
# unless objective is reg:squarederror
|
||||
if (params$objective != 'reg:squarederror')
|
||||
y <- factor(y)
|
||||
y <- convert.labels(y, params$objective)
|
||||
} else {
|
||||
# If no 'objective' given in params, it means that user either wants to
|
||||
# use the default 'reg:squarederror' objective or has provided a custom
|
||||
# obj function. Here, assume classification setting when y has 5 or less
|
||||
# unique values:
|
||||
if (length(unique(y)) <= 5)
|
||||
if (length(unique(y)) <= 5) {
|
||||
y <- factor(y)
|
||||
}
|
||||
}
|
||||
folds <- xgb.createFolds(y, nfold)
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user