diff --git a/R-package/R/utils.R b/R-package/R/utils.R index a677fb197..c468c331b 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -20,6 +20,12 @@ NVL <- function(x, val) { stop("typeof(x) == ", typeof(x), " is not supported by NVL") } +# List of classification and ranking objectives +.CLASSIFICATION_OBJECTIVES <- function() { + return(c('binary:logistic', 'binary:logitraw', 'binary:hinge', 'multi:softmax', + 'multi:softprob', 'rank:pairwise', 'rank:ndcg', 'rank:map')) +} + # # Low-level functions for boosting -------------------------------------------- @@ -187,6 +193,17 @@ xgb.iter.eval <- function(booster_handle, watchlist, iter, feval = NULL) { # Helper functions for cross validation --------------------------------------- # +# Possibly convert the labels into factors, depending on the objective. +# The labels are converted into factors only when the given objective refers to the classification +# or ranking tasks. +convert.labels <- function(labels, objective_name) { + if (objective_name %in% .CLASSIFICATION_OBJECTIVES()) { + return(as.factor(labels)) + } else { + return(labels) + } +} + # Generates random (stratified if needed) CV folds generate.cv.folds <- function(nfold, nrows, stratified, label, params) { @@ -206,17 +223,15 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) { # and then do stratification by factor levels. # - For regression, leave y numeric and do stratification by quantiles. if (is.character(objective)) { - # If 'objective' provided in params, assume that y is a classification label - # unless objective is reg:squarederror - if (params$objective != 'reg:squarederror') - y <- factor(y) + y <- convert.labels(y, params$objective) } else { # If no 'objective' given in params, it means that user either wants to # use the default 'reg:squarederror' objective or has provided a custom # obj function. Here, assume classification setting when y has 5 or less # unique values: - if (length(unique(y)) <= 5) + if (length(unique(y)) <= 5) { y <- factor(y) + } } folds <- xgb.createFolds(y, nfold) } else { diff --git a/R-package/tests/testthat/test_helpers.R b/R-package/tests/testthat/test_helpers.R index b0a85a9fe..5638f70cb 100644 --- a/R-package/tests/testthat/test_helpers.R +++ b/R-package/tests/testthat/test_helpers.R @@ -410,3 +410,26 @@ test_that("check.deprecation works", { , "\'dumm\' was partially matched to \'dummy\'") expect_equal(res, list(a = 1, DUMMY = 22)) }) + +test_that('convert.labels works', { + y <- c(0, 1, 0, 0, 1) + for (objective in c('binary:logistic', 'binary:logitraw', 'binary:hinge')) { + res <- xgboost:::convert.labels(y, objective_name = objective) + expect_s3_class(res, 'factor') + expect_equal(res, factor(res)) + } + y <- c(0, 1, 3, 2, 1, 4) + for (objective in c('multi:softmax', 'multi:softprob', 'rank:pairwise', 'rank:ndcg', + 'rank:map')) { + res <- xgboost:::convert.labels(y, objective_name = objective) + expect_s3_class(res, 'factor') + expect_equal(res, factor(res)) + } + y <- c(1.2, 3.0, -1.0, 10.0) + for (objective in c('reg:squarederror', 'reg:squaredlogerror', 'reg:logistic', + 'reg:pseudohubererror', 'count:poisson', 'survival:cox', 'survival:aft', + 'reg:gamma', 'reg:tweedie')) { + res <- xgboost:::convert.labels(y, objective_name = objective) + expect_equal(class(res), 'numeric') + } +})