[R] Do not convert continuous labels to factors (#6380)
* [R] Do not convert continuous labels to factors * Address reviewer's comment
This commit is contained in:
parent
3cca1c5fa1
commit
e426b6e040
@ -20,6 +20,12 @@ NVL <- function(x, val) {
|
|||||||
stop("typeof(x) == ", typeof(x), " is not supported by NVL")
|
stop("typeof(x) == ", typeof(x), " is not supported by NVL")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# List of classification and ranking objectives
|
||||||
|
.CLASSIFICATION_OBJECTIVES <- function() {
|
||||||
|
return(c('binary:logistic', 'binary:logitraw', 'binary:hinge', 'multi:softmax',
|
||||||
|
'multi:softprob', 'rank:pairwise', 'rank:ndcg', 'rank:map'))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Low-level functions for boosting --------------------------------------------
|
# Low-level functions for boosting --------------------------------------------
|
||||||
@ -187,6 +193,17 @@ xgb.iter.eval <- function(booster_handle, watchlist, iter, feval = NULL) {
|
|||||||
# Helper functions for cross validation ---------------------------------------
|
# Helper functions for cross validation ---------------------------------------
|
||||||
#
|
#
|
||||||
|
|
||||||
|
# Possibly convert the labels into factors, depending on the objective.
|
||||||
|
# The labels are converted into factors only when the given objective refers to the classification
|
||||||
|
# or ranking tasks.
|
||||||
|
convert.labels <- function(labels, objective_name) {
|
||||||
|
if (objective_name %in% .CLASSIFICATION_OBJECTIVES()) {
|
||||||
|
return(as.factor(labels))
|
||||||
|
} else {
|
||||||
|
return(labels)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# Generates random (stratified if needed) CV folds
|
# Generates random (stratified if needed) CV folds
|
||||||
generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
|
generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
|
||||||
|
|
||||||
@ -206,17 +223,15 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
|
|||||||
# and then do stratification by factor levels.
|
# and then do stratification by factor levels.
|
||||||
# - For regression, leave y numeric and do stratification by quantiles.
|
# - For regression, leave y numeric and do stratification by quantiles.
|
||||||
if (is.character(objective)) {
|
if (is.character(objective)) {
|
||||||
# If 'objective' provided in params, assume that y is a classification label
|
y <- convert.labels(y, params$objective)
|
||||||
# unless objective is reg:squarederror
|
|
||||||
if (params$objective != 'reg:squarederror')
|
|
||||||
y <- factor(y)
|
|
||||||
} else {
|
} else {
|
||||||
# If no 'objective' given in params, it means that user either wants to
|
# If no 'objective' given in params, it means that user either wants to
|
||||||
# use the default 'reg:squarederror' objective or has provided a custom
|
# use the default 'reg:squarederror' objective or has provided a custom
|
||||||
# obj function. Here, assume classification setting when y has 5 or less
|
# obj function. Here, assume classification setting when y has 5 or less
|
||||||
# unique values:
|
# unique values:
|
||||||
if (length(unique(y)) <= 5)
|
if (length(unique(y)) <= 5) {
|
||||||
y <- factor(y)
|
y <- factor(y)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
folds <- xgb.createFolds(y, nfold)
|
folds <- xgb.createFolds(y, nfold)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -410,3 +410,26 @@ test_that("check.deprecation works", {
|
|||||||
, "\'dumm\' was partially matched to \'dummy\'")
|
, "\'dumm\' was partially matched to \'dummy\'")
|
||||||
expect_equal(res, list(a = 1, DUMMY = 22))
|
expect_equal(res, list(a = 1, DUMMY = 22))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that('convert.labels works', {
|
||||||
|
y <- c(0, 1, 0, 0, 1)
|
||||||
|
for (objective in c('binary:logistic', 'binary:logitraw', 'binary:hinge')) {
|
||||||
|
res <- xgboost:::convert.labels(y, objective_name = objective)
|
||||||
|
expect_s3_class(res, 'factor')
|
||||||
|
expect_equal(res, factor(res))
|
||||||
|
}
|
||||||
|
y <- c(0, 1, 3, 2, 1, 4)
|
||||||
|
for (objective in c('multi:softmax', 'multi:softprob', 'rank:pairwise', 'rank:ndcg',
|
||||||
|
'rank:map')) {
|
||||||
|
res <- xgboost:::convert.labels(y, objective_name = objective)
|
||||||
|
expect_s3_class(res, 'factor')
|
||||||
|
expect_equal(res, factor(res))
|
||||||
|
}
|
||||||
|
y <- c(1.2, 3.0, -1.0, 10.0)
|
||||||
|
for (objective in c('reg:squarederror', 'reg:squaredlogerror', 'reg:logistic',
|
||||||
|
'reg:pseudohubererror', 'count:poisson', 'survival:cox', 'survival:aft',
|
||||||
|
'reg:gamma', 'reg:tweedie')) {
|
||||||
|
res <- xgboost:::convert.labels(y, objective_name = objective)
|
||||||
|
expect_equal(class(res), 'numeric')
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user