From 4e9fad74eb4df6a04758a6b919826388c00530f3 Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Wed, 3 Jul 2019 01:32:40 -0700 Subject: [PATCH] [R] Use built-in label when xgb.DMatrix is given to xgb.cv() (#4631) * Use built-in label when xgb.DMatrix is given to xgb.cv() * Add a test * Fix test * Bump version number --- R-package/DESCRIPTION | 4 ++-- R-package/R/xgb.cv.R | 11 +++++++++-- R-package/tests/testthat/test_basic.R | 14 ++++++++++++++ R-package/tests/testthat/test_callbacks.R | 3 ++- 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/R-package/DESCRIPTION b/R-package/DESCRIPTION index b3c340b02..aec055704 100644 --- a/R-package/DESCRIPTION +++ b/R-package/DESCRIPTION @@ -1,8 +1,8 @@ Package: xgboost Type: Package Title: Extreme Gradient Boosting -Version: 0.82.0.1 -Date: 2019-03-11 +Version: 0.90.0.1 +Date: 2019-05-18 Authors@R: c( person("Tianqi", "Chen", role = c("aut"), email = "tianqi.tchen@gmail.com"), diff --git a/R-package/R/xgb.cv.R b/R-package/R/xgb.cv.R index 64f5259f8..eaf341ed1 100644 --- a/R-package/R/xgb.cv.R +++ b/R-package/R/xgb.cv.R @@ -133,8 +133,15 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = # Check the labels if ( (inherits(data, 'xgb.DMatrix') && is.null(getinfo(data, 'label'))) || - (!inherits(data, 'xgb.DMatrix') && is.null(label))) + (!inherits(data, 'xgb.DMatrix') && is.null(label))) { stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix") + } else if (inherits(data, 'xgb.DMatrix')) { + if (!is.null(label)) + warning("xgb.cv: label will be ignored, since data is of type xgb.DMatrix") + cv_label = getinfo(data, 'label') + } else { + cv_label = label + } # CV folds if(!is.null(folds)) { @@ -144,7 +151,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing = } else { if (nfold <= 1) stop("'nfold' must be > 1") - folds <- generate.cv.folds(nfold, nrow(data), stratified, label, params) + folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, params) } # Potential TODO: sequential CV diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index e5eee48d5..36c148a99 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -191,6 +191,20 @@ test_that("xgb.cv works", { expect_false(is.null(cv$call)) }) +test_that("xgb.cv works with stratified folds", { + dtrain <- xgb.DMatrix(train$data, label = train$label) + set.seed(314159) + cv <- xgb.cv(data = dtrain, max_depth = 2, nfold = 5, + eta = 1., nthread = 2, nrounds = 2, objective = "binary:logistic", + verbose=TRUE, stratified = FALSE) + set.seed(314159) + cv2 <- xgb.cv(data = dtrain, max_depth = 2, nfold = 5, + eta = 1., nthread = 2, nrounds = 2, objective = "binary:logistic", + verbose=TRUE, stratified = TRUE) + # Stratified folds should result in a different evaluation logs + expect_true(all(cv$evaluation_log[, test_error_mean] != cv2$evaluation_log[, test_error_mean])) +}) + test_that("train and predict with non-strict classes", { # standard dense matrix input train_dense <- as.matrix(train$data) diff --git a/R-package/tests/testthat/test_callbacks.R b/R-package/tests/testthat/test_callbacks.R index a8d726b13..11275d80a 100644 --- a/R-package/tests/testthat/test_callbacks.R +++ b/R-package/tests/testthat/test_callbacks.R @@ -285,7 +285,8 @@ test_that("prediction in early-stopping xgb.cv works", { set.seed(11) expect_output( cv <- xgb.cv(param, dtrain, nfold = 5, eta = 0.1, nrounds = 20, - early_stopping_rounds = 5, maximize = FALSE, prediction = TRUE) + early_stopping_rounds = 5, maximize = FALSE, stratified = FALSE, + prediction = TRUE) , "Stopping. Best iteration") expect_false(is.null(cv$best_iteration))