[R] Do not convert continuous labels to factors (#6380)

* [R] Do not convert continuous labels to factors

* Address reviewer's comment
This commit is contained in:
Philip Hyunsu Cho 2020-11-17 09:19:16 -08:00 committed by GitHub
parent 3cca1c5fa1
commit e426b6e040
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 5 deletions

View File

@ -20,6 +20,12 @@ NVL <- function(x, val) {
stop("typeof(x) == ", typeof(x), " is not supported by NVL") stop("typeof(x) == ", typeof(x), " is not supported by NVL")
} }
# List of classification and ranking objectives
.CLASSIFICATION_OBJECTIVES <- function() {
return(c('binary:logistic', 'binary:logitraw', 'binary:hinge', 'multi:softmax',
'multi:softprob', 'rank:pairwise', 'rank:ndcg', 'rank:map'))
}
# #
# Low-level functions for boosting -------------------------------------------- # Low-level functions for boosting --------------------------------------------
@ -187,6 +193,17 @@ xgb.iter.eval <- function(booster_handle, watchlist, iter, feval = NULL) {
# Helper functions for cross validation --------------------------------------- # Helper functions for cross validation ---------------------------------------
# #
# Possibly convert the labels into factors, depending on the objective.
# The labels are converted into factors only when the given objective refers to the classification
# or ranking tasks.
convert.labels <- function(labels, objective_name) {
if (objective_name %in% .CLASSIFICATION_OBJECTIVES()) {
return(as.factor(labels))
} else {
return(labels)
}
}
# Generates random (stratified if needed) CV folds # Generates random (stratified if needed) CV folds
generate.cv.folds <- function(nfold, nrows, stratified, label, params) { generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
@ -206,17 +223,15 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
# and then do stratification by factor levels. # and then do stratification by factor levels.
# - For regression, leave y numeric and do stratification by quantiles. # - For regression, leave y numeric and do stratification by quantiles.
if (is.character(objective)) { if (is.character(objective)) {
# If 'objective' provided in params, assume that y is a classification label y <- convert.labels(y, params$objective)
# unless objective is reg:squarederror
if (params$objective != 'reg:squarederror')
y <- factor(y)
} else { } else {
# If no 'objective' given in params, it means that user either wants to # If no 'objective' given in params, it means that user either wants to
# use the default 'reg:squarederror' objective or has provided a custom # use the default 'reg:squarederror' objective or has provided a custom
# obj function. Here, assume classification setting when y has 5 or less # obj function. Here, assume classification setting when y has 5 or less
# unique values: # unique values:
if (length(unique(y)) <= 5) if (length(unique(y)) <= 5) {
y <- factor(y) y <- factor(y)
}
} }
folds <- xgb.createFolds(y, nfold) folds <- xgb.createFolds(y, nfold)
} else { } else {

View File

@ -410,3 +410,26 @@ test_that("check.deprecation works", {
, "\'dumm\' was partially matched to \'dummy\'") , "\'dumm\' was partially matched to \'dummy\'")
expect_equal(res, list(a = 1, DUMMY = 22)) expect_equal(res, list(a = 1, DUMMY = 22))
}) })
test_that('convert.labels works', {
y <- c(0, 1, 0, 0, 1)
for (objective in c('binary:logistic', 'binary:logitraw', 'binary:hinge')) {
res <- xgboost:::convert.labels(y, objective_name = objective)
expect_s3_class(res, 'factor')
expect_equal(res, factor(res))
}
y <- c(0, 1, 3, 2, 1, 4)
for (objective in c('multi:softmax', 'multi:softprob', 'rank:pairwise', 'rank:ndcg',
'rank:map')) {
res <- xgboost:::convert.labels(y, objective_name = objective)
expect_s3_class(res, 'factor')
expect_equal(res, factor(res))
}
y <- c(1.2, 3.0, -1.0, 10.0)
for (objective in c('reg:squarederror', 'reg:squaredlogerror', 'reg:logistic',
'reg:pseudohubererror', 'count:poisson', 'survival:cox', 'survival:aft',
'reg:gamma', 'reg:tweedie')) {
res <- xgboost:::convert.labels(y, objective_name = objective)
expect_equal(class(res), 'numeric')
}
})