[R] Use built-in label when xgb.DMatrix is given to xgb.cv() (#4631)
* Use built-in label when xgb.DMatrix is given to xgb.cv() * Add a test * Fix test * Bump version number
This commit is contained in:
parent
986fee6022
commit
4e9fad74eb
@ -1,8 +1,8 @@
|
|||||||
Package: xgboost
|
Package: xgboost
|
||||||
Type: Package
|
Type: Package
|
||||||
Title: Extreme Gradient Boosting
|
Title: Extreme Gradient Boosting
|
||||||
Version: 0.82.0.1
|
Version: 0.90.0.1
|
||||||
Date: 2019-03-11
|
Date: 2019-05-18
|
||||||
Authors@R: c(
|
Authors@R: c(
|
||||||
person("Tianqi", "Chen", role = c("aut"),
|
person("Tianqi", "Chen", role = c("aut"),
|
||||||
email = "tianqi.tchen@gmail.com"),
|
email = "tianqi.tchen@gmail.com"),
|
||||||
|
|||||||
@ -133,8 +133,15 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
|
|||||||
|
|
||||||
# Check the labels
|
# Check the labels
|
||||||
if ( (inherits(data, 'xgb.DMatrix') && is.null(getinfo(data, 'label'))) ||
|
if ( (inherits(data, 'xgb.DMatrix') && is.null(getinfo(data, 'label'))) ||
|
||||||
(!inherits(data, 'xgb.DMatrix') && is.null(label)))
|
(!inherits(data, 'xgb.DMatrix') && is.null(label))) {
|
||||||
stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix")
|
stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix")
|
||||||
|
} else if (inherits(data, 'xgb.DMatrix')) {
|
||||||
|
if (!is.null(label))
|
||||||
|
warning("xgb.cv: label will be ignored, since data is of type xgb.DMatrix")
|
||||||
|
cv_label = getinfo(data, 'label')
|
||||||
|
} else {
|
||||||
|
cv_label = label
|
||||||
|
}
|
||||||
|
|
||||||
# CV folds
|
# CV folds
|
||||||
if(!is.null(folds)) {
|
if(!is.null(folds)) {
|
||||||
@ -144,7 +151,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
|
|||||||
} else {
|
} else {
|
||||||
if (nfold <= 1)
|
if (nfold <= 1)
|
||||||
stop("'nfold' must be > 1")
|
stop("'nfold' must be > 1")
|
||||||
folds <- generate.cv.folds(nfold, nrow(data), stratified, label, params)
|
folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Potential TODO: sequential CV
|
# Potential TODO: sequential CV
|
||||||
|
|||||||
@ -191,6 +191,20 @@ test_that("xgb.cv works", {
|
|||||||
expect_false(is.null(cv$call))
|
expect_false(is.null(cv$call))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("xgb.cv works with stratified folds", {
|
||||||
|
dtrain <- xgb.DMatrix(train$data, label = train$label)
|
||||||
|
set.seed(314159)
|
||||||
|
cv <- xgb.cv(data = dtrain, max_depth = 2, nfold = 5,
|
||||||
|
eta = 1., nthread = 2, nrounds = 2, objective = "binary:logistic",
|
||||||
|
verbose=TRUE, stratified = FALSE)
|
||||||
|
set.seed(314159)
|
||||||
|
cv2 <- xgb.cv(data = dtrain, max_depth = 2, nfold = 5,
|
||||||
|
eta = 1., nthread = 2, nrounds = 2, objective = "binary:logistic",
|
||||||
|
verbose=TRUE, stratified = TRUE)
|
||||||
|
# Stratified folds should result in a different evaluation logs
|
||||||
|
expect_true(all(cv$evaluation_log[, test_error_mean] != cv2$evaluation_log[, test_error_mean]))
|
||||||
|
})
|
||||||
|
|
||||||
test_that("train and predict with non-strict classes", {
|
test_that("train and predict with non-strict classes", {
|
||||||
# standard dense matrix input
|
# standard dense matrix input
|
||||||
train_dense <- as.matrix(train$data)
|
train_dense <- as.matrix(train$data)
|
||||||
|
|||||||
@ -285,7 +285,8 @@ test_that("prediction in early-stopping xgb.cv works", {
|
|||||||
set.seed(11)
|
set.seed(11)
|
||||||
expect_output(
|
expect_output(
|
||||||
cv <- xgb.cv(param, dtrain, nfold = 5, eta = 0.1, nrounds = 20,
|
cv <- xgb.cv(param, dtrain, nfold = 5, eta = 0.1, nrounds = 20,
|
||||||
early_stopping_rounds = 5, maximize = FALSE, prediction = TRUE)
|
early_stopping_rounds = 5, maximize = FALSE, stratified = FALSE,
|
||||||
|
prediction = TRUE)
|
||||||
, "Stopping. Best iteration")
|
, "Stopping. Best iteration")
|
||||||
|
|
||||||
expect_false(is.null(cv$best_iteration))
|
expect_false(is.null(cv$best_iteration))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user