[R] Do not convert continuous labels to factors (#6380)

* [R] Do not convert continuous labels to factors

* Address reviewer's comment
This commit is contained in:
Philip Hyunsu Cho
2020-11-17 09:19:16 -08:00
committed by GitHub
parent 3cca1c5fa1
commit e426b6e040
2 changed files with 43 additions and 5 deletions

View File

@@ -20,6 +20,12 @@ NVL <- function(x, val) {
stop("typeof(x) == ", typeof(x), " is not supported by NVL")
}
# List of classification and ranking objectives
.CLASSIFICATION_OBJECTIVES <- function() {
return(c('binary:logistic', 'binary:logitraw', 'binary:hinge', 'multi:softmax',
'multi:softprob', 'rank:pairwise', 'rank:ndcg', 'rank:map'))
}
#
# Low-level functions for boosting --------------------------------------------
@@ -187,6 +193,17 @@ xgb.iter.eval <- function(booster_handle, watchlist, iter, feval = NULL) {
# Helper functions for cross validation ---------------------------------------
#
# Possibly convert the labels into factors, depending on the objective.
# The labels are converted into factors only when the given objective refers to the classification
# or ranking tasks.
convert.labels <- function(labels, objective_name) {
if (objective_name %in% .CLASSIFICATION_OBJECTIVES()) {
return(as.factor(labels))
} else {
return(labels)
}
}
# Generates random (stratified if needed) CV folds
generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
@@ -206,17 +223,15 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
# and then do stratification by factor levels.
# - For regression, leave y numeric and do stratification by quantiles.
if (is.character(objective)) {
# If 'objective' provided in params, assume that y is a classification label
# unless objective is reg:squarederror
if (params$objective != 'reg:squarederror')
y <- factor(y)
y <- convert.labels(y, params$objective)
} else {
# If no 'objective' given in params, it means that user either wants to
# use the default 'reg:squarederror' objective or has provided a custom
# obj function. Here, assume classification setting when y has 5 or less
# unique values:
if (length(unique(y)) <= 5)
if (length(unique(y)) <= 5) {
y <- factor(y)
}
}
folds <- xgb.createFolds(y, nfold)
} else {