[R] Do not convert continuous labels to factors (#6380)
* [R] Do not convert continuous labels to factors * Address reviewer's comment
This commit is contained in:
committed by
GitHub
parent
3cca1c5fa1
commit
e426b6e040
@@ -410,3 +410,26 @@ test_that("check.deprecation works", {
|
||||
, "\'dumm\' was partially matched to \'dummy\'")
|
||||
expect_equal(res, list(a = 1, DUMMY = 22))
|
||||
})
|
||||
|
||||
test_that('convert.labels works', {
|
||||
y <- c(0, 1, 0, 0, 1)
|
||||
for (objective in c('binary:logistic', 'binary:logitraw', 'binary:hinge')) {
|
||||
res <- xgboost:::convert.labels(y, objective_name = objective)
|
||||
expect_s3_class(res, 'factor')
|
||||
expect_equal(res, factor(res))
|
||||
}
|
||||
y <- c(0, 1, 3, 2, 1, 4)
|
||||
for (objective in c('multi:softmax', 'multi:softprob', 'rank:pairwise', 'rank:ndcg',
|
||||
'rank:map')) {
|
||||
res <- xgboost:::convert.labels(y, objective_name = objective)
|
||||
expect_s3_class(res, 'factor')
|
||||
expect_equal(res, factor(res))
|
||||
}
|
||||
y <- c(1.2, 3.0, -1.0, 10.0)
|
||||
for (objective in c('reg:squarederror', 'reg:squaredlogerror', 'reg:logistic',
|
||||
'reg:pseudohubererror', 'count:poisson', 'survival:cox', 'survival:aft',
|
||||
'reg:gamma', 'reg:tweedie')) {
|
||||
res <- xgboost:::convert.labels(y, objective_name = objective)
|
||||
expect_equal(class(res), 'numeric')
|
||||
}
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user