[R] Do not convert continuous labels to factors (#6380)
* [R] Do not convert continuous labels to factors * Address reviewer's comment
This commit is contained in:
parent
3cca1c5fa1
commit
e426b6e040
@ -20,6 +20,12 @@ NVL <- function(x, val) {
|
||||
stop("typeof(x) == ", typeof(x), " is not supported by NVL")
|
||||
}
|
||||
|
||||
# List of classification and ranking objectives
|
||||
.CLASSIFICATION_OBJECTIVES <- function() {
|
||||
return(c('binary:logistic', 'binary:logitraw', 'binary:hinge', 'multi:softmax',
|
||||
'multi:softprob', 'rank:pairwise', 'rank:ndcg', 'rank:map'))
|
||||
}
|
||||
|
||||
|
||||
#
|
||||
# Low-level functions for boosting --------------------------------------------
|
||||
@ -187,6 +193,17 @@ xgb.iter.eval <- function(booster_handle, watchlist, iter, feval = NULL) {
|
||||
# Helper functions for cross validation ---------------------------------------
|
||||
#
|
||||
|
||||
# Possibly convert the labels into factors, depending on the objective.
|
||||
# The labels are converted into factors only when the given objective refers to the classification
|
||||
# or ranking tasks.
|
||||
convert.labels <- function(labels, objective_name) {
|
||||
if (objective_name %in% .CLASSIFICATION_OBJECTIVES()) {
|
||||
return(as.factor(labels))
|
||||
} else {
|
||||
return(labels)
|
||||
}
|
||||
}
|
||||
|
||||
# Generates random (stratified if needed) CV folds
|
||||
generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
|
||||
|
||||
@ -206,17 +223,15 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
|
||||
# and then do stratification by factor levels.
|
||||
# - For regression, leave y numeric and do stratification by quantiles.
|
||||
if (is.character(objective)) {
|
||||
# If 'objective' provided in params, assume that y is a classification label
|
||||
# unless objective is reg:squarederror
|
||||
if (params$objective != 'reg:squarederror')
|
||||
y <- factor(y)
|
||||
y <- convert.labels(y, params$objective)
|
||||
} else {
|
||||
# If no 'objective' given in params, it means that user either wants to
|
||||
# use the default 'reg:squarederror' objective or has provided a custom
|
||||
# obj function. Here, assume classification setting when y has 5 or less
|
||||
# unique values:
|
||||
if (length(unique(y)) <= 5)
|
||||
if (length(unique(y)) <= 5) {
|
||||
y <- factor(y)
|
||||
}
|
||||
}
|
||||
folds <- xgb.createFolds(y, nfold)
|
||||
} else {
|
||||
|
||||
@ -410,3 +410,26 @@ test_that("check.deprecation works", {
|
||||
, "\'dumm\' was partially matched to \'dummy\'")
|
||||
expect_equal(res, list(a = 1, DUMMY = 22))
|
||||
})
|
||||
|
||||
test_that('convert.labels works', {
|
||||
y <- c(0, 1, 0, 0, 1)
|
||||
for (objective in c('binary:logistic', 'binary:logitraw', 'binary:hinge')) {
|
||||
res <- xgboost:::convert.labels(y, objective_name = objective)
|
||||
expect_s3_class(res, 'factor')
|
||||
expect_equal(res, factor(res))
|
||||
}
|
||||
y <- c(0, 1, 3, 2, 1, 4)
|
||||
for (objective in c('multi:softmax', 'multi:softprob', 'rank:pairwise', 'rank:ndcg',
|
||||
'rank:map')) {
|
||||
res <- xgboost:::convert.labels(y, objective_name = objective)
|
||||
expect_s3_class(res, 'factor')
|
||||
expect_equal(res, factor(res))
|
||||
}
|
||||
y <- c(1.2, 3.0, -1.0, 10.0)
|
||||
for (objective in c('reg:squarederror', 'reg:squaredlogerror', 'reg:logistic',
|
||||
'reg:pseudohubererror', 'count:poisson', 'survival:cox', 'survival:aft',
|
||||
'reg:gamma', 'reg:tweedie')) {
|
||||
res <- xgboost:::convert.labels(y, objective_name = objective)
|
||||
expect_equal(class(res), 'numeric')
|
||||
}
|
||||
})
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user