[R] Use built-in label when xgb.DMatrix is given to xgb.cv() (#4631)
* Use built-in label when xgb.DMatrix is given to xgb.cv() * Add a test * Fix test * Bump version number
This commit is contained in:
committed by
GitHub
parent
986fee6022
commit
4e9fad74eb
@@ -133,8 +133,15 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
|
||||
|
||||
# Check the labels
|
||||
if ( (inherits(data, 'xgb.DMatrix') && is.null(getinfo(data, 'label'))) ||
|
||||
(!inherits(data, 'xgb.DMatrix') && is.null(label)))
|
||||
(!inherits(data, 'xgb.DMatrix') && is.null(label))) {
|
||||
stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix")
|
||||
} else if (inherits(data, 'xgb.DMatrix')) {
|
||||
if (!is.null(label))
|
||||
warning("xgb.cv: label will be ignored, since data is of type xgb.DMatrix")
|
||||
cv_label = getinfo(data, 'label')
|
||||
} else {
|
||||
cv_label = label
|
||||
}
|
||||
|
||||
# CV folds
|
||||
if(!is.null(folds)) {
|
||||
@@ -144,7 +151,7 @@ xgb.cv <- function(params=list(), data, nrounds, nfold, label = NULL, missing =
|
||||
} else {
|
||||
if (nfold <= 1)
|
||||
stop("'nfold' must be > 1")
|
||||
folds <- generate.cv.folds(nfold, nrow(data), stratified, label, params)
|
||||
folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, params)
|
||||
}
|
||||
|
||||
# Potential TODO: sequential CV
|
||||
|
||||
Reference in New Issue
Block a user