[R] Use built-in label when xgb.DMatrix is given to xgb.cv() (#4631)

* Use built-in label when xgb.DMatrix is given to xgb.cv()

* Add a test

* Fix test

* Bump version number
This commit is contained in:
Philip Hyunsu Cho 2019-07-03 01:32:40 -07:00 committed by GitHub
parent 986fee6022
commit 4e9fad74eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 27 additions and 5 deletions

View File

@ -1,8 +1,8 @@
Package: xgboost Package: xgboost
Type: Package Type: Package
Title: Extreme Gradient Boosting Title: Extreme Gradient Boosting
Version: 0.82.0.1 Version: 0.90.0.1
Date: 2019-03-11 Date: 2019-05-18
Authors@R: c( Authors@R: c(
person("Tianqi", "Chen", role = c("aut"), person("Tianqi", "Chen", role = c("aut"),
email = "tianqi.tchen@gmail.com"), email = "tianqi.tchen@gmail.com"),

View File

@ -133,8 +133,15 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
# Check the labels # Check the labels
if ( (inherits(data, 'xgb.DMatrix') && is.null(getinfo(data, 'label'))) || if ( (inherits(data, 'xgb.DMatrix') && is.null(getinfo(data, 'label'))) ||
(!inherits(data, 'xgb.DMatrix') && is.null(label))) (!inherits(data, 'xgb.DMatrix') && is.null(label))) {
stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix") stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix")
} else if (inherits(data, 'xgb.DMatrix')) {
if (!is.null(label))
warning("xgb.cv: label will be ignored, since data is of type xgb.DMatrix")
cv_label = getinfo(data, 'label')
} else {
cv_label = label
}
# CV folds # CV folds
if(!is.null(folds)) { if(!is.null(folds)) {
@ -144,7 +151,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
} else { } else {
if (nfold <= 1) if (nfold <= 1)
stop("'nfold' must be > 1") stop("'nfold' must be > 1")
folds <- generate.cv.folds(nfold, nrow(data), stratified, label, params) folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, params)
} }
# Potential TODO: sequential CV # Potential TODO: sequential CV

View File

@ -191,6 +191,20 @@ test_that("xgb.cv works", {
expect_false(is.null(cv$call)) expect_false(is.null(cv$call))
}) })
test_that("xgb.cv works with stratified folds", {
dtrain <- xgb.DMatrix(train$data, label = train$label)
set.seed(314159)
cv <- xgb.cv(data = dtrain, max_depth = 2, nfold = 5,
eta = 1., nthread = 2, nrounds = 2, objective = "binary:logistic",
verbose=TRUE, stratified = FALSE)
set.seed(314159)
cv2 <- xgb.cv(data = dtrain, max_depth = 2, nfold = 5,
eta = 1., nthread = 2, nrounds = 2, objective = "binary:logistic",
verbose=TRUE, stratified = TRUE)
# Stratified folds should result in a different evaluation logs
expect_true(all(cv$evaluation_log[, test_error_mean] != cv2$evaluation_log[, test_error_mean]))
})
test_that("train and predict with non-strict classes", { test_that("train and predict with non-strict classes", {
# standard dense matrix input # standard dense matrix input
train_dense <- as.matrix(train$data) train_dense <- as.matrix(train$data)

View File

@ -285,7 +285,8 @@ test_that("prediction in early-stopping xgb.cv works", {
set.seed(11) set.seed(11)
expect_output( expect_output(
cv <- xgb.cv(param, dtrain, nfold = 5, eta = 0.1, nrounds = 20, cv <- xgb.cv(param, dtrain, nfold = 5, eta = 0.1, nrounds = 20,
early_stopping_rounds = 5, maximize = FALSE, prediction = TRUE) early_stopping_rounds = 5, maximize = FALSE, stratified = FALSE,
prediction = TRUE)
, "Stopping. Best iteration") , "Stopping. Best iteration")
expect_false(is.null(cv$best_iteration)) expect_false(is.null(cv$best_iteration))