[R] Make xgb.cv work with xgb.DMatrix only, adding support for survival and ranking fields (#10031)
--------- Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
8bad677c2f
commit
bc9ea62ec0
@ -26,6 +26,11 @@ NVL <- function(x, val) {
|
|||||||
'multi:softprob', 'rank:pairwise', 'rank:ndcg', 'rank:map'))
|
'multi:softprob', 'rank:pairwise', 'rank:ndcg', 'rank:map'))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.RANKING_OBJECTIVES <- function() {
|
||||||
|
return(c('binary:logistic', 'binary:logitraw', 'binary:hinge', 'multi:softmax',
|
||||||
|
'multi:softprob'))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Low-level functions for boosting --------------------------------------------
|
# Low-level functions for boosting --------------------------------------------
|
||||||
@ -235,33 +240,43 @@ convert.labels <- function(labels, objective_name) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Generates random (stratified if needed) CV folds
|
# Generates random (stratified if needed) CV folds
|
||||||
generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
|
generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
|
||||||
|
if (NROW(group)) {
|
||||||
|
if (stratified) {
|
||||||
|
warning(
|
||||||
|
paste0(
|
||||||
|
"Stratified splitting is not supported when using 'group' attribute.",
|
||||||
|
" Will use unstratified splitting."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return(generate.group.folds(nfold, group))
|
||||||
|
}
|
||||||
|
objective <- params$objective
|
||||||
|
if (!is.character(objective)) {
|
||||||
|
warning("Will use unstratified splitting (custom objective used)")
|
||||||
|
stratified <- FALSE
|
||||||
|
}
|
||||||
|
# cannot stratify if label is NULL
|
||||||
|
if (stratified && is.null(label)) {
|
||||||
|
warning("Will use unstratified splitting (no 'labels' available)")
|
||||||
|
stratified <- FALSE
|
||||||
|
}
|
||||||
|
|
||||||
# cannot do it for rank
|
# cannot do it for rank
|
||||||
objective <- params$objective
|
|
||||||
if (is.character(objective) && strtrim(objective, 5) == 'rank:') {
|
if (is.character(objective) && strtrim(objective, 5) == 'rank:') {
|
||||||
stop("\n\tAutomatic generation of CV-folds is not implemented for ranking!\n",
|
stop("\n\tAutomatic generation of CV-folds is not implemented for ranking without 'group' field!\n",
|
||||||
"\tConsider providing pre-computed CV-folds through the 'folds=' parameter.\n")
|
"\tConsider providing pre-computed CV-folds through the 'folds=' parameter.\n")
|
||||||
}
|
}
|
||||||
# shuffle
|
# shuffle
|
||||||
rnd_idx <- sample.int(nrows)
|
rnd_idx <- sample.int(nrows)
|
||||||
if (stratified &&
|
if (stratified && length(label) == length(rnd_idx)) {
|
||||||
length(label) == length(rnd_idx)) {
|
|
||||||
y <- label[rnd_idx]
|
y <- label[rnd_idx]
|
||||||
# WARNING: some heuristic logic is employed to identify classification setting!
|
|
||||||
# - For classification, need to convert y labels to factor before making the folds,
|
# - For classification, need to convert y labels to factor before making the folds,
|
||||||
# and then do stratification by factor levels.
|
# and then do stratification by factor levels.
|
||||||
# - For regression, leave y numeric and do stratification by quantiles.
|
# - For regression, leave y numeric and do stratification by quantiles.
|
||||||
if (is.character(objective)) {
|
if (is.character(objective)) {
|
||||||
y <- convert.labels(y, params$objective)
|
y <- convert.labels(y, objective)
|
||||||
} else {
|
|
||||||
# If no 'objective' given in params, it means that user either wants to
|
|
||||||
# use the default 'reg:squarederror' objective or has provided a custom
|
|
||||||
# obj function. Here, assume classification setting when y has 5 or less
|
|
||||||
# unique values:
|
|
||||||
if (length(unique(y)) <= 5) {
|
|
||||||
y <- factor(y)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
folds <- xgb.createFolds(y = y, k = nfold)
|
folds <- xgb.createFolds(y = y, k = nfold)
|
||||||
} else {
|
} else {
|
||||||
@ -277,6 +292,29 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
|
|||||||
return(folds)
|
return(folds)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
generate.group.folds <- function(nfold, group) {
|
||||||
|
ngroups <- length(group) - 1
|
||||||
|
if (ngroups < nfold) {
|
||||||
|
stop("DMatrix has fewer groups than folds.")
|
||||||
|
}
|
||||||
|
seq_groups <- seq_len(ngroups)
|
||||||
|
indices <- lapply(seq_groups, function(gr) seq(group[gr] + 1, group[gr + 1]))
|
||||||
|
assignments <- base::split(seq_groups, as.integer(seq_groups %% nfold))
|
||||||
|
assignments <- unname(assignments)
|
||||||
|
|
||||||
|
out <- vector("list", nfold)
|
||||||
|
randomized_groups <- sample(ngroups)
|
||||||
|
for (idx in seq_len(nfold)) {
|
||||||
|
groups_idx_test <- randomized_groups[assignments[[idx]]]
|
||||||
|
groups_test <- indices[groups_idx_test]
|
||||||
|
idx_test <- unlist(groups_test)
|
||||||
|
attributes(idx_test)$group_test <- lengths(groups_test)
|
||||||
|
attributes(idx_test)$group_train <- lengths(indices[-groups_idx_test])
|
||||||
|
out[[idx]] <- idx_test
|
||||||
|
}
|
||||||
|
return(out)
|
||||||
|
}
|
||||||
|
|
||||||
# Creates CV folds stratified by the values of y.
|
# Creates CV folds stratified by the values of y.
|
||||||
# It was borrowed from caret::createFolds and simplified
|
# It was borrowed from caret::createFolds and simplified
|
||||||
# by always returning an unnamed list of fold indices.
|
# by always returning an unnamed list of fold indices.
|
||||||
|
|||||||
@ -1259,8 +1259,11 @@ xgb.get.DMatrix.data <- function(dmat) {
|
|||||||
#' Get a new DMatrix containing the specified rows of
|
#' Get a new DMatrix containing the specified rows of
|
||||||
#' original xgb.DMatrix object
|
#' original xgb.DMatrix object
|
||||||
#'
|
#'
|
||||||
#' @param object Object of class "xgb.DMatrix"
|
#' @param object Object of class "xgb.DMatrix".
|
||||||
#' @param idxset a integer vector of indices of rows needed
|
#' @param idxset An integer vector of indices of rows needed (base-1 indexing).
|
||||||
|
#' @param allow_groups Whether to allow slicing an `xgb.DMatrix` with `group` (or
|
||||||
|
#' equivalently `qid`) field. Note that in such case, the result will not have
|
||||||
|
#' the groups anymore - they need to be set manually through `setinfo`.
|
||||||
#' @param colset currently not used (columns subsetting is not available)
|
#' @param colset currently not used (columns subsetting is not available)
|
||||||
#'
|
#'
|
||||||
#' @examples
|
#' @examples
|
||||||
@ -1275,11 +1278,11 @@ xgb.get.DMatrix.data <- function(dmat) {
|
|||||||
#'
|
#'
|
||||||
#' @rdname xgb.slice.DMatrix
|
#' @rdname xgb.slice.DMatrix
|
||||||
#' @export
|
#' @export
|
||||||
xgb.slice.DMatrix <- function(object, idxset) {
|
xgb.slice.DMatrix <- function(object, idxset, allow_groups = FALSE) {
|
||||||
if (!inherits(object, "xgb.DMatrix")) {
|
if (!inherits(object, "xgb.DMatrix")) {
|
||||||
stop("object must be xgb.DMatrix")
|
stop("object must be xgb.DMatrix")
|
||||||
}
|
}
|
||||||
ret <- .Call(XGDMatrixSliceDMatrix_R, object, idxset)
|
ret <- .Call(XGDMatrixSliceDMatrix_R, object, idxset, allow_groups)
|
||||||
|
|
||||||
attr_list <- attributes(object)
|
attr_list <- attributes(object)
|
||||||
nr <- nrow(object)
|
nr <- nrow(object)
|
||||||
@ -1296,7 +1299,15 @@ xgb.slice.DMatrix <- function(object, idxset) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return(structure(ret, class = "xgb.DMatrix"))
|
|
||||||
|
out <- structure(ret, class = "xgb.DMatrix")
|
||||||
|
parent_fields <- as.list(attributes(object)$fields)
|
||||||
|
if (NROW(parent_fields)) {
|
||||||
|
child_fields <- parent_fields[!(names(parent_fields) %in% c("group", "qid"))]
|
||||||
|
child_fields <- as.environment(child_fields)
|
||||||
|
attributes(out)$fields <- child_fields
|
||||||
|
}
|
||||||
|
return(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
#' @rdname xgb.slice.DMatrix
|
#' @rdname xgb.slice.DMatrix
|
||||||
@ -1340,11 +1351,11 @@ print.xgb.DMatrix <- function(x, verbose = FALSE, ...) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cat(class_print, ' dim:', nrow(x), 'x', ncol(x), ' info: ')
|
cat(class_print, ' dim:', nrow(x), 'x', ncol(x), ' info: ')
|
||||||
infos <- character(0)
|
infos <- names(attributes(x)$fields)
|
||||||
if (xgb.DMatrix.hasinfo(x, 'label')) infos <- 'label'
|
infos <- infos[infos != "feature_name"]
|
||||||
if (xgb.DMatrix.hasinfo(x, 'weight')) infos <- c(infos, 'weight')
|
if (!NROW(infos)) infos <- "NA"
|
||||||
if (xgb.DMatrix.hasinfo(x, 'base_margin')) infos <- c(infos, 'base_margin')
|
infos <- infos[order(infos)]
|
||||||
if (length(infos) == 0) infos <- 'NA'
|
infos <- paste(infos, collapse = ", ")
|
||||||
cat(infos)
|
cat(infos)
|
||||||
cnames <- colnames(x)
|
cnames <- colnames(x)
|
||||||
cat(' colnames:')
|
cat(' colnames:')
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
#' Cross Validation
|
#' Cross Validation
|
||||||
#'
|
#'
|
||||||
#' The cross validation function of xgboost
|
#' The cross validation function of xgboost.
|
||||||
#'
|
#'
|
||||||
#' @param params the list of parameters. The complete list of parameters is
|
#' @param params the list of parameters. The complete list of parameters is
|
||||||
#' available in the \href{http://xgboost.readthedocs.io/en/latest/parameter.html}{online documentation}. Below
|
#' available in the \href{http://xgboost.readthedocs.io/en/latest/parameter.html}{online documentation}. Below
|
||||||
@ -19,13 +19,17 @@
|
|||||||
#'
|
#'
|
||||||
#' See \code{\link{xgb.train}} for further details.
|
#' See \code{\link{xgb.train}} for further details.
|
||||||
#' See also demo/ for walkthrough example in R.
|
#' See also demo/ for walkthrough example in R.
|
||||||
#' @param data takes an \code{xgb.DMatrix}, \code{matrix}, or \code{dgCMatrix} as the input.
|
#'
|
||||||
|
#' Note that, while `params` accepts a `seed` entry and will use such parameter for model training if
|
||||||
|
#' supplied, this seed is not used for creation of train-test splits, which instead rely on R's own RNG
|
||||||
|
#' system - thus, for reproducible results, one needs to call the `set.seed` function beforehand.
|
||||||
|
#' @param data An `xgb.DMatrix` object, with corresponding fields like `label` or bounds as required
|
||||||
|
#' for model training by the objective.
|
||||||
|
#'
|
||||||
|
#' Note that only the basic `xgb.DMatrix` class is supported - variants such as `xgb.QuantileDMatrix`
|
||||||
|
#' or `xgb.ExternalDMatrix` are not supported here.
|
||||||
#' @param nrounds the max number of iterations
|
#' @param nrounds the max number of iterations
|
||||||
#' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples.
|
#' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples.
|
||||||
#' @param label vector of response values. Should be provided only when data is an R-matrix.
|
|
||||||
#' @param missing is only used when input is a dense matrix. By default is set to NA, which means
|
|
||||||
#' that NA values should be considered as 'missing' by the algorithm.
|
|
||||||
#' Sometimes, 0 or other extreme value might be used to represent missing values.
|
|
||||||
#' @param prediction A logical value indicating whether to return the test fold predictions
|
#' @param prediction A logical value indicating whether to return the test fold predictions
|
||||||
#' from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback.
|
#' from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback.
|
||||||
#' @param showsd \code{boolean}, whether to show standard deviation of cross validation
|
#' @param showsd \code{boolean}, whether to show standard deviation of cross validation
|
||||||
@ -47,13 +51,30 @@
|
|||||||
#' @param feval customized evaluation function. Returns
|
#' @param feval customized evaluation function. Returns
|
||||||
#' \code{list(metric='metric-name', value='metric-value')} with given
|
#' \code{list(metric='metric-name', value='metric-value')} with given
|
||||||
#' prediction and dtrain.
|
#' prediction and dtrain.
|
||||||
#' @param stratified a \code{boolean} indicating whether sampling of folds should be stratified
|
#' @param stratified A \code{boolean} indicating whether sampling of folds should be stratified
|
||||||
#' by the values of outcome labels.
|
#' by the values of outcome labels. For real-valued labels in regression objectives,
|
||||||
|
#' stratification will be done by discretizing the labels into up to 5 buckets beforehand.
|
||||||
|
#'
|
||||||
|
#' If passing "auto", will be set to `TRUE` if the objective in `params` is a classification
|
||||||
|
#' objective (from XGBoost's built-in objectives, doesn't apply to custom ones), and to
|
||||||
|
#' `FALSE` otherwise.
|
||||||
|
#'
|
||||||
|
#' This parameter is ignored when `data` has a `group` field - in such case, the splitting
|
||||||
|
#' will be based on whole groups (note that this might make the folds have different sizes).
|
||||||
|
#'
|
||||||
|
#' Value `TRUE` here is \bold{not} supported for custom objectives.
|
||||||
#' @param folds \code{list} provides a possibility to use a list of pre-defined CV folds
|
#' @param folds \code{list} provides a possibility to use a list of pre-defined CV folds
|
||||||
#' (each element must be a vector of test fold's indices). When folds are supplied,
|
#' (each element must be a vector of test fold's indices). When folds are supplied,
|
||||||
#' the \code{nfold} and \code{stratified} parameters are ignored.
|
#' the \code{nfold} and \code{stratified} parameters are ignored.
|
||||||
|
#'
|
||||||
|
#' If `data` has a `group` field and the objective requires this field, each fold (list element)
|
||||||
|
#' must additionally have two attributes (retrievable through \link{attributes}) named `group_test`
|
||||||
|
#' and `group_train`, which should hold the `group` to assign through \link{setinfo.xgb.DMatrix} to
|
||||||
|
#' the resulting DMatrices.
|
||||||
#' @param train_folds \code{list} list specifying which indicies to use for training. If \code{NULL}
|
#' @param train_folds \code{list} list specifying which indicies to use for training. If \code{NULL}
|
||||||
#' (the default) all indices not specified in \code{folds} will be used for training.
|
#' (the default) all indices not specified in \code{folds} will be used for training.
|
||||||
|
#'
|
||||||
|
#' This is not supported when `data` has `group` field.
|
||||||
#' @param verbose \code{boolean}, print the statistics during the process
|
#' @param verbose \code{boolean}, print the statistics during the process
|
||||||
#' @param print_every_n Print each n-th iteration evaluation messages when \code{verbose>0}.
|
#' @param print_every_n Print each n-th iteration evaluation messages when \code{verbose>0}.
|
||||||
#' Default is 1 which means all messages are printed. This parameter is passed to the
|
#' Default is 1 which means all messages are printed. This parameter is passed to the
|
||||||
@ -118,13 +139,14 @@
|
|||||||
#' print(cv, verbose=TRUE)
|
#' print(cv, verbose=TRUE)
|
||||||
#'
|
#'
|
||||||
#' @export
|
#' @export
|
||||||
xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing = NA,
|
xgb.cv <- function(params = list(), data, nrounds, nfold,
|
||||||
prediction = FALSE, showsd = TRUE, metrics = list(),
|
prediction = FALSE, showsd = TRUE, metrics = list(),
|
||||||
obj = NULL, feval = NULL, stratified = TRUE, folds = NULL, train_folds = NULL,
|
obj = NULL, feval = NULL, stratified = "auto", folds = NULL, train_folds = NULL,
|
||||||
verbose = TRUE, print_every_n = 1L,
|
verbose = TRUE, print_every_n = 1L,
|
||||||
early_stopping_rounds = NULL, maximize = NULL, callbacks = list(), ...) {
|
early_stopping_rounds = NULL, maximize = NULL, callbacks = list(), ...) {
|
||||||
|
|
||||||
check.deprecation(...)
|
check.deprecation(...)
|
||||||
|
stopifnot(inherits(data, "xgb.DMatrix"))
|
||||||
if (inherits(data, "xgb.DMatrix") && .Call(XGCheckNullPtr_R, data)) {
|
if (inherits(data, "xgb.DMatrix") && .Call(XGCheckNullPtr_R, data)) {
|
||||||
stop("'data' is an invalid 'xgb.DMatrix' object. Must be constructed again.")
|
stop("'data' is an invalid 'xgb.DMatrix' object. Must be constructed again.")
|
||||||
}
|
}
|
||||||
@ -137,16 +159,22 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
|
|||||||
check.custom.obj()
|
check.custom.obj()
|
||||||
check.custom.eval()
|
check.custom.eval()
|
||||||
|
|
||||||
# Check the labels
|
if (stratified == "auto") {
|
||||||
if ((inherits(data, 'xgb.DMatrix') && !xgb.DMatrix.hasinfo(data, 'label')) ||
|
if (is.character(params$objective)) {
|
||||||
(!inherits(data, 'xgb.DMatrix') && is.null(label))) {
|
stratified <- (
|
||||||
stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix")
|
(params$objective %in% .CLASSIFICATION_OBJECTIVES())
|
||||||
} else if (inherits(data, 'xgb.DMatrix')) {
|
&& !(params$objective %in% .RANKING_OBJECTIVES())
|
||||||
if (!is.null(label))
|
)
|
||||||
warning("xgb.cv: label will be ignored, since data is of type xgb.DMatrix")
|
} else {
|
||||||
cv_label <- getinfo(data, 'label')
|
stratified <- FALSE
|
||||||
} else {
|
}
|
||||||
cv_label <- label
|
}
|
||||||
|
|
||||||
|
# Check the labels and groups
|
||||||
|
cv_label <- getinfo(data, "label")
|
||||||
|
cv_group <- getinfo(data, "group")
|
||||||
|
if (!is.null(train_folds) && NROW(cv_group)) {
|
||||||
|
stop("'train_folds' is not supported for DMatrix object with 'group' field.")
|
||||||
}
|
}
|
||||||
|
|
||||||
# CV folds
|
# CV folds
|
||||||
@ -157,7 +185,7 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
|
|||||||
} else {
|
} else {
|
||||||
if (nfold <= 1)
|
if (nfold <= 1)
|
||||||
stop("'nfold' must be > 1")
|
stop("'nfold' must be > 1")
|
||||||
folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, params)
|
folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, cv_group, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
@ -195,20 +223,18 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
|
|||||||
|
|
||||||
# create the booster-folds
|
# create the booster-folds
|
||||||
# train_folds
|
# train_folds
|
||||||
dall <- xgb.get.DMatrix(
|
dall <- data
|
||||||
data = data,
|
|
||||||
label = label,
|
|
||||||
missing = missing,
|
|
||||||
weight = NULL,
|
|
||||||
nthread = params$nthread
|
|
||||||
)
|
|
||||||
bst_folds <- lapply(seq_along(folds), function(k) {
|
bst_folds <- lapply(seq_along(folds), function(k) {
|
||||||
dtest <- xgb.slice.DMatrix(dall, folds[[k]])
|
dtest <- xgb.slice.DMatrix(dall, folds[[k]], allow_groups = TRUE)
|
||||||
# code originally contributed by @RolandASc on stackoverflow
|
# code originally contributed by @RolandASc on stackoverflow
|
||||||
if (is.null(train_folds))
|
if (is.null(train_folds))
|
||||||
dtrain <- xgb.slice.DMatrix(dall, unlist(folds[-k]))
|
dtrain <- xgb.slice.DMatrix(dall, unlist(folds[-k]), allow_groups = TRUE)
|
||||||
else
|
else
|
||||||
dtrain <- xgb.slice.DMatrix(dall, train_folds[[k]])
|
dtrain <- xgb.slice.DMatrix(dall, train_folds[[k]], allow_groups = TRUE)
|
||||||
|
if (!is.null(attributes(folds[[k]])$group_test)) {
|
||||||
|
setinfo(dtest, "group", attributes(folds[[k]])$group_test)
|
||||||
|
setinfo(dtrain, "group", attributes(folds[[k]])$group_train)
|
||||||
|
}
|
||||||
bst <- xgb.Booster(
|
bst <- xgb.Booster(
|
||||||
params = params,
|
params = params,
|
||||||
cachelist = list(dtrain, dtest),
|
cachelist = list(dtrain, dtest),
|
||||||
@ -312,8 +338,8 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
|
|||||||
#' @examples
|
#' @examples
|
||||||
#' data(agaricus.train, package='xgboost')
|
#' data(agaricus.train, package='xgboost')
|
||||||
#' train <- agaricus.train
|
#' train <- agaricus.train
|
||||||
#' cv <- xgb.cv(data = train$data, label = train$label, nfold = 5, max_depth = 2,
|
#' cv <- xgb.cv(data = xgb.DMatrix(train$data, label = train$label), nfold = 5, max_depth = 2,
|
||||||
#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
|
#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
|
||||||
#' print(cv)
|
#' print(cv)
|
||||||
#' print(cv, verbose=TRUE)
|
#' print(cv, verbose=TRUE)
|
||||||
#'
|
#'
|
||||||
|
|||||||
@ -23,8 +23,8 @@ including the best iteration (when available).
|
|||||||
\examples{
|
\examples{
|
||||||
data(agaricus.train, package='xgboost')
|
data(agaricus.train, package='xgboost')
|
||||||
train <- agaricus.train
|
train <- agaricus.train
|
||||||
cv <- xgb.cv(data = train$data, label = train$label, nfold = 5, max_depth = 2,
|
cv <- xgb.cv(data = xgb.DMatrix(train$data, label = train$label), nfold = 5, max_depth = 2,
|
||||||
eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
|
eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
|
||||||
print(cv)
|
print(cv)
|
||||||
print(cv, verbose=TRUE)
|
print(cv, verbose=TRUE)
|
||||||
|
|
||||||
|
|||||||
@ -9,14 +9,12 @@ xgb.cv(
|
|||||||
data,
|
data,
|
||||||
nrounds,
|
nrounds,
|
||||||
nfold,
|
nfold,
|
||||||
label = NULL,
|
|
||||||
missing = NA,
|
|
||||||
prediction = FALSE,
|
prediction = FALSE,
|
||||||
showsd = TRUE,
|
showsd = TRUE,
|
||||||
metrics = list(),
|
metrics = list(),
|
||||||
obj = NULL,
|
obj = NULL,
|
||||||
feval = NULL,
|
feval = NULL,
|
||||||
stratified = TRUE,
|
stratified = "auto",
|
||||||
folds = NULL,
|
folds = NULL,
|
||||||
train_folds = NULL,
|
train_folds = NULL,
|
||||||
verbose = TRUE,
|
verbose = TRUE,
|
||||||
@ -44,20 +42,23 @@ is a shorter summary:
|
|||||||
}
|
}
|
||||||
|
|
||||||
See \code{\link{xgb.train}} for further details.
|
See \code{\link{xgb.train}} for further details.
|
||||||
See also demo/ for walkthrough example in R.}
|
See also demo/ for walkthrough example in R.
|
||||||
|
|
||||||
\item{data}{takes an \code{xgb.DMatrix}, \code{matrix}, or \code{dgCMatrix} as the input.}
|
Note that, while \code{params} accepts a \code{seed} entry and will use such parameter for model training if
|
||||||
|
supplied, this seed is not used for creation of train-test splits, which instead rely on R's own RNG
|
||||||
|
system - thus, for reproducible results, one needs to call the \code{set.seed} function beforehand.}
|
||||||
|
|
||||||
|
\item{data}{An \code{xgb.DMatrix} object, with corresponding fields like \code{label} or bounds as required
|
||||||
|
for model training by the objective.
|
||||||
|
|
||||||
|
\if{html}{\out{<div class="sourceCode">}}\preformatted{ Note that only the basic `xgb.DMatrix` class is supported - variants such as `xgb.QuantileDMatrix`
|
||||||
|
or `xgb.ExternalDMatrix` are not supported here.
|
||||||
|
}\if{html}{\out{</div>}}}
|
||||||
|
|
||||||
\item{nrounds}{the max number of iterations}
|
\item{nrounds}{the max number of iterations}
|
||||||
|
|
||||||
\item{nfold}{the original dataset is randomly partitioned into \code{nfold} equal size subsamples.}
|
\item{nfold}{the original dataset is randomly partitioned into \code{nfold} equal size subsamples.}
|
||||||
|
|
||||||
\item{label}{vector of response values. Should be provided only when data is an R-matrix.}
|
|
||||||
|
|
||||||
\item{missing}{is only used when input is a dense matrix. By default is set to NA, which means
|
|
||||||
that NA values should be considered as 'missing' by the algorithm.
|
|
||||||
Sometimes, 0 or other extreme value might be used to represent missing values.}
|
|
||||||
|
|
||||||
\item{prediction}{A logical value indicating whether to return the test fold predictions
|
\item{prediction}{A logical value indicating whether to return the test fold predictions
|
||||||
from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback.}
|
from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback.}
|
||||||
|
|
||||||
@ -84,15 +85,35 @@ gradient with given prediction and dtrain.}
|
|||||||
\code{list(metric='metric-name', value='metric-value')} with given
|
\code{list(metric='metric-name', value='metric-value')} with given
|
||||||
prediction and dtrain.}
|
prediction and dtrain.}
|
||||||
|
|
||||||
\item{stratified}{a \code{boolean} indicating whether sampling of folds should be stratified
|
\item{stratified}{A \code{boolean} indicating whether sampling of folds should be stratified
|
||||||
by the values of outcome labels.}
|
by the values of outcome labels. For real-valued labels in regression objectives,
|
||||||
|
stratification will be done by discretizing the labels into up to 5 buckets beforehand.
|
||||||
|
|
||||||
|
\if{html}{\out{<div class="sourceCode">}}\preformatted{ If passing "auto", will be set to `TRUE` if the objective in `params` is a classification
|
||||||
|
objective (from XGBoost's built-in objectives, doesn't apply to custom ones), and to
|
||||||
|
`FALSE` otherwise.
|
||||||
|
|
||||||
|
This parameter is ignored when `data` has a `group` field - in such case, the splitting
|
||||||
|
will be based on whole groups (note that this might make the folds have different sizes).
|
||||||
|
|
||||||
|
Value `TRUE` here is \\bold\{not\} supported for custom objectives.
|
||||||
|
}\if{html}{\out{</div>}}}
|
||||||
|
|
||||||
\item{folds}{\code{list} provides a possibility to use a list of pre-defined CV folds
|
\item{folds}{\code{list} provides a possibility to use a list of pre-defined CV folds
|
||||||
(each element must be a vector of test fold's indices). When folds are supplied,
|
(each element must be a vector of test fold's indices). When folds are supplied,
|
||||||
the \code{nfold} and \code{stratified} parameters are ignored.}
|
the \code{nfold} and \code{stratified} parameters are ignored.
|
||||||
|
|
||||||
|
\if{html}{\out{<div class="sourceCode">}}\preformatted{ If `data` has a `group` field and the objective requires this field, each fold (list element)
|
||||||
|
must additionally have two attributes (retrievable through \link{attributes}) named `group_test`
|
||||||
|
and `group_train`, which should hold the `group` to assign through \link{setinfo.xgb.DMatrix} to
|
||||||
|
the resulting DMatrices.
|
||||||
|
}\if{html}{\out{</div>}}}
|
||||||
|
|
||||||
\item{train_folds}{\code{list} list specifying which indicies to use for training. If \code{NULL}
|
\item{train_folds}{\code{list} list specifying which indicies to use for training. If \code{NULL}
|
||||||
(the default) all indices not specified in \code{folds} will be used for training.}
|
(the default) all indices not specified in \code{folds} will be used for training.
|
||||||
|
|
||||||
|
\if{html}{\out{<div class="sourceCode">}}\preformatted{ This is not supported when `data` has `group` field.
|
||||||
|
}\if{html}{\out{</div>}}}
|
||||||
|
|
||||||
\item{verbose}{\code{boolean}, print the statistics during the process}
|
\item{verbose}{\code{boolean}, print the statistics during the process}
|
||||||
|
|
||||||
@ -142,7 +163,7 @@ such as saving also the models created during cross validation); or a list \code
|
|||||||
will contain elements such as \code{best_iteration} when using the early stopping callback (\link{xgb.cb.early.stop}).
|
will contain elements such as \code{best_iteration} when using the early stopping callback (\link{xgb.cb.early.stop}).
|
||||||
}
|
}
|
||||||
\description{
|
\description{
|
||||||
The cross validation function of xgboost
|
The cross validation function of xgboost.
|
||||||
}
|
}
|
||||||
\details{
|
\details{
|
||||||
The original sample is randomly partitioned into \code{nfold} equal size subsamples.
|
The original sample is randomly partitioned into \code{nfold} equal size subsamples.
|
||||||
|
|||||||
@ -6,14 +6,18 @@
|
|||||||
\title{Get a new DMatrix containing the specified rows of
|
\title{Get a new DMatrix containing the specified rows of
|
||||||
original xgb.DMatrix object}
|
original xgb.DMatrix object}
|
||||||
\usage{
|
\usage{
|
||||||
xgb.slice.DMatrix(object, idxset)
|
xgb.slice.DMatrix(object, idxset, allow_groups = FALSE)
|
||||||
|
|
||||||
\method{[}{xgb.DMatrix}(object, idxset, colset = NULL)
|
\method{[}{xgb.DMatrix}(object, idxset, colset = NULL)
|
||||||
}
|
}
|
||||||
\arguments{
|
\arguments{
|
||||||
\item{object}{Object of class "xgb.DMatrix"}
|
\item{object}{Object of class "xgb.DMatrix".}
|
||||||
|
|
||||||
\item{idxset}{a integer vector of indices of rows needed}
|
\item{idxset}{An integer vector of indices of rows needed (base-1 indexing).}
|
||||||
|
|
||||||
|
\item{allow_groups}{Whether to allow slicing an \code{xgb.DMatrix} with \code{group} (or
|
||||||
|
equivalently \code{qid}) field. Note that in such case, the result will not have
|
||||||
|
the groups anymore - they need to be set manually through \code{setinfo}.}
|
||||||
|
|
||||||
\item{colset}{currently not used (columns subsetting is not available)}
|
\item{colset}{currently not used (columns subsetting is not available)}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -71,7 +71,7 @@ extern SEXP XGDMatrixGetDataAsCSR_R(SEXP);
|
|||||||
extern SEXP XGDMatrixSaveBinary_R(SEXP, SEXP, SEXP);
|
extern SEXP XGDMatrixSaveBinary_R(SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP);
|
extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixSetStrFeatureInfo_R(SEXP, SEXP, SEXP);
|
extern SEXP XGDMatrixSetStrFeatureInfo_R(SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP);
|
extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGBSetGlobalConfig_R(SEXP);
|
extern SEXP XGBSetGlobalConfig_R(SEXP);
|
||||||
extern SEXP XGBGetGlobalConfig_R(void);
|
extern SEXP XGBGetGlobalConfig_R(void);
|
||||||
extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP);
|
extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP);
|
||||||
@ -134,7 +134,7 @@ static const R_CallMethodDef CallEntries[] = {
|
|||||||
{"XGDMatrixSaveBinary_R", (DL_FUNC) &XGDMatrixSaveBinary_R, 3},
|
{"XGDMatrixSaveBinary_R", (DL_FUNC) &XGDMatrixSaveBinary_R, 3},
|
||||||
{"XGDMatrixSetInfo_R", (DL_FUNC) &XGDMatrixSetInfo_R, 3},
|
{"XGDMatrixSetInfo_R", (DL_FUNC) &XGDMatrixSetInfo_R, 3},
|
||||||
{"XGDMatrixSetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixSetStrFeatureInfo_R, 3},
|
{"XGDMatrixSetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixSetStrFeatureInfo_R, 3},
|
||||||
{"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 2},
|
{"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 3},
|
||||||
{"XGBSetGlobalConfig_R", (DL_FUNC) &XGBSetGlobalConfig_R, 1},
|
{"XGBSetGlobalConfig_R", (DL_FUNC) &XGBSetGlobalConfig_R, 1},
|
||||||
{"XGBGetGlobalConfig_R", (DL_FUNC) &XGBGetGlobalConfig_R, 0},
|
{"XGBGetGlobalConfig_R", (DL_FUNC) &XGBGetGlobalConfig_R, 0},
|
||||||
{"XGBoosterFeatureScore_R", (DL_FUNC) &XGBoosterFeatureScore_R, 2},
|
{"XGBoosterFeatureScore_R", (DL_FUNC) &XGBoosterFeatureScore_R, 2},
|
||||||
|
|||||||
@ -512,7 +512,7 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
|
XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset, SEXP allow_groups) {
|
||||||
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
|
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
|
||||||
R_API_BEGIN();
|
R_API_BEGIN();
|
||||||
R_xlen_t len = Rf_xlength(idxset);
|
R_xlen_t len = Rf_xlength(idxset);
|
||||||
@ -531,7 +531,7 @@ XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
|
|||||||
res_code = XGDMatrixSliceDMatrixEx(R_ExternalPtrAddr(handle),
|
res_code = XGDMatrixSliceDMatrixEx(R_ExternalPtrAddr(handle),
|
||||||
BeginPtr(idxvec), len,
|
BeginPtr(idxvec), len,
|
||||||
&res,
|
&res,
|
||||||
0);
|
Rf_asLogical(allow_groups));
|
||||||
}
|
}
|
||||||
CHECK_CALL(res_code);
|
CHECK_CALL(res_code);
|
||||||
R_SetExternalPtrAddr(ret, res);
|
R_SetExternalPtrAddr(ret, res);
|
||||||
|
|||||||
@ -112,9 +112,10 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP
|
|||||||
* \brief create a new dmatrix from sliced content of existing matrix
|
* \brief create a new dmatrix from sliced content of existing matrix
|
||||||
* \param handle instance of data matrix to be sliced
|
* \param handle instance of data matrix to be sliced
|
||||||
* \param idxset index set
|
* \param idxset index set
|
||||||
|
* \param allow_groups Whether to allow slicing the DMatrix if it has a 'group' field
|
||||||
* \return a sliced new matrix
|
* \return a sliced new matrix
|
||||||
*/
|
*/
|
||||||
XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset);
|
XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset, SEXP allow_groups);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief load a data matrix into binary file
|
* \brief load a data matrix into binary file
|
||||||
|
|||||||
@ -334,7 +334,7 @@ test_that("xgb.cv works", {
|
|||||||
set.seed(11)
|
set.seed(11)
|
||||||
expect_output(
|
expect_output(
|
||||||
cv <- xgb.cv(
|
cv <- xgb.cv(
|
||||||
data = train$data, label = train$label, max_depth = 2, nfold = 5,
|
data = xgb.DMatrix(train$data, label = train$label), max_depth = 2, nfold = 5,
|
||||||
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
|
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
|
||||||
eval_metric = "error", verbose = TRUE
|
eval_metric = "error", verbose = TRUE
|
||||||
),
|
),
|
||||||
@ -357,13 +357,13 @@ test_that("xgb.cv works with stratified folds", {
|
|||||||
cv <- xgb.cv(
|
cv <- xgb.cv(
|
||||||
data = dtrain, max_depth = 2, nfold = 5,
|
data = dtrain, max_depth = 2, nfold = 5,
|
||||||
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
|
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
|
||||||
verbose = TRUE, stratified = FALSE
|
verbose = FALSE, stratified = FALSE
|
||||||
)
|
)
|
||||||
set.seed(314159)
|
set.seed(314159)
|
||||||
cv2 <- xgb.cv(
|
cv2 <- xgb.cv(
|
||||||
data = dtrain, max_depth = 2, nfold = 5,
|
data = dtrain, max_depth = 2, nfold = 5,
|
||||||
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
|
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
|
||||||
verbose = TRUE, stratified = TRUE
|
verbose = FALSE, stratified = TRUE
|
||||||
)
|
)
|
||||||
# Stratified folds should result in a different evaluation logs
|
# Stratified folds should result in a different evaluation logs
|
||||||
expect_true(all(cv$evaluation_log[, test_logloss_mean] != cv2$evaluation_log[, test_logloss_mean]))
|
expect_true(all(cv$evaluation_log[, test_logloss_mean] != cv2$evaluation_log[, test_logloss_mean]))
|
||||||
@ -885,3 +885,57 @@ test_that("Seed in params override PRNG from R", {
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("xgb.cv works for AFT", {
|
||||||
|
X <- matrix(c(1, -1, -1, 1, 0, 1, 1, 0), nrow = 4, byrow = TRUE) # 4x2 matrix
|
||||||
|
dtrain <- xgb.DMatrix(X, nthread = n_threads)
|
||||||
|
|
||||||
|
params <- list(objective = 'survival:aft', learning_rate = 0.2, max_depth = 2L)
|
||||||
|
|
||||||
|
# data must have bounds
|
||||||
|
expect_error(
|
||||||
|
xgb.cv(
|
||||||
|
params = params,
|
||||||
|
data = dtrain,
|
||||||
|
nround = 5L,
|
||||||
|
nfold = 4L,
|
||||||
|
nthread = n_threads
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
setinfo(dtrain, 'label_lower_bound', c(2, 3, 0, 4))
|
||||||
|
setinfo(dtrain, 'label_upper_bound', c(2, Inf, 4, 5))
|
||||||
|
|
||||||
|
# automatic stratified splitting is turned off
|
||||||
|
expect_warning(
|
||||||
|
xgb.cv(
|
||||||
|
params = params, data = dtrain, nround = 5L, nfold = 4L,
|
||||||
|
nthread = n_threads, stratified = TRUE, verbose = FALSE
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# this works without any issue
|
||||||
|
expect_no_warning(
|
||||||
|
xgb.cv(params = params, data = dtrain, nround = 5L, nfold = 4L, verbose = FALSE)
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
test_that("xgb.cv works for ranking", {
|
||||||
|
data(iris)
|
||||||
|
x <- iris[, -(4:5)]
|
||||||
|
y <- as.integer(iris$Petal.Width)
|
||||||
|
group <- rep(50, 3)
|
||||||
|
dm <- xgb.DMatrix(x, label = y, group = group)
|
||||||
|
res <- xgb.cv(
|
||||||
|
data = dm,
|
||||||
|
params = list(
|
||||||
|
objective = "rank:pairwise",
|
||||||
|
max_depth = 3
|
||||||
|
),
|
||||||
|
nrounds = 3,
|
||||||
|
nfold = 2,
|
||||||
|
verbose = FALSE,
|
||||||
|
stratified = FALSE
|
||||||
|
)
|
||||||
|
expect_equal(length(res$folds), 2L)
|
||||||
|
})
|
||||||
|
|||||||
@ -367,7 +367,7 @@ test_that("prediction in early-stopping xgb.cv works", {
|
|||||||
expect_output(
|
expect_output(
|
||||||
cv <- xgb.cv(param, dtrain, nfold = 5, eta = 0.1, nrounds = 20,
|
cv <- xgb.cv(param, dtrain, nfold = 5, eta = 0.1, nrounds = 20,
|
||||||
early_stopping_rounds = 5, maximize = FALSE, stratified = FALSE,
|
early_stopping_rounds = 5, maximize = FALSE, stratified = FALSE,
|
||||||
prediction = TRUE, base_score = 0.5)
|
prediction = TRUE, base_score = 0.5, verbose = TRUE)
|
||||||
, "Stopping. Best iteration")
|
, "Stopping. Best iteration")
|
||||||
|
|
||||||
expect_false(is.null(cv$early_stop$best_iteration))
|
expect_false(is.null(cv$early_stop$best_iteration))
|
||||||
@ -387,7 +387,7 @@ test_that("prediction in xgb.cv for softprob works", {
|
|||||||
lb <- as.numeric(iris$Species) - 1
|
lb <- as.numeric(iris$Species) - 1
|
||||||
set.seed(11)
|
set.seed(11)
|
||||||
expect_warning(
|
expect_warning(
|
||||||
cv <- xgb.cv(data = as.matrix(iris[, -5]), label = lb, nfold = 4,
|
cv <- xgb.cv(data = xgb.DMatrix(as.matrix(iris[, -5]), label = lb), nfold = 4,
|
||||||
eta = 0.5, nrounds = 5, max_depth = 3, nthread = n_threads,
|
eta = 0.5, nrounds = 5, max_depth = 3, nthread = n_threads,
|
||||||
subsample = 0.8, gamma = 2, verbose = 0,
|
subsample = 0.8, gamma = 2, verbose = 0,
|
||||||
prediction = TRUE, objective = "multi:softprob", num_class = 3)
|
prediction = TRUE, objective = "multi:softprob", num_class = 3)
|
||||||
|
|||||||
@ -243,7 +243,7 @@ test_that("xgb.DMatrix: print", {
|
|||||||
txt <- capture.output({
|
txt <- capture.output({
|
||||||
print(dtrain)
|
print(dtrain)
|
||||||
})
|
})
|
||||||
expect_equal(txt, "xgb.DMatrix dim: 6513 x 126 info: label weight base_margin colnames: yes")
|
expect_equal(txt, "xgb.DMatrix dim: 6513 x 126 info: base_margin, label, weight colnames: yes")
|
||||||
|
|
||||||
# DMatrix with just features
|
# DMatrix with just features
|
||||||
dtrain <- xgb.DMatrix(
|
dtrain <- xgb.DMatrix(
|
||||||
@ -724,6 +724,44 @@ test_that("xgb.DMatrix: quantile cuts look correct", {
|
|||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("xgb.DMatrix: slicing keeps field indicators", {
|
||||||
|
data(mtcars)
|
||||||
|
x <- as.matrix(mtcars[, -1])
|
||||||
|
y <- mtcars[, 1]
|
||||||
|
dm <- xgb.DMatrix(
|
||||||
|
data = x,
|
||||||
|
label_lower_bound = -y,
|
||||||
|
label_upper_bound = y,
|
||||||
|
nthread = 1
|
||||||
|
)
|
||||||
|
idx_take <- seq(1, 5)
|
||||||
|
dm_slice <- xgb.slice.DMatrix(dm, idx_take)
|
||||||
|
|
||||||
|
expect_true(xgb.DMatrix.hasinfo(dm_slice, "label_lower_bound"))
|
||||||
|
expect_true(xgb.DMatrix.hasinfo(dm_slice, "label_upper_bound"))
|
||||||
|
expect_false(xgb.DMatrix.hasinfo(dm_slice, "label"))
|
||||||
|
|
||||||
|
expect_equal(getinfo(dm_slice, "label_lower_bound"), -y[idx_take], tolerance = 1e-6)
|
||||||
|
expect_equal(getinfo(dm_slice, "label_upper_bound"), y[idx_take], tolerance = 1e-6)
|
||||||
|
})
|
||||||
|
|
||||||
|
test_that("xgb.DMatrix: can slice with groups", {
|
||||||
|
data(iris)
|
||||||
|
x <- as.matrix(iris[, -5])
|
||||||
|
set.seed(123)
|
||||||
|
y <- sample(3, size = nrow(x), replace = TRUE)
|
||||||
|
group <- c(50, 50, 50)
|
||||||
|
dm <- xgb.DMatrix(x, label = y, group = group, nthread = 1)
|
||||||
|
idx_take <- seq(1, 50)
|
||||||
|
dm_slice <- xgb.slice.DMatrix(dm, idx_take, allow_groups = TRUE)
|
||||||
|
|
||||||
|
expect_true(xgb.DMatrix.hasinfo(dm_slice, "label"))
|
||||||
|
expect_false(xgb.DMatrix.hasinfo(dm_slice, "group"))
|
||||||
|
expect_false(xgb.DMatrix.hasinfo(dm_slice, "qid"))
|
||||||
|
expect_null(getinfo(dm_slice, "group"))
|
||||||
|
expect_equal(getinfo(dm_slice, "label"), y[idx_take], tolerance = 1e-6)
|
||||||
|
})
|
||||||
|
|
||||||
test_that("xgb.DMatrix: can read CSV", {
|
test_that("xgb.DMatrix: can read CSV", {
|
||||||
txt <- paste(
|
txt <- paste(
|
||||||
"1,2,3",
|
"1,2,3",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user