diff --git a/.github/dependabot.yml b/.github/dependabot.yml
index c03a52c60..0cc0c16fd 100644
--- a/.github/dependabot.yml
+++ b/.github/dependabot.yml
@@ -8,7 +8,7 @@ updates:
- package-ecosystem: "maven"
directory: "/jvm-packages"
schedule:
- interval: "daily"
+ interval: "monthly"
- package-ecosystem: "maven"
directory: "/jvm-packages/xgboost4j"
schedule:
@@ -16,11 +16,11 @@ updates:
- package-ecosystem: "maven"
directory: "/jvm-packages/xgboost4j-gpu"
schedule:
- interval: "daily"
+ interval: "monthly"
- package-ecosystem: "maven"
directory: "/jvm-packages/xgboost4j-example"
schedule:
- interval: "daily"
+ interval: "monthly"
- package-ecosystem: "maven"
directory: "/jvm-packages/xgboost4j-spark"
schedule:
@@ -28,4 +28,4 @@ updates:
- package-ecosystem: "maven"
directory: "/jvm-packages/xgboost4j-spark-gpu"
schedule:
- interval: "daily"
+ interval: "monthly"
diff --git a/.github/workflows/r_tests.yml b/.github/workflows/r_tests.yml
index 045dac575..7dbdf3a84 100644
--- a/.github/workflows/r_tests.yml
+++ b/.github/workflows/r_tests.yml
@@ -110,7 +110,7 @@ jobs:
name: Test R package on Debian
runs-on: ubuntu-latest
container:
- image: rhub/debian-gcc-devel
+ image: rhub/debian-gcc-release
steps:
- name: Install system dependencies
@@ -130,12 +130,12 @@ jobs:
- name: Install dependencies
shell: bash -l {0}
run: |
- /tmp/R-devel/bin/Rscript -e "source('./R-package/tests/helper_scripts/install_deps.R')"
+ Rscript -e "source('./R-package/tests/helper_scripts/install_deps.R')"
- name: Test R
shell: bash -l {0}
run: |
- python3 tests/ci_build/test_r_package.py --r=/tmp/R-devel/bin/R --build-tool=autotools --task=check
+ python3 tests/ci_build/test_r_package.py --r=/usr/bin/R --build-tool=autotools --task=check
- uses: dorny/paths-filter@v2
id: changes
@@ -147,4 +147,4 @@ jobs:
- name: Run document check
if: steps.changes.outputs.r_package == 'true'
run: |
- python3 tests/ci_build/test_r_package.py --r=/tmp/R-devel/bin/R --task=doc
+ python3 tests/ci_build/test_r_package.py --r=/usr/bin/R --task=doc
diff --git a/.github/workflows/update_rapids.yml b/.github/workflows/update_rapids.yml
index 395a42148..22a395799 100644
--- a/.github/workflows/update_rapids.yml
+++ b/.github/workflows/update_rapids.yml
@@ -3,7 +3,7 @@ name: update-rapids
on:
workflow_dispatch:
schedule:
- - cron: "0 20 * * *" # Run once daily
+ - cron: "0 20 * * 1" # Run once weekly
permissions:
pull-requests: write
@@ -32,7 +32,7 @@ jobs:
run: |
bash tests/buildkite/update-rapids.sh
- name: Create Pull Request
- uses: peter-evans/create-pull-request@v5
+ uses: peter-evans/create-pull-request@v6
if: github.ref == 'refs/heads/master'
with:
add-paths: |
diff --git a/NEWS.md b/NEWS.md
index 43019d877..b067c8e3c 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -2101,7 +2101,7 @@ This release marks a major milestone for the XGBoost project.
## v0.90 (2019.05.18)
### XGBoost Python package drops Python 2.x (#4379, #4381)
-Python 2.x is reaching its end-of-life at the end of this year. [Many scientific Python packages are now moving to drop Python 2.x](https://python3statement.org/).
+Python 2.x is reaching its end-of-life at the end of this year. [Many scientific Python packages are now moving to drop Python 2.x](https://python3statement.github.io/).
### XGBoost4J-Spark now requires Spark 2.4.x (#4377)
* Spark 2.3 is reaching its end-of-life soon. See discussion at #4389.
diff --git a/R-package/R/utils.R b/R-package/R/utils.R
index 01c282a96..7b6a20f70 100644
--- a/R-package/R/utils.R
+++ b/R-package/R/utils.R
@@ -26,6 +26,11 @@ NVL <- function(x, val) {
'multi:softprob', 'rank:pairwise', 'rank:ndcg', 'rank:map'))
}
+.RANKING_OBJECTIVES <- function() {
+ return(c('binary:logistic', 'binary:logitraw', 'binary:hinge', 'multi:softmax',
+ 'multi:softprob'))
+}
+
#
# Low-level functions for boosting --------------------------------------------
@@ -235,33 +240,43 @@ convert.labels <- function(labels, objective_name) {
}
# Generates random (stratified if needed) CV folds
-generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
+generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
+ if (NROW(group)) {
+ if (stratified) {
+ warning(
+ paste0(
+ "Stratified splitting is not supported when using 'group' attribute.",
+ " Will use unstratified splitting."
+ )
+ )
+ }
+ return(generate.group.folds(nfold, group))
+ }
+ objective <- params$objective
+ if (!is.character(objective)) {
+ warning("Will use unstratified splitting (custom objective used)")
+ stratified <- FALSE
+ }
+ # cannot stratify if label is NULL
+ if (stratified && is.null(label)) {
+ warning("Will use unstratified splitting (no 'labels' available)")
+ stratified <- FALSE
+ }
# cannot do it for rank
- objective <- params$objective
if (is.character(objective) && strtrim(objective, 5) == 'rank:') {
- stop("\n\tAutomatic generation of CV-folds is not implemented for ranking!\n",
+ stop("\n\tAutomatic generation of CV-folds is not implemented for ranking without 'group' field!\n",
"\tConsider providing pre-computed CV-folds through the 'folds=' parameter.\n")
}
# shuffle
rnd_idx <- sample.int(nrows)
- if (stratified &&
- length(label) == length(rnd_idx)) {
+ if (stratified && length(label) == length(rnd_idx)) {
y <- label[rnd_idx]
- # WARNING: some heuristic logic is employed to identify classification setting!
# - For classification, need to convert y labels to factor before making the folds,
# and then do stratification by factor levels.
# - For regression, leave y numeric and do stratification by quantiles.
if (is.character(objective)) {
- y <- convert.labels(y, params$objective)
- } else {
- # If no 'objective' given in params, it means that user either wants to
- # use the default 'reg:squarederror' objective or has provided a custom
- # obj function. Here, assume classification setting when y has 5 or less
- # unique values:
- if (length(unique(y)) <= 5) {
- y <- factor(y)
- }
+ y <- convert.labels(y, objective)
}
folds <- xgb.createFolds(y = y, k = nfold)
} else {
@@ -277,6 +292,29 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
return(folds)
}
+generate.group.folds <- function(nfold, group) {
+ ngroups <- length(group) - 1
+ if (ngroups < nfold) {
+ stop("DMatrix has fewer groups than folds.")
+ }
+ seq_groups <- seq_len(ngroups)
+ indices <- lapply(seq_groups, function(gr) seq(group[gr] + 1, group[gr + 1]))
+ assignments <- base::split(seq_groups, as.integer(seq_groups %% nfold))
+ assignments <- unname(assignments)
+
+ out <- vector("list", nfold)
+ randomized_groups <- sample(ngroups)
+ for (idx in seq_len(nfold)) {
+ groups_idx_test <- randomized_groups[assignments[[idx]]]
+ groups_test <- indices[groups_idx_test]
+ idx_test <- unlist(groups_test)
+ attributes(idx_test)$group_test <- lengths(groups_test)
+ attributes(idx_test)$group_train <- lengths(indices[-groups_idx_test])
+ out[[idx]] <- idx_test
+ }
+ return(out)
+}
+
# Creates CV folds stratified by the values of y.
# It was borrowed from caret::createFolds and simplified
# by always returning an unnamed list of fold indices.
diff --git a/R-package/R/xgb.DMatrix.R b/R-package/R/xgb.DMatrix.R
index edbc267c1..15f6faed0 100644
--- a/R-package/R/xgb.DMatrix.R
+++ b/R-package/R/xgb.DMatrix.R
@@ -1259,8 +1259,11 @@ xgb.get.DMatrix.data <- function(dmat) {
#' Get a new DMatrix containing the specified rows of
#' original xgb.DMatrix object
#'
-#' @param object Object of class "xgb.DMatrix"
-#' @param idxset a integer vector of indices of rows needed
+#' @param object Object of class "xgb.DMatrix".
+#' @param idxset An integer vector of indices of rows needed (base-1 indexing).
+#' @param allow_groups Whether to allow slicing an `xgb.DMatrix` with `group` (or
+#' equivalently `qid`) field. Note that in such case, the result will not have
+#' the groups anymore - they need to be set manually through `setinfo`.
#' @param colset currently not used (columns subsetting is not available)
#'
#' @examples
@@ -1275,11 +1278,11 @@ xgb.get.DMatrix.data <- function(dmat) {
#'
#' @rdname xgb.slice.DMatrix
#' @export
-xgb.slice.DMatrix <- function(object, idxset) {
+xgb.slice.DMatrix <- function(object, idxset, allow_groups = FALSE) {
if (!inherits(object, "xgb.DMatrix")) {
stop("object must be xgb.DMatrix")
}
- ret <- .Call(XGDMatrixSliceDMatrix_R, object, idxset)
+ ret <- .Call(XGDMatrixSliceDMatrix_R, object, idxset, allow_groups)
attr_list <- attributes(object)
nr <- nrow(object)
@@ -1296,7 +1299,15 @@ xgb.slice.DMatrix <- function(object, idxset) {
}
}
}
- return(structure(ret, class = "xgb.DMatrix"))
+
+ out <- structure(ret, class = "xgb.DMatrix")
+ parent_fields <- as.list(attributes(object)$fields)
+ if (NROW(parent_fields)) {
+ child_fields <- parent_fields[!(names(parent_fields) %in% c("group", "qid"))]
+ child_fields <- as.environment(child_fields)
+ attributes(out)$fields <- child_fields
+ }
+ return(out)
}
#' @rdname xgb.slice.DMatrix
@@ -1340,11 +1351,11 @@ print.xgb.DMatrix <- function(x, verbose = FALSE, ...) {
}
cat(class_print, ' dim:', nrow(x), 'x', ncol(x), ' info: ')
- infos <- character(0)
- if (xgb.DMatrix.hasinfo(x, 'label')) infos <- 'label'
- if (xgb.DMatrix.hasinfo(x, 'weight')) infos <- c(infos, 'weight')
- if (xgb.DMatrix.hasinfo(x, 'base_margin')) infos <- c(infos, 'base_margin')
- if (length(infos) == 0) infos <- 'NA'
+ infos <- names(attributes(x)$fields)
+ infos <- infos[infos != "feature_name"]
+ if (!NROW(infos)) infos <- "NA"
+ infos <- infos[order(infos)]
+ infos <- paste(infos, collapse = ", ")
cat(infos)
cnames <- colnames(x)
cat(' colnames:')
diff --git a/R-package/R/xgb.cv.R b/R-package/R/xgb.cv.R
index 1cafd7be7..880fd5697 100644
--- a/R-package/R/xgb.cv.R
+++ b/R-package/R/xgb.cv.R
@@ -1,6 +1,6 @@
#' Cross Validation
#'
-#' The cross validation function of xgboost
+#' The cross validation function of xgboost.
#'
#' @param params the list of parameters. The complete list of parameters is
#' available in the \href{http://xgboost.readthedocs.io/en/latest/parameter.html}{online documentation}. Below
@@ -19,13 +19,17 @@
#'
#' See \code{\link{xgb.train}} for further details.
#' See also demo/ for walkthrough example in R.
-#' @param data takes an \code{xgb.DMatrix}, \code{matrix}, or \code{dgCMatrix} as the input.
+#'
+#' Note that, while `params` accepts a `seed` entry and will use such parameter for model training if
+#' supplied, this seed is not used for creation of train-test splits, which instead rely on R's own RNG
+#' system - thus, for reproducible results, one needs to call the `set.seed` function beforehand.
+#' @param data An `xgb.DMatrix` object, with corresponding fields like `label` or bounds as required
+#' for model training by the objective.
+#'
+#' Note that only the basic `xgb.DMatrix` class is supported - variants such as `xgb.QuantileDMatrix`
+#' or `xgb.ExternalDMatrix` are not supported here.
#' @param nrounds the max number of iterations
#' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples.
-#' @param label vector of response values. Should be provided only when data is an R-matrix.
-#' @param missing is only used when input is a dense matrix. By default is set to NA, which means
-#' that NA values should be considered as 'missing' by the algorithm.
-#' Sometimes, 0 or other extreme value might be used to represent missing values.
#' @param prediction A logical value indicating whether to return the test fold predictions
#' from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback.
#' @param showsd \code{boolean}, whether to show standard deviation of cross validation
@@ -47,13 +51,30 @@
#' @param feval customized evaluation function. Returns
#' \code{list(metric='metric-name', value='metric-value')} with given
#' prediction and dtrain.
-#' @param stratified a \code{boolean} indicating whether sampling of folds should be stratified
-#' by the values of outcome labels.
+#' @param stratified A \code{boolean} indicating whether sampling of folds should be stratified
+#' by the values of outcome labels. For real-valued labels in regression objectives,
+#' stratification will be done by discretizing the labels into up to 5 buckets beforehand.
+#'
+#' If passing "auto", will be set to `TRUE` if the objective in `params` is a classification
+#' objective (from XGBoost's built-in objectives, doesn't apply to custom ones), and to
+#' `FALSE` otherwise.
+#'
+#' This parameter is ignored when `data` has a `group` field - in such case, the splitting
+#' will be based on whole groups (note that this might make the folds have different sizes).
+#'
+#' Value `TRUE` here is \bold{not} supported for custom objectives.
#' @param folds \code{list} provides a possibility to use a list of pre-defined CV folds
#' (each element must be a vector of test fold's indices). When folds are supplied,
#' the \code{nfold} and \code{stratified} parameters are ignored.
+#'
+#' If `data` has a `group` field and the objective requires this field, each fold (list element)
+#' must additionally have two attributes (retrievable through \link{attributes}) named `group_test`
+#' and `group_train`, which should hold the `group` to assign through \link{setinfo.xgb.DMatrix} to
+#' the resulting DMatrices.
#' @param train_folds \code{list} list specifying which indicies to use for training. If \code{NULL}
#' (the default) all indices not specified in \code{folds} will be used for training.
+#'
+#' This is not supported when `data` has `group` field.
#' @param verbose \code{boolean}, print the statistics during the process
#' @param print_every_n Print each n-th iteration evaluation messages when \code{verbose>0}.
#' Default is 1 which means all messages are printed. This parameter is passed to the
@@ -118,13 +139,14 @@
#' print(cv, verbose=TRUE)
#'
#' @export
-xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing = NA,
+xgb.cv <- function(params = list(), data, nrounds, nfold,
prediction = FALSE, showsd = TRUE, metrics = list(),
- obj = NULL, feval = NULL, stratified = TRUE, folds = NULL, train_folds = NULL,
+ obj = NULL, feval = NULL, stratified = "auto", folds = NULL, train_folds = NULL,
verbose = TRUE, print_every_n = 1L,
early_stopping_rounds = NULL, maximize = NULL, callbacks = list(), ...) {
check.deprecation(...)
+ stopifnot(inherits(data, "xgb.DMatrix"))
if (inherits(data, "xgb.DMatrix") && .Call(XGCheckNullPtr_R, data)) {
stop("'data' is an invalid 'xgb.DMatrix' object. Must be constructed again.")
}
@@ -137,16 +159,22 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
check.custom.obj()
check.custom.eval()
- # Check the labels
- if ((inherits(data, 'xgb.DMatrix') && !xgb.DMatrix.hasinfo(data, 'label')) ||
- (!inherits(data, 'xgb.DMatrix') && is.null(label))) {
- stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix")
- } else if (inherits(data, 'xgb.DMatrix')) {
- if (!is.null(label))
- warning("xgb.cv: label will be ignored, since data is of type xgb.DMatrix")
- cv_label <- getinfo(data, 'label')
- } else {
- cv_label <- label
+ if (stratified == "auto") {
+ if (is.character(params$objective)) {
+ stratified <- (
+ (params$objective %in% .CLASSIFICATION_OBJECTIVES())
+ && !(params$objective %in% .RANKING_OBJECTIVES())
+ )
+ } else {
+ stratified <- FALSE
+ }
+ }
+
+ # Check the labels and groups
+ cv_label <- getinfo(data, "label")
+ cv_group <- getinfo(data, "group")
+ if (!is.null(train_folds) && NROW(cv_group)) {
+ stop("'train_folds' is not supported for DMatrix object with 'group' field.")
}
# CV folds
@@ -157,7 +185,7 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
} else {
if (nfold <= 1)
stop("'nfold' must be > 1")
- folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, params)
+ folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, cv_group, params)
}
# Callbacks
@@ -195,20 +223,18 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
# create the booster-folds
# train_folds
- dall <- xgb.get.DMatrix(
- data = data,
- label = label,
- missing = missing,
- weight = NULL,
- nthread = params$nthread
- )
+ dall <- data
bst_folds <- lapply(seq_along(folds), function(k) {
- dtest <- xgb.slice.DMatrix(dall, folds[[k]])
+ dtest <- xgb.slice.DMatrix(dall, folds[[k]], allow_groups = TRUE)
# code originally contributed by @RolandASc on stackoverflow
if (is.null(train_folds))
- dtrain <- xgb.slice.DMatrix(dall, unlist(folds[-k]))
+ dtrain <- xgb.slice.DMatrix(dall, unlist(folds[-k]), allow_groups = TRUE)
else
- dtrain <- xgb.slice.DMatrix(dall, train_folds[[k]])
+ dtrain <- xgb.slice.DMatrix(dall, train_folds[[k]], allow_groups = TRUE)
+ if (!is.null(attributes(folds[[k]])$group_test)) {
+ setinfo(dtest, "group", attributes(folds[[k]])$group_test)
+ setinfo(dtrain, "group", attributes(folds[[k]])$group_train)
+ }
bst <- xgb.Booster(
params = params,
cachelist = list(dtrain, dtest),
@@ -312,8 +338,8 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
#' @examples
#' data(agaricus.train, package='xgboost')
#' train <- agaricus.train
-#' cv <- xgb.cv(data = train$data, label = train$label, nfold = 5, max_depth = 2,
-#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
+#' cv <- xgb.cv(data = xgb.DMatrix(train$data, label = train$label), nfold = 5, max_depth = 2,
+#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
#' print(cv)
#' print(cv, verbose=TRUE)
#'
diff --git a/R-package/man/print.xgb.cv.Rd b/R-package/man/print.xgb.cv.Rd
index 05ad61eed..74fc15d01 100644
--- a/R-package/man/print.xgb.cv.Rd
+++ b/R-package/man/print.xgb.cv.Rd
@@ -23,8 +23,8 @@ including the best iteration (when available).
\examples{
data(agaricus.train, package='xgboost')
train <- agaricus.train
-cv <- xgb.cv(data = train$data, label = train$label, nfold = 5, max_depth = 2,
- eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
+cv <- xgb.cv(data = xgb.DMatrix(train$data, label = train$label), nfold = 5, max_depth = 2,
+ eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
print(cv)
print(cv, verbose=TRUE)
diff --git a/R-package/man/xgb.cv.Rd b/R-package/man/xgb.cv.Rd
index 778b4540a..cede67570 100644
--- a/R-package/man/xgb.cv.Rd
+++ b/R-package/man/xgb.cv.Rd
@@ -9,14 +9,12 @@ xgb.cv(
data,
nrounds,
nfold,
- label = NULL,
- missing = NA,
prediction = FALSE,
showsd = TRUE,
metrics = list(),
obj = NULL,
feval = NULL,
- stratified = TRUE,
+ stratified = "auto",
folds = NULL,
train_folds = NULL,
verbose = TRUE,
@@ -44,20 +42,23 @@ is a shorter summary:
}
See \code{\link{xgb.train}} for further details.
-See also demo/ for walkthrough example in R.}
+See also demo/ for walkthrough example in R.
-\item{data}{takes an \code{xgb.DMatrix}, \code{matrix}, or \code{dgCMatrix} as the input.}
+Note that, while \code{params} accepts a \code{seed} entry and will use such parameter for model training if
+supplied, this seed is not used for creation of train-test splits, which instead rely on R's own RNG
+system - thus, for reproducible results, one needs to call the \code{set.seed} function beforehand.}
+
+\item{data}{An \code{xgb.DMatrix} object, with corresponding fields like \code{label} or bounds as required
+for model training by the objective.
+
+\if{html}{\out{
}}\preformatted{ Note that only the basic `xgb.DMatrix` class is supported - variants such as `xgb.QuantileDMatrix`
+ or `xgb.ExternalDMatrix` are not supported here.
+}\if{html}{\out{
}}}
\item{nrounds}{the max number of iterations}
\item{nfold}{the original dataset is randomly partitioned into \code{nfold} equal size subsamples.}
-\item{label}{vector of response values. Should be provided only when data is an R-matrix.}
-
-\item{missing}{is only used when input is a dense matrix. By default is set to NA, which means
-that NA values should be considered as 'missing' by the algorithm.
-Sometimes, 0 or other extreme value might be used to represent missing values.}
-
\item{prediction}{A logical value indicating whether to return the test fold predictions
from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback.}
@@ -84,15 +85,35 @@ gradient with given prediction and dtrain.}
\code{list(metric='metric-name', value='metric-value')} with given
prediction and dtrain.}
-\item{stratified}{a \code{boolean} indicating whether sampling of folds should be stratified
-by the values of outcome labels.}
+\item{stratified}{A \code{boolean} indicating whether sampling of folds should be stratified
+by the values of outcome labels. For real-valued labels in regression objectives,
+stratification will be done by discretizing the labels into up to 5 buckets beforehand.
+
+\if{html}{\out{
}}\preformatted{ If passing "auto", will be set to `TRUE` if the objective in `params` is a classification
+ objective (from XGBoost's built-in objectives, doesn't apply to custom ones), and to
+ `FALSE` otherwise.
+
+ This parameter is ignored when `data` has a `group` field - in such case, the splitting
+ will be based on whole groups (note that this might make the folds have different sizes).
+
+ Value `TRUE` here is \\bold\{not\} supported for custom objectives.
+}\if{html}{\out{
}}}
\item{folds}{\code{list} provides a possibility to use a list of pre-defined CV folds
(each element must be a vector of test fold's indices). When folds are supplied,
-the \code{nfold} and \code{stratified} parameters are ignored.}
+the \code{nfold} and \code{stratified} parameters are ignored.
+
+\if{html}{\out{
}}\preformatted{ If `data` has a `group` field and the objective requires this field, each fold (list element)
+ must additionally have two attributes (retrievable through \link{attributes}) named `group_test`
+ and `group_train`, which should hold the `group` to assign through \link{setinfo.xgb.DMatrix} to
+ the resulting DMatrices.
+}\if{html}{\out{
}}}
\item{train_folds}{\code{list} list specifying which indicies to use for training. If \code{NULL}
-(the default) all indices not specified in \code{folds} will be used for training.}
+(the default) all indices not specified in \code{folds} will be used for training.
+
+\if{html}{\out{
}}\preformatted{ This is not supported when `data` has `group` field.
+}\if{html}{\out{
}}}
\item{verbose}{\code{boolean}, print the statistics during the process}
@@ -142,7 +163,7 @@ such as saving also the models created during cross validation); or a list \code
will contain elements such as \code{best_iteration} when using the early stopping callback (\link{xgb.cb.early.stop}).
}
\description{
-The cross validation function of xgboost
+The cross validation function of xgboost.
}
\details{
The original sample is randomly partitioned into \code{nfold} equal size subsamples.
diff --git a/R-package/man/xgb.slice.DMatrix.Rd b/R-package/man/xgb.slice.DMatrix.Rd
index c9695996b..c4f776594 100644
--- a/R-package/man/xgb.slice.DMatrix.Rd
+++ b/R-package/man/xgb.slice.DMatrix.Rd
@@ -6,14 +6,18 @@
\title{Get a new DMatrix containing the specified rows of
original xgb.DMatrix object}
\usage{
-xgb.slice.DMatrix(object, idxset)
+xgb.slice.DMatrix(object, idxset, allow_groups = FALSE)
\method{[}{xgb.DMatrix}(object, idxset, colset = NULL)
}
\arguments{
-\item{object}{Object of class "xgb.DMatrix"}
+\item{object}{Object of class "xgb.DMatrix".}
-\item{idxset}{a integer vector of indices of rows needed}
+\item{idxset}{An integer vector of indices of rows needed (base-1 indexing).}
+
+\item{allow_groups}{Whether to allow slicing an \code{xgb.DMatrix} with \code{group} (or
+equivalently \code{qid}) field. Note that in such case, the result will not have
+the groups anymore - they need to be set manually through \code{setinfo}.}
\item{colset}{currently not used (columns subsetting is not available)}
}
diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in
index 0f4b3ac6f..69cdd09a3 100644
--- a/R-package/src/Makevars.in
+++ b/R-package/src/Makevars.in
@@ -99,10 +99,12 @@ OBJECTS= \
$(PKGROOT)/src/context.o \
$(PKGROOT)/src/logging.o \
$(PKGROOT)/src/global_config.o \
+ $(PKGROOT)/src/collective/result.o \
$(PKGROOT)/src/collective/allgather.o \
$(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \
+ $(PKGROOT)/src/collective/comm_group.o \
$(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/communicator-inl.o \
$(PKGROOT)/src/collective/tracker.o \
diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win
index 0c2084de9..b34d8c649 100644
--- a/R-package/src/Makevars.win
+++ b/R-package/src/Makevars.win
@@ -99,10 +99,12 @@ OBJECTS= \
$(PKGROOT)/src/context.o \
$(PKGROOT)/src/logging.o \
$(PKGROOT)/src/global_config.o \
+ $(PKGROOT)/src/collective/result.o \
$(PKGROOT)/src/collective/allgather.o \
$(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \
+ $(PKGROOT)/src/collective/comm_group.o \
$(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/communicator-inl.o \
$(PKGROOT)/src/collective/tracker.o \
diff --git a/R-package/src/init.c b/R-package/src/init.c
index c869871c6..5db3218b4 100644
--- a/R-package/src/init.c
+++ b/R-package/src/init.c
@@ -71,7 +71,7 @@ extern SEXP XGDMatrixGetDataAsCSR_R(SEXP);
extern SEXP XGDMatrixSaveBinary_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixSetStrFeatureInfo_R(SEXP, SEXP, SEXP);
-extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP);
+extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP, SEXP);
extern SEXP XGBSetGlobalConfig_R(SEXP);
extern SEXP XGBGetGlobalConfig_R(void);
extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP);
@@ -134,7 +134,7 @@ static const R_CallMethodDef CallEntries[] = {
{"XGDMatrixSaveBinary_R", (DL_FUNC) &XGDMatrixSaveBinary_R, 3},
{"XGDMatrixSetInfo_R", (DL_FUNC) &XGDMatrixSetInfo_R, 3},
{"XGDMatrixSetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixSetStrFeatureInfo_R, 3},
- {"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 2},
+ {"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 3},
{"XGBSetGlobalConfig_R", (DL_FUNC) &XGBSetGlobalConfig_R, 1},
{"XGBGetGlobalConfig_R", (DL_FUNC) &XGBGetGlobalConfig_R, 0},
{"XGBoosterFeatureScore_R", (DL_FUNC) &XGBoosterFeatureScore_R, 2},
diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc
index 2228932bd..cdb9ba65c 100644
--- a/R-package/src/xgboost_R.cc
+++ b/R-package/src/xgboost_R.cc
@@ -512,7 +512,7 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP
return ret;
}
-XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
+XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset, SEXP allow_groups) {
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
R_API_BEGIN();
R_xlen_t len = Rf_xlength(idxset);
@@ -531,7 +531,7 @@ XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
res_code = XGDMatrixSliceDMatrixEx(R_ExternalPtrAddr(handle),
BeginPtr(idxvec), len,
&res,
- 0);
+ Rf_asLogical(allow_groups));
}
CHECK_CALL(res_code);
R_SetExternalPtrAddr(ret, res);
diff --git a/R-package/src/xgboost_R.h b/R-package/src/xgboost_R.h
index cea50c146..62be5022a 100644
--- a/R-package/src/xgboost_R.h
+++ b/R-package/src/xgboost_R.h
@@ -112,9 +112,10 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP
* \brief create a new dmatrix from sliced content of existing matrix
* \param handle instance of data matrix to be sliced
* \param idxset index set
+ * \param allow_groups Whether to allow slicing the DMatrix if it has a 'group' field
* \return a sliced new matrix
*/
-XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset);
+XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset, SEXP allow_groups);
/*!
* \brief load a data matrix into binary file
diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R
index 18a3b99e6..bbb8fb323 100644
--- a/R-package/tests/testthat/test_basic.R
+++ b/R-package/tests/testthat/test_basic.R
@@ -334,7 +334,7 @@ test_that("xgb.cv works", {
set.seed(11)
expect_output(
cv <- xgb.cv(
- data = train$data, label = train$label, max_depth = 2, nfold = 5,
+ data = xgb.DMatrix(train$data, label = train$label), max_depth = 2, nfold = 5,
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
eval_metric = "error", verbose = TRUE
),
@@ -357,13 +357,13 @@ test_that("xgb.cv works with stratified folds", {
cv <- xgb.cv(
data = dtrain, max_depth = 2, nfold = 5,
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
- verbose = TRUE, stratified = FALSE
+ verbose = FALSE, stratified = FALSE
)
set.seed(314159)
cv2 <- xgb.cv(
data = dtrain, max_depth = 2, nfold = 5,
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
- verbose = TRUE, stratified = TRUE
+ verbose = FALSE, stratified = TRUE
)
# Stratified folds should result in a different evaluation logs
expect_true(all(cv$evaluation_log[, test_logloss_mean] != cv2$evaluation_log[, test_logloss_mean]))
@@ -885,3 +885,57 @@ test_that("Seed in params override PRNG from R", {
)
)
})
+
+test_that("xgb.cv works for AFT", {
+ X <- matrix(c(1, -1, -1, 1, 0, 1, 1, 0), nrow = 4, byrow = TRUE) # 4x2 matrix
+ dtrain <- xgb.DMatrix(X, nthread = n_threads)
+
+ params <- list(objective = 'survival:aft', learning_rate = 0.2, max_depth = 2L)
+
+ # data must have bounds
+ expect_error(
+ xgb.cv(
+ params = params,
+ data = dtrain,
+ nround = 5L,
+ nfold = 4L,
+ nthread = n_threads
+ )
+ )
+
+ setinfo(dtrain, 'label_lower_bound', c(2, 3, 0, 4))
+ setinfo(dtrain, 'label_upper_bound', c(2, Inf, 4, 5))
+
+ # automatic stratified splitting is turned off
+ expect_warning(
+ xgb.cv(
+ params = params, data = dtrain, nround = 5L, nfold = 4L,
+ nthread = n_threads, stratified = TRUE, verbose = FALSE
+ )
+ )
+
+ # this works without any issue
+ expect_no_warning(
+ xgb.cv(params = params, data = dtrain, nround = 5L, nfold = 4L, verbose = FALSE)
+ )
+})
+
+test_that("xgb.cv works for ranking", {
+ data(iris)
+ x <- iris[, -(4:5)]
+ y <- as.integer(iris$Petal.Width)
+ group <- rep(50, 3)
+ dm <- xgb.DMatrix(x, label = y, group = group)
+ res <- xgb.cv(
+ data = dm,
+ params = list(
+ objective = "rank:pairwise",
+ max_depth = 3
+ ),
+ nrounds = 3,
+ nfold = 2,
+ verbose = FALSE,
+ stratified = FALSE
+ )
+ expect_equal(length(res$folds), 2L)
+})
diff --git a/R-package/tests/testthat/test_callbacks.R b/R-package/tests/testthat/test_callbacks.R
index 913791de4..bf95a170d 100644
--- a/R-package/tests/testthat/test_callbacks.R
+++ b/R-package/tests/testthat/test_callbacks.R
@@ -367,7 +367,7 @@ test_that("prediction in early-stopping xgb.cv works", {
expect_output(
cv <- xgb.cv(param, dtrain, nfold = 5, eta = 0.1, nrounds = 20,
early_stopping_rounds = 5, maximize = FALSE, stratified = FALSE,
- prediction = TRUE, base_score = 0.5)
+ prediction = TRUE, base_score = 0.5, verbose = TRUE)
, "Stopping. Best iteration")
expect_false(is.null(cv$early_stop$best_iteration))
@@ -387,7 +387,7 @@ test_that("prediction in xgb.cv for softprob works", {
lb <- as.numeric(iris$Species) - 1
set.seed(11)
expect_warning(
- cv <- xgb.cv(data = as.matrix(iris[, -5]), label = lb, nfold = 4,
+ cv <- xgb.cv(data = xgb.DMatrix(as.matrix(iris[, -5]), label = lb), nfold = 4,
eta = 0.5, nrounds = 5, max_depth = 3, nthread = n_threads,
subsample = 0.8, gamma = 2, verbose = 0,
prediction = TRUE, objective = "multi:softprob", num_class = 3)
diff --git a/R-package/tests/testthat/test_dmatrix.R b/R-package/tests/testthat/test_dmatrix.R
index 44d1566c6..548afece3 100644
--- a/R-package/tests/testthat/test_dmatrix.R
+++ b/R-package/tests/testthat/test_dmatrix.R
@@ -243,7 +243,7 @@ test_that("xgb.DMatrix: print", {
txt <- capture.output({
print(dtrain)
})
- expect_equal(txt, "xgb.DMatrix dim: 6513 x 126 info: label weight base_margin colnames: yes")
+ expect_equal(txt, "xgb.DMatrix dim: 6513 x 126 info: base_margin, label, weight colnames: yes")
# DMatrix with just features
dtrain <- xgb.DMatrix(
@@ -724,6 +724,44 @@ test_that("xgb.DMatrix: quantile cuts look correct", {
)
})
+test_that("xgb.DMatrix: slicing keeps field indicators", {
+ data(mtcars)
+ x <- as.matrix(mtcars[, -1])
+ y <- mtcars[, 1]
+ dm <- xgb.DMatrix(
+ data = x,
+ label_lower_bound = -y,
+ label_upper_bound = y,
+ nthread = 1
+ )
+ idx_take <- seq(1, 5)
+ dm_slice <- xgb.slice.DMatrix(dm, idx_take)
+
+ expect_true(xgb.DMatrix.hasinfo(dm_slice, "label_lower_bound"))
+ expect_true(xgb.DMatrix.hasinfo(dm_slice, "label_upper_bound"))
+ expect_false(xgb.DMatrix.hasinfo(dm_slice, "label"))
+
+ expect_equal(getinfo(dm_slice, "label_lower_bound"), -y[idx_take], tolerance = 1e-6)
+ expect_equal(getinfo(dm_slice, "label_upper_bound"), y[idx_take], tolerance = 1e-6)
+})
+
+test_that("xgb.DMatrix: can slice with groups", {
+ data(iris)
+ x <- as.matrix(iris[, -5])
+ set.seed(123)
+ y <- sample(3, size = nrow(x), replace = TRUE)
+ group <- c(50, 50, 50)
+ dm <- xgb.DMatrix(x, label = y, group = group, nthread = 1)
+ idx_take <- seq(1, 50)
+ dm_slice <- xgb.slice.DMatrix(dm, idx_take, allow_groups = TRUE)
+
+ expect_true(xgb.DMatrix.hasinfo(dm_slice, "label"))
+ expect_false(xgb.DMatrix.hasinfo(dm_slice, "group"))
+ expect_false(xgb.DMatrix.hasinfo(dm_slice, "qid"))
+ expect_null(getinfo(dm_slice, "group"))
+ expect_equal(getinfo(dm_slice, "label"), y[idx_take], tolerance = 1e-6)
+})
+
test_that("xgb.DMatrix: can read CSV", {
txt <- paste(
"1,2,3",
diff --git a/demo/dask/cpu_training.py b/demo/dask/cpu_training.py
index 2bee444f7..7117eddd9 100644
--- a/demo/dask/cpu_training.py
+++ b/demo/dask/cpu_training.py
@@ -40,7 +40,7 @@ def main(client):
# you can pass output directly into `predict` too.
prediction = dxgb.predict(client, bst, dtrain)
print("Evaluation history:", history)
- return prediction
+ print("Error:", da.sqrt((prediction - y) ** 2).mean().compute())
if __name__ == "__main__":
diff --git a/doc/contrib/unit_tests.rst b/doc/contrib/unit_tests.rst
index 662a632e2..908e5ed99 100644
--- a/doc/contrib/unit_tests.rst
+++ b/doc/contrib/unit_tests.rst
@@ -144,6 +144,14 @@ which provides higher flexibility. For example:
ctest --verbose
+If you need to debug errors on Windows using the debugger from VS, you can append the gtest flags in `test_main.cc`:
+
+.. code-block::
+
+ ::testing::GTEST_FLAG(filter) = "Suite.Test";
+ ::testing::GTEST_FLAG(repeat) = 10;
+
+
***********************************************
Sanitizers: Detect memory errors and data races
***********************************************
diff --git a/doc/index.rst b/doc/index.rst
index a2ae9bbd3..7b241c0a1 100644
--- a/doc/index.rst
+++ b/doc/index.rst
@@ -28,7 +28,7 @@ Contents
Python Package
R Package
JVM Package
- Ruby Package
+ Ruby Package
Swift Package
Julia Package
C Package
diff --git a/doc/tutorials/learning_to_rank.rst b/doc/tutorials/learning_to_rank.rst
index bfc727ed7..15a611bd0 100644
--- a/doc/tutorials/learning_to_rank.rst
+++ b/doc/tutorials/learning_to_rank.rst
@@ -52,7 +52,7 @@ Notice that the samples are sorted based on their query index in a non-decreasin
X, y = make_classification(random_state=seed)
rng = np.random.default_rng(seed)
n_query_groups = 3
- qid = rng.integers(0, 3, size=X.shape[0])
+ qid = rng.integers(0, n_query_groups, size=X.shape[0])
# Sort the inputs based on query index
sorted_idx = np.argsort(qid)
@@ -65,14 +65,14 @@ The simplest way to train a ranking model is by using the scikit-learn estimator
.. code-block:: python
ranker = xgb.XGBRanker(tree_method="hist", lambdarank_num_pair_per_sample=8, objective="rank:ndcg", lambdarank_pair_method="topk")
- ranker.fit(X, y, qid=qid)
+ ranker.fit(X, y, qid=qid[sorted_idx])
Please note that, as of writing, there's no learning-to-rank interface in scikit-learn. As a result, the :py:class:`xgboost.XGBRanker` class does not fully conform the scikit-learn estimator guideline and can not be directly used with some of its utility functions. For instances, the ``auc_score`` and ``ndcg_score`` in scikit-learn don't consider query group information nor the pairwise loss. Most of the metrics are implemented as part of XGBoost, but to use scikit-learn utilities like :py:func:`sklearn.model_selection.cross_validation`, we need to make some adjustments in order to pass the ``qid`` as an additional parameter for :py:meth:`xgboost.XGBRanker.score`. Given a data frame ``X`` (either pandas or cuDF), add the column ``qid`` as follows:
.. code-block:: python
df = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])])
- df["qid"] = qid
+ df["qid"] = qid[sorted_idx]
ranker.fit(df, y) # No need to pass qid as a separate argument
from sklearn.model_selection import StratifiedGroupKFold, cross_val_score
diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h
index 795c78946..19b93c644 100644
--- a/include/xgboost/c_api.h
+++ b/include/xgboost/c_api.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2015~2023 by XGBoost Contributors
+ * Copyright 2015-2024, XGBoost Contributors
* \file c_api.h
* \author Tianqi Chen
* \brief C API of XGBoost, used for interfacing to other languages.
@@ -639,21 +639,14 @@ XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle,
* \param len length of array
* \return 0 when success, -1 when failure happens
*/
-XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
- const char *field,
- const float *array,
+XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const float *array,
bst_ulong len);
-/*!
- * \brief set uint32 vector to a content in info
- * \param handle a instance of data matrix
- * \param field field name
- * \param array pointer to unsigned int vector
- * \param len length of array
- * \return 0 when success, -1 when failure happens
+/**
+ * @deprecated since 2.1.0
+ *
+ * Use @ref XGDMatrixSetInfoFromInterface instead.
*/
-XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
- const char *field,
- const unsigned *array,
+XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const unsigned *array,
bst_ulong len);
/*!
@@ -725,42 +718,13 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
bst_ulong *size,
const char ***out_features);
-/*!
- * \brief Set meta info from dense matrix. Valid field names are:
+/**
+ * @deprecated since 2.1.0
*
- * - label
- * - weight
- * - base_margin
- * - group
- * - label_lower_bound
- * - label_upper_bound
- * - feature_weights
- *
- * \param handle An instance of data matrix
- * \param field Field name
- * \param data Pointer to consecutive memory storing data.
- * \param size Size of the data, this is relative to size of type. (Meaning NOT number
- * of bytes.)
- * \param type Indicator of data type. This is defined in xgboost::DataType enum class.
- * - float = 1
- * - double = 2
- * - uint32_t = 3
- * - uint64_t = 4
- * \return 0 when success, -1 when failure happens
+ * Use @ref XGDMatrixSetInfoFromInterface instead.
*/
-XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field,
- void const *data, bst_ulong size, int type);
-
-/*!
- * \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix
- * \param handle a instance of data matrix
- * \param group pointer to group size
- * \param len length of array
- * \return 0 when success, -1 when failure happens
- */
-XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
- const unsigned *group,
- bst_ulong len);
+XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void const *data,
+ bst_ulong size, int type);
/*!
* \brief get float info vector from matrix.
@@ -1591,7 +1555,7 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
/**
* @brief Get the arguments needed for running workers. This should be called after
- * XGTrackerRun() and XGTrackerWait()
+ * XGTrackerRun().
*
* @param handle The handle to the tracker.
* @param args The arguments returned as a JSON document.
@@ -1601,16 +1565,19 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args);
/**
- * @brief Run the tracker.
+ * @brief Start the tracker. The tracker runs in the background and this function returns
+ * once the tracker is started.
*
* @param handle The handle to the tracker.
+ * @param config Unused at the moment, preserved for the future.
*
* @return 0 for success, -1 for failure.
*/
-XGB_DLL int XGTrackerRun(TrackerHandle handle);
+XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *config);
/**
- * @brief Wait for the tracker to finish, should be called after XGTrackerRun().
+ * @brief Wait for the tracker to finish, should be called after XGTrackerRun(). This
+ * function will block until the tracker task is finished or timeout is reached.
*
* @param handle The handle to the tracker.
* @param config JSON encoded configuration. No argument is required yet, preserved for
@@ -1618,11 +1585,12 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle);
*
* @return 0 for success, -1 for failure.
*/
-XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config);
+XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config);
/**
- * @brief Free a tracker instance. XGTrackerWait() is called internally. If the tracker
- * cannot close properly, manual interruption is required.
+ * @brief Free a tracker instance. This should be called after XGTrackerWaitFor(). If the
+ * tracker is not properly waited, this function will shutdown all connections with
+ * the tracker, potentially leading to undefined behavior.
*
* @param handle The handle to the tracker.
*
diff --git a/include/xgboost/collective/result.h b/include/xgboost/collective/result.h
index 919d3a902..23e70a8e6 100644
--- a/include/xgboost/collective/result.h
+++ b/include/xgboost/collective/result.h
@@ -3,13 +3,11 @@
*/
#pragma once
-#include
-
-#include // for unique_ptr
-#include // for stringstream
-#include // for stack
-#include // for string
-#include // for move
+#include // for int32_t
+#include // for unique_ptr
+#include // for string
+#include // for error_code
+#include // for move
namespace xgboost::collective {
namespace detail {
@@ -48,48 +46,19 @@ struct ResultImpl {
return cur_eq;
}
- [[nodiscard]] std::string Report() {
- std::stringstream ss;
- ss << "\n- " << this->message;
- if (this->errc != std::error_code{}) {
- ss << " system error:" << this->errc.message();
- }
+ [[nodiscard]] std::string Report() const;
+ [[nodiscard]] std::error_code Code() const;
- auto ptr = prev.get();
- while (ptr) {
- ss << "\n- ";
- ss << ptr->message;
-
- if (ptr->errc != std::error_code{}) {
- ss << " " << ptr->errc.message();
- }
- ptr = ptr->prev.get();
- }
-
- return ss.str();
- }
- [[nodiscard]] auto Code() const {
- // Find the root error.
- std::stack stack;
- auto ptr = this;
- while (ptr) {
- stack.push(ptr);
- if (ptr->prev) {
- ptr = ptr->prev.get();
- } else {
- break;
- }
- }
- while (!stack.empty()) {
- auto frame = stack.top();
- stack.pop();
- if (frame->errc != std::error_code{}) {
- return frame->errc;
- }
- }
- return std::error_code{};
- }
+ void Concat(std::unique_ptr rhs);
};
+
+#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
+#define __builtin_FILE() nullptr
+#define __builtin_LINE() (-1)
+std::string MakeMsg(std::string&& msg, char const*, std::int32_t);
+#else
+std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line);
+#endif
} // namespace detail
/**
@@ -131,8 +100,21 @@ struct Result {
}
return *impl_ == *that.impl_;
}
+
+ friend Result operator+(Result&& lhs, Result&& rhs);
};
+[[nodiscard]] inline Result operator+(Result&& lhs, Result&& rhs) {
+ if (lhs.OK()) {
+ return std::forward(rhs);
+ }
+ if (rhs.OK()) {
+ return std::forward(lhs);
+ }
+ lhs.impl_->Concat(std::move(rhs.impl_));
+ return std::forward(lhs);
+}
+
/**
* @brief Return success.
*/
@@ -140,38 +122,43 @@ struct Result {
/**
* @brief Return failure.
*/
-[[nodiscard]] inline auto Fail(std::string msg) { return Result{std::move(msg)}; }
+[[nodiscard]] inline auto Fail(std::string msg, char const* file = __builtin_FILE(),
+ std::int32_t line = __builtin_LINE()) {
+ return Result{detail::MakeMsg(std::move(msg), file, line)};
+}
/**
* @brief Return failure with `errno`.
*/
-[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc) {
- return Result{std::move(msg), std::move(errc)};
+[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc,
+ char const* file = __builtin_FILE(),
+ std::int32_t line = __builtin_LINE()) {
+ return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc)};
}
/**
* @brief Return failure with a previous error.
*/
-[[nodiscard]] inline auto Fail(std::string msg, Result&& prev) {
- return Result{std::move(msg), std::forward(prev)};
+[[nodiscard]] inline auto Fail(std::string msg, Result&& prev, char const* file = __builtin_FILE(),
+ std::int32_t line = __builtin_LINE()) {
+ return Result{detail::MakeMsg(std::move(msg), file, line), std::forward(prev)};
}
/**
* @brief Return failure with a previous error and a new `errno`.
*/
-[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev) {
- return Result{std::move(msg), std::move(errc), std::forward(prev)};
+[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev,
+ char const* file = __builtin_FILE(),
+ std::int32_t line = __builtin_LINE()) {
+ return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc),
+ std::forward(prev)};
}
// We don't have monad, a simple helper would do.
template
-[[nodiscard]] Result operator<<(Result&& r, Fn&& fn) {
+[[nodiscard]] std::enable_if_t, Result> operator<<(Result&& r, Fn&& fn) {
if (!r.OK()) {
return std::forward(r);
}
return fn();
}
-inline void SafeColl(Result const& rc) {
- if (!rc.OK()) {
- LOG(FATAL) << rc.Report();
- }
-}
+void SafeColl(Result const& rc);
} // namespace xgboost::collective
diff --git a/include/xgboost/collective/socket.h b/include/xgboost/collective/socket.h
index 3bc3b389c..0e098052c 100644
--- a/include/xgboost/collective/socket.h
+++ b/include/xgboost/collective/socket.h
@@ -1,5 +1,5 @@
/**
- * Copyright (c) 2022-2023, XGBoost Contributors
+ * Copyright (c) 2022-2024, XGBoost Contributors
*/
#pragma once
@@ -12,7 +12,6 @@
#include // std::size_t
#include // std::int32_t, std::uint16_t
#include // memset
-#include // std::numeric_limits
#include // std::string
#include // std::error_code, std::system_category
#include // std::swap
@@ -125,6 +124,21 @@ inline std::int32_t CloseSocket(SocketT fd) {
#endif
}
+inline std::int32_t ShutdownSocket(SocketT fd) {
+#if defined(_WIN32)
+ auto rc = shutdown(fd, SD_BOTH);
+ if (rc != 0 && LastError() == WSANOTINITIALISED) {
+ return 0;
+ }
+#else
+ auto rc = shutdown(fd, SHUT_RDWR);
+ if (rc != 0 && LastError() == ENOTCONN) {
+ return 0;
+ }
+#endif
+ return rc;
+}
+
inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
#ifdef _WIN32
return errsv == WSAEWOULDBLOCK;
@@ -468,19 +482,30 @@ class TCPSocket {
*addr = SockAddress{SockAddrV6{caddr}};
*out = TCPSocket{newfd};
}
+ // On MacOS, this is automatically set to async socket if the parent socket is async
+ // We make sure all socket are blocking by default.
+ //
+ // On Windows, a closed socket is returned during shutdown. We guard against it when
+ // setting non-blocking.
+ if (!out->IsClosed()) {
+ return out->NonBlocking(false);
+ }
return Success();
}
~TCPSocket() {
if (!IsClosed()) {
- Close();
+ auto rc = this->Close();
+ if (!rc.OK()) {
+ LOG(WARNING) << rc.Report();
+ }
}
}
TCPSocket(TCPSocket const &that) = delete;
TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); }
TCPSocket &operator=(TCPSocket const &that) = delete;
- TCPSocket &operator=(TCPSocket &&that) {
+ TCPSocket &operator=(TCPSocket &&that) noexcept(true) {
std::swap(this->handle_, that.handle_);
return *this;
}
@@ -489,36 +514,49 @@ class TCPSocket {
*/
[[nodiscard]] HandleT const &Handle() const { return handle_; }
/**
- * \brief Listen to incoming requests. Should be called after bind.
+ * @brief Listen to incoming requests. Should be called after bind.
*/
- void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); }
+ [[nodiscard]] Result Listen(std::int32_t backlog = 16) {
+ if (listen(handle_, backlog) != 0) {
+ return system::FailWithCode("Failed to listen.");
+ }
+ return Success();
+ }
/**
- * \brief Bind socket to INADDR_ANY, return the port selected by the OS.
+ * @brief Bind socket to INADDR_ANY, return the port selected by the OS.
*/
- [[nodiscard]] in_port_t BindHost() {
+ [[nodiscard]] Result BindHost(std::int32_t* p_out) {
+ // Use int32 instead of in_port_t for consistency. We take port as parameter from
+ // users using other languages, the port is usually stored and passed around as int.
if (Domain() == SockDomain::kV6) {
auto addr = SockAddrV6::InaddrAny();
auto handle = reinterpret_cast(&addr.Handle());
- xgboost_CHECK_SYS_CALL(
- bind(handle_, handle, sizeof(std::remove_reference_t)), 0);
+ if (bind(handle_, handle, sizeof(std::remove_reference_t)) != 0) {
+ return system::FailWithCode("bind failed.");
+ }
sockaddr_in6 res_addr;
socklen_t addrlen = sizeof(res_addr);
- xgboost_CHECK_SYS_CALL(
- getsockname(handle_, reinterpret_cast(&res_addr), &addrlen), 0);
- return ntohs(res_addr.sin6_port);
+ if (getsockname(handle_, reinterpret_cast(&res_addr), &addrlen) != 0) {
+ return system::FailWithCode("getsockname failed.");
+ }
+ *p_out = ntohs(res_addr.sin6_port);
} else {
auto addr = SockAddrV4::InaddrAny();
auto handle = reinterpret_cast(&addr.Handle());
- xgboost_CHECK_SYS_CALL(
- bind(handle_, handle, sizeof(std::remove_reference_t)), 0);
+ if (bind(handle_, handle, sizeof(std::remove_reference_t)) != 0) {
+ return system::FailWithCode("bind failed.");
+ }
sockaddr_in res_addr;
socklen_t addrlen = sizeof(res_addr);
- xgboost_CHECK_SYS_CALL(
- getsockname(handle_, reinterpret_cast(&res_addr), &addrlen), 0);
- return ntohs(res_addr.sin_port);
+ if (getsockname(handle_, reinterpret_cast(&res_addr), &addrlen) != 0) {
+ return system::FailWithCode("getsockname failed.");
+ }
+ *p_out = ntohs(res_addr.sin_port);
}
+
+ return Success();
}
[[nodiscard]] auto Port() const {
@@ -631,26 +669,49 @@ class TCPSocket {
*/
std::size_t Send(StringView str);
/**
- * \brief Receive string, format is matched with the Python socket wrapper in RABIT.
+ * @brief Receive string, format is matched with the Python socket wrapper in RABIT.
*/
- std::size_t Recv(std::string *p_str);
+ [[nodiscard]] Result Recv(std::string *p_str);
/**
- * \brief Close the socket, called automatically in destructor if the socket is not closed.
+ * @brief Close the socket, called automatically in destructor if the socket is not closed.
*/
- void Close() {
+ [[nodiscard]] Result Close() {
if (InvalidSocket() != handle_) {
-#if defined(_WIN32)
auto rc = system::CloseSocket(handle_);
+#if defined(_WIN32)
// it's possible that we close TCP sockets after finalizing WSA due to detached thread.
if (rc != 0 && system::LastError() != WSANOTINITIALISED) {
- system::ThrowAtError("close", rc);
+ return system::FailWithCode("Failed to close the socket.");
}
#else
- xgboost_CHECK_SYS_CALL(system::CloseSocket(handle_), 0);
+ if (rc != 0) {
+ return system::FailWithCode("Failed to close the socket.");
+ }
#endif
handle_ = InvalidSocket();
}
+ return Success();
}
+ /**
+ * @brief Call shutdown on the socket.
+ */
+ [[nodiscard]] Result Shutdown() {
+ if (this->IsClosed()) {
+ return Success();
+ }
+ auto rc = system::ShutdownSocket(this->Handle());
+#if defined(_WIN32)
+ // Windows cannot shutdown a socket if it's not connected.
+ if (rc == -1 && system::LastError() == WSAENOTCONN) {
+ return Success();
+ }
+#endif
+ if (rc != 0) {
+ return system::FailWithCode("Failed to shutdown socket.");
+ }
+ return Success();
+ }
+
/**
* \brief Create a TCP socket on specified domain.
*/
diff --git a/include/xgboost/data.h b/include/xgboost/data.h
index 2bdf3713d..ec06a9c86 100644
--- a/include/xgboost/data.h
+++ b/include/xgboost/data.h
@@ -19,7 +19,6 @@
#include
#include
#include
-#include
#include
#include
#include
@@ -137,14 +136,6 @@ class MetaInfo {
* \param fo The output stream.
*/
void SaveBinary(dmlc::Stream* fo) const;
- /*!
- * \brief Set information in the meta info.
- * \param key The key of the information.
- * \param dptr The data pointer of the source array.
- * \param dtype The type of the source data.
- * \param num Number of elements in the source array.
- */
- void SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype, size_t num);
/*!
* \brief Set information in the meta info with array interface.
* \param key The key of the information.
@@ -517,10 +508,6 @@ class DMatrix {
DMatrix() = default;
/*! \brief meta information of the dataset */
virtual MetaInfo& Info() = 0;
- virtual void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
- auto const& ctx = *this->Ctx();
- this->Info().SetInfo(ctx, key, dptr, dtype, num);
- }
virtual void SetInfo(const char* key, std::string const& interface_str) {
auto const& ctx = *this->Ctx();
this->Info().SetInfo(ctx, key, StringView{interface_str});
diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h
index 8f9cd3eaa..79810d4d0 100644
--- a/include/xgboost/linalg.h
+++ b/include/xgboost/linalg.h
@@ -190,13 +190,14 @@ constexpr auto ArrToTuple(T (&arr)[N]) {
// uint division optimization inspired by the CIndexer in cupy. Division operation is
// slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64
// bit when the index is smaller, then try to avoid division when it's exp of 2.
-template
+template
LINALG_HD auto UnravelImpl(I idx, common::Span shape) {
- size_t index[D]{0};
+ std::size_t index[D]{0};
static_assert(std::is_signed::value,
"Don't change the type without changing the for loop.");
+ auto const sptr = shape.data();
for (int32_t dim = D; --dim > 0;) {
- auto s = static_cast>>(shape[dim]);
+ auto s = static_cast>>(sptr[dim]);
if (s & (s - 1)) {
auto t = idx / s;
index[dim] = idx - t * s;
@@ -745,6 +746,14 @@ auto ArrayInterfaceStr(TensorView const &t) {
return str;
}
+template
+auto Make1dInterface(T const *vec, std::size_t len) {
+ Context ctx;
+ auto t = linalg::MakeTensorView(&ctx, common::Span{vec, len}, len);
+ auto str = linalg::ArrayInterfaceStr(t);
+ return str;
+}
+
/**
* \brief A tensor storage. To use it for other functionality like slicing one needs to
* obtain a view first. This way we can use it on both host and device.
diff --git a/include/xgboost/span.h b/include/xgboost/span.h
index 3e1325ceb..468f5ff50 100644
--- a/include/xgboost/span.h
+++ b/include/xgboost/span.h
@@ -30,9 +30,8 @@
#define XGBOOST_SPAN_H_
#include
-#include
-#include // size_t
+#include // size_t
#include
#include
#include // numeric_limits
@@ -75,8 +74,7 @@
#endif // defined(_MSC_VER) && _MSC_VER < 1910
-namespace xgboost {
-namespace common {
+namespace xgboost::common {
#if defined(__CUDA_ARCH__)
// Usual logging facility is not available inside device code.
@@ -744,8 +742,8 @@ class IterSpan {
return it_ + size();
}
};
-} // namespace common
-} // namespace xgboost
+} // namespace xgboost::common
+
#if defined(_MSC_VER) &&_MSC_VER < 1910
#undef constexpr
diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml
index 5b6f82b6a..70e054d3a 100644
--- a/jvm-packages/pom.xml
+++ b/jvm-packages/pom.xml
@@ -33,7 +33,7 @@
UTF-81.81.8
- 1.18.0
+ 1.19.04.13.23.4.13.4.1
@@ -46,9 +46,9 @@
23.12.123.12.1cuda12
+ 3.2.18
+ 2.12.0OFF
- 3.2.17
- 2.11.0
@@ -124,7 +124,7 @@
org.apache.maven.pluginsmaven-jar-plugin
- 3.3.0
+ 3.4.0empty-javadoc-jar
@@ -153,7 +153,7 @@
org.apache.maven.pluginsmaven-gpg-plugin
- 3.1.0
+ 3.2.3sign-artifacts
@@ -167,7 +167,7 @@
org.apache.maven.pluginsmaven-source-plugin
- 3.3.0
+ 3.3.1attach-sources
@@ -205,7 +205,7 @@
org.apache.maven.pluginsmaven-assembly-plugin
- 3.6.0
+ 3.7.1jar-with-dependencies
@@ -446,7 +446,7 @@
org.apache.maven.pluginsmaven-surefire-plugin
- 3.2.2
+ 3.2.5falsefalse
@@ -488,12 +488,12 @@
com.esotericsoftwarekryo
- 5.5.0
+ 5.6.0commons-loggingcommons-logging
- 1.3.0
+ 1.3.1org.scalatest
diff --git a/jvm-packages/xgboost4j-gpu/pom.xml b/jvm-packages/xgboost4j-gpu/pom.xml
index 2dc36d52d..26ad9cafd 100644
--- a/jvm-packages/xgboost4j-gpu/pom.xml
+++ b/jvm-packages/xgboost4j-gpu/pom.xml
@@ -72,7 +72,7 @@
org.apache.maven.pluginsmaven-javadoc-plugin
- 3.6.2
+ 3.6.3protectedtrue
@@ -88,7 +88,7 @@
exec-maven-pluginorg.codehaus.mojo
- 3.1.0
+ 3.2.0native
@@ -115,7 +115,7 @@
org.apache.maven.pluginsmaven-jar-plugin
- 3.3.0
+ 3.4.0
diff --git a/jvm-packages/xgboost4j-tester/generate_pom.py b/jvm-packages/xgboost4j-tester/generate_pom.py
index b9c274c28..eb7cf94b3 100644
--- a/jvm-packages/xgboost4j-tester/generate_pom.py
+++ b/jvm-packages/xgboost4j-tester/generate_pom.py
@@ -22,7 +22,7 @@ pom_template = """
{scala_version}3.2.15{scala_binary_version}
- 5.5.0
+ 5.6.0
diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml
index 7eb186919..5012eaf14 100644
--- a/jvm-packages/xgboost4j/pom.xml
+++ b/jvm-packages/xgboost4j/pom.xml
@@ -60,7 +60,7 @@
org.apache.maven.pluginsmaven-javadoc-plugin
- 3.6.2
+ 3.6.3protectedtrue
@@ -76,7 +76,7 @@
exec-maven-pluginorg.codehaus.mojo
- 3.1.0
+ 3.2.0native
@@ -99,7 +99,7 @@
org.apache.maven.pluginsmaven-jar-plugin
- 3.3.0
+ 3.4.0
diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp
index 332b1a127..9ba944d5a 100644
--- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp
+++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp
@@ -408,7 +408,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatI
jfloat* array = jenv->GetFloatArrayElements(jarray, NULL);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
- int ret = XGDMatrixSetFloatInfo(handle, field, (float const *)array, len);
+ auto str = xgboost::linalg::Make1dInterface(array, len);
+ int ret = XGDMatrixSetInfoFromInterface(handle, field, str.c_str());
JVM_CHECK_CALL(ret);
//release
if (field) jenv->ReleaseStringUTFChars(jfield, field);
@@ -427,7 +428,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntIn
const char* field = jenv->GetStringUTFChars(jfield, 0);
jint* array = jenv->GetIntArrayElements(jarray, NULL);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
- int ret = XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len);
+ auto str = xgboost::linalg::Make1dInterface(array, len);
+ int ret = XGDMatrixSetInfoFromInterface(handle, field, str.c_str());
JVM_CHECK_CALL(ret);
//release
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
@@ -730,8 +732,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFr
if (jmargin) {
margin = jenv->GetFloatArrayElements(jmargin, nullptr);
JVM_CHECK_CALL(XGProxyDMatrixCreate(&proxy));
- JVM_CHECK_CALL(
- XGDMatrixSetFloatInfo(proxy, "base_margin", margin, jenv->GetArrayLength(jmargin)));
+ auto str = xgboost::linalg::Make1dInterface(margin, jenv->GetArrayLength(jmargin));
+ JVM_CHECK_CALL(XGDMatrixSetInfoFromInterface(proxy, "base_margin", str.c_str()));
}
bst_ulong const *out_shape;
diff --git a/plugin/federated/federated_coll.cc b/plugin/federated/federated_coll.cc
index b3dc23dba..34670715a 100644
--- a/plugin/federated/federated_coll.cc
+++ b/plugin/federated/federated_coll.cc
@@ -89,19 +89,15 @@ Coll *FederatedColl::MakeCUDAVar() {
[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span data,
std::int32_t root) {
- if (comm.Rank() == root) {
- return BroadcastImpl(comm, &this->sequence_number_, data, root);
- } else {
- return BroadcastImpl(comm, &this->sequence_number_, data, root);
- }
+ return BroadcastImpl(comm, &this->sequence_number_, data, root);
}
-[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span data,
- std::int64_t size) {
+[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span data) {
using namespace federated; // NOLINT
auto fed = dynamic_cast(&comm);
CHECK(fed);
auto stub = fed->Handle();
+ auto size = data.size_bytes() / comm.World();
auto offset = comm.Rank() * size;
auto segment = data.subspan(offset, size);
diff --git a/plugin/federated/federated_coll.cu b/plugin/federated/federated_coll.cu
index a922e1c11..3f604c50d 100644
--- a/plugin/federated/federated_coll.cu
+++ b/plugin/federated/federated_coll.cu
@@ -53,8 +53,7 @@ Coll *FederatedColl::MakeCUDAVar() {
};
}
-[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span data,
- std::int64_t size) {
+[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span data) {
auto cufed = dynamic_cast(&comm);
CHECK(cufed);
std::vector h_data(data.size());
@@ -63,7 +62,7 @@ Coll *FederatedColl::MakeCUDAVar() {
return GetCUDAResult(
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
} << [&] {
- return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()}, size);
+ return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()});
} << [&] {
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
cudaMemcpyHostToDevice, cufed->Stream()));
diff --git a/plugin/federated/federated_coll.cuh b/plugin/federated/federated_coll.cuh
index a1121d88f..6a690a33d 100644
--- a/plugin/federated/federated_coll.cuh
+++ b/plugin/federated/federated_coll.cuh
@@ -1,5 +1,5 @@
/**
- * Copyright 2023, XGBoost contributors
+ * Copyright 2023-2024, XGBoost contributors
*/
#include "../../src/collective/comm.h" // for Comm, Coll
#include "federated_coll.h" // for FederatedColl
@@ -16,8 +16,7 @@ class CUDAFederatedColl : public Coll {
ArrayInterfaceHandler::Type type, Op op) override;
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span data,
std::int32_t root) override;
- [[nodiscard]] Result Allgather(Comm const &, common::Span data,
- std::int64_t size) override;
+ [[nodiscard]] Result Allgather(Comm const &, common::Span data) override;
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span data,
common::Span sizes,
common::Span recv_segments,
diff --git a/plugin/federated/federated_coll.h b/plugin/federated/federated_coll.h
index c261b01e1..12443a3e1 100644
--- a/plugin/federated/federated_coll.h
+++ b/plugin/federated/federated_coll.h
@@ -1,12 +1,9 @@
/**
- * Copyright 2023, XGBoost contributors
+ * Copyright 2023-2024, XGBoost contributors
*/
#pragma once
#include "../../src/collective/coll.h" // for Coll
#include "../../src/collective/comm.h" // for Comm
-#include "../../src/common/io.h" // for ReadAll
-#include "../../src/common/json_utils.h" // for OptionalArg
-#include "xgboost/json.h" // for Json
namespace xgboost::collective {
class FederatedColl : public Coll {
@@ -20,8 +17,7 @@ class FederatedColl : public Coll {
ArrayInterfaceHandler::Type type, Op op) override;
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span data,
std::int32_t root) override;
- [[nodiscard]] Result Allgather(Comm const &, common::Span data,
- std::int64_t) override;
+ [[nodiscard]] Result Allgather(Comm const &, common::Span data) override;
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span data,
common::Span sizes,
common::Span recv_segments,
diff --git a/plugin/federated/federated_comm.cuh b/plugin/federated/federated_comm.cuh
index 58c52f67e..85cecb3eb 100644
--- a/plugin/federated/federated_comm.cuh
+++ b/plugin/federated/federated_comm.cuh
@@ -1,5 +1,5 @@
/**
- * Copyright 2023, XGBoost Contributors
+ * Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
@@ -9,7 +9,6 @@
#include "../../src/common/device_helpers.cuh" // for CUDAStreamView
#include "federated_comm.h" // for FederatedComm
#include "xgboost/context.h" // for Context
-#include "xgboost/logging.h"
namespace xgboost::collective {
class CUDAFederatedComm : public FederatedComm {
diff --git a/plugin/federated/federated_comm.h b/plugin/federated/federated_comm.h
index 750d94abd..b39e1878a 100644
--- a/plugin/federated/federated_comm.h
+++ b/plugin/federated/federated_comm.h
@@ -1,5 +1,5 @@
/**
- * Copyright 2023, XGBoost contributors
+ * Copyright 2023-2024, XGBoost contributors
*/
#pragma once
@@ -11,7 +11,6 @@
#include // for string
#include "../../src/collective/comm.h" // for HostComm
-#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h"
namespace xgboost::collective {
@@ -51,6 +50,10 @@ class FederatedComm : public HostComm {
std::int32_t rank) {
this->Init(host, port, world, rank, {}, {}, {});
}
+ [[nodiscard]] Result Shutdown() final {
+ this->ResetState();
+ return Success();
+ }
~FederatedComm() override { stub_.reset(); }
[[nodiscard]] std::shared_ptr Chan(std::int32_t) const override {
@@ -65,5 +68,13 @@ class FederatedComm : public HostComm {
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr pimpl) const override;
+ /**
+ * @brief Get a string ID for the current process.
+ */
+ [[nodiscard]] Result ProcessorName(std::string* out) const final {
+ auto rank = this->Rank();
+ *out = "rank:" + std::to_string(rank);
+ return Success();
+ };
};
} // namespace xgboost::collective
diff --git a/plugin/federated/federated_server.h b/plugin/federated/federated_server.h
index de760d9d8..4692ad6c2 100644
--- a/plugin/federated/federated_server.h
+++ b/plugin/federated/federated_server.h
@@ -1,22 +1,18 @@
/**
- * Copyright 2022-2023, XGBoost contributors
+ * Copyright 2022-2024, XGBoost contributors
*/
#pragma once
#include
#include // for int32_t
-#include // for future
#include "../../src/collective/in_memory_handler.h"
-#include "../../src/collective/tracker.h" // for Tracker
-#include "xgboost/collective/result.h" // for Result
namespace xgboost::federated {
class FederatedService final : public Federated::Service {
public:
- explicit FederatedService(std::int32_t world_size)
- : handler_{static_cast(world_size)} {}
+ explicit FederatedService(std::int32_t world_size) : handler_{world_size} {}
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;
diff --git a/plugin/federated/federated_tracker.cc b/plugin/federated/federated_tracker.cc
index 37b6c3639..5051d43cb 100644
--- a/plugin/federated/federated_tracker.cc
+++ b/plugin/federated/federated_tracker.cc
@@ -125,14 +125,14 @@ Result FederatedTracker::Shutdown() {
[[nodiscard]] Json FederatedTracker::WorkerArgs() const {
auto rc = this->WaitUntilReady();
- CHECK(rc.OK()) << rc.Report();
+ SafeColl(rc);
std::string host;
rc = GetHostAddress(&host);
CHECK(rc.OK());
Json args{Object{}};
- args["DMLC_TRACKER_URI"] = String{host};
- args["DMLC_TRACKER_PORT"] = this->Port();
+ args["dmlc_tracker_uri"] = String{host};
+ args["dmlc_tracker_port"] = this->Port();
return args;
}
} // namespace xgboost::collective
diff --git a/plugin/federated/federated_tracker.h b/plugin/federated/federated_tracker.h
index 33592fefe..ac46b6eaa 100644
--- a/plugin/federated/federated_tracker.h
+++ b/plugin/federated/federated_tracker.h
@@ -17,8 +17,7 @@ namespace xgboost::collective {
namespace federated {
class FederatedService final : public Federated::Service {
public:
- explicit FederatedService(std::int32_t world_size)
- : handler_{static_cast(world_size)} {}
+ explicit FederatedService(std::int32_t world_size) : handler_{world_size} {}
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;
diff --git a/plugin/sycl/common/hist_util.cc b/plugin/sycl/common/hist_util.cc
new file mode 100644
index 000000000..fd813a92c
--- /dev/null
+++ b/plugin/sycl/common/hist_util.cc
@@ -0,0 +1,334 @@
+/*!
+ * Copyright 2017-2023 by Contributors
+ * \file hist_util.cc
+ */
+#include
+#include
+#include
+
+#include "../data/gradient_index.h"
+#include "hist_util.h"
+
+#include
+
+namespace xgboost {
+namespace sycl {
+namespace common {
+
+/*!
+ * \brief Fill histogram with zeroes
+ */
+template
+void InitHist(::sycl::queue qu, GHistRow* hist,
+ size_t size, ::sycl::event* event) {
+ *event = qu.fill(hist->Begin(),
+ xgboost::detail::GradientPairInternal(), size, *event);
+}
+template void InitHist(::sycl::queue qu,
+ GHistRow* hist,
+ size_t size, ::sycl::event* event);
+template void InitHist(::sycl::queue qu,
+ GHistRow* hist,
+ size_t size, ::sycl::event* event);
+
+/*!
+ * \brief Compute Subtraction: dst = src1 - src2
+ */
+template
+::sycl::event SubtractionHist(::sycl::queue qu,
+ GHistRow* dst,
+ const GHistRow& src1,
+ const GHistRow& src2,
+ size_t size, ::sycl::event event_priv) {
+ GradientSumT* pdst = reinterpret_cast(dst->Data());
+ const GradientSumT* psrc1 = reinterpret_cast(src1.DataConst());
+ const GradientSumT* psrc2 = reinterpret_cast(src2.DataConst());
+
+ auto event_final = qu.submit([&](::sycl::handler& cgh) {
+ cgh.depends_on(event_priv);
+ cgh.parallel_for<>(::sycl::range<1>(2 * size), [pdst, psrc1, psrc2](::sycl::item<1> pid) {
+ const size_t i = pid.get_id(0);
+ pdst[i] = psrc1[i] - psrc2[i];
+ });
+ });
+ return event_final;
+}
+template ::sycl::event SubtractionHist(::sycl::queue qu,
+ GHistRow* dst,
+ const GHistRow& src1,
+ const GHistRow& src2,
+ size_t size, ::sycl::event event_priv);
+template ::sycl::event SubtractionHist(::sycl::queue qu,
+ GHistRow* dst,
+ const GHistRow& src1,
+ const GHistRow& src2,
+ size_t size, ::sycl::event event_priv);
+
+// Kernel with buffer using
+template
+::sycl::event BuildHistKernel(::sycl::queue qu,
+ const USMVector& gpair_device,
+ const RowSetCollection::Elem& row_indices,
+ const GHistIndexMatrix& gmat,
+ GHistRow* hist,
+ GHistRow* hist_buffer,
+ ::sycl::event event_priv) {
+ const size_t size = row_indices.Size();
+ const size_t* rid = row_indices.begin;
+ const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride;
+ const GradientPair::ValueT* pgh =
+ reinterpret_cast(gpair_device.DataConst());
+ const BinIdxType* gradient_index = gmat.index.data();
+ const uint32_t* offsets = gmat.index.Offset();
+ FPType* hist_data = reinterpret_cast(hist->Data());
+ const size_t nbins = gmat.nbins;
+
+ const size_t max_work_group_size =
+ qu.get_device().get_info<::sycl::info::device::max_work_group_size>();
+ const size_t work_group_size = n_columns < max_work_group_size ? n_columns : max_work_group_size;
+
+ const size_t max_nblocks = hist_buffer->Size() / (nbins * 2);
+ const size_t min_block_size = 128;
+ size_t nblocks = std::min(max_nblocks, size / min_block_size + !!(size % min_block_size));
+ const size_t block_size = size / nblocks + !!(size % nblocks);
+ FPType* hist_buffer_data = reinterpret_cast(hist_buffer->Data());
+
+ auto event_fill = qu.fill(hist_buffer_data, FPType(0), nblocks * nbins * 2, event_priv);
+ auto event_main = qu.submit([&](::sycl::handler& cgh) {
+ cgh.depends_on(event_fill);
+ cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(nblocks, work_group_size),
+ ::sycl::range<2>(1, work_group_size)),
+ [=](::sycl::nd_item<2> pid) {
+ size_t block = pid.get_global_id(0);
+ size_t feat = pid.get_global_id(1);
+
+ FPType* hist_local = hist_buffer_data + block * nbins * 2;
+ for (size_t idx = 0; idx < block_size; ++idx) {
+ size_t i = block * block_size + idx;
+ if (i < size) {
+ const size_t icol_start = n_columns * rid[i];
+ const size_t idx_gh = rid[i];
+
+ pid.barrier(::sycl::access::fence_space::local_space);
+ const BinIdxType* gr_index_local = gradient_index + icol_start;
+
+ for (size_t j = feat; j < n_columns; j += work_group_size) {
+ uint32_t idx_bin = static_cast(gr_index_local[j]);
+ if constexpr (isDense) {
+ idx_bin += offsets[j];
+ }
+ if (idx_bin < nbins) {
+ hist_local[2 * idx_bin] += pgh[2 * idx_gh];
+ hist_local[2 * idx_bin+1] += pgh[2 * idx_gh+1];
+ }
+ }
+ }
+ }
+ });
+ });
+
+ auto event_save = qu.submit([&](::sycl::handler& cgh) {
+ cgh.depends_on(event_main);
+ cgh.parallel_for<>(::sycl::range<1>(nbins), [=](::sycl::item<1> pid) {
+ size_t idx_bin = pid.get_id(0);
+
+ FPType gsum = 0.0f;
+ FPType hsum = 0.0f;
+
+ for (size_t j = 0; j < nblocks; ++j) {
+ gsum += hist_buffer_data[j * nbins * 2 + 2 * idx_bin];
+ hsum += hist_buffer_data[j * nbins * 2 + 2 * idx_bin + 1];
+ }
+
+ hist_data[2 * idx_bin] = gsum;
+ hist_data[2 * idx_bin + 1] = hsum;
+ });
+ });
+ return event_save;
+}
+
+// Kernel with atomic using
+template
+::sycl::event BuildHistKernel(::sycl::queue qu,
+ const USMVector& gpair_device,
+ const RowSetCollection::Elem& row_indices,
+ const GHistIndexMatrix& gmat,
+ GHistRow* hist,
+ ::sycl::event event_priv) {
+ const size_t size = row_indices.Size();
+ const size_t* rid = row_indices.begin;
+ const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride;
+ const GradientPair::ValueT* pgh =
+ reinterpret_cast(gpair_device.DataConst());
+ const BinIdxType* gradient_index = gmat.index.data();
+ const uint32_t* offsets = gmat.index.Offset();
+ FPType* hist_data = reinterpret_cast(hist->Data());
+ const size_t nbins = gmat.nbins;
+
+ const size_t max_work_group_size =
+ qu.get_device().get_info<::sycl::info::device::max_work_group_size>();
+ const size_t feat_local = n_columns < max_work_group_size ? n_columns : max_work_group_size;
+
+ auto event_fill = qu.fill(hist_data, FPType(0), nbins * 2, event_priv);
+ auto event_main = qu.submit([&](::sycl::handler& cgh) {
+ cgh.depends_on(event_fill);
+ cgh.parallel_for<>(::sycl::range<2>(size, feat_local),
+ [=](::sycl::item<2> pid) {
+ size_t i = pid.get_id(0);
+ size_t feat = pid.get_id(1);
+
+ const size_t icol_start = n_columns * rid[i];
+ const size_t idx_gh = rid[i];
+
+ const BinIdxType* gr_index_local = gradient_index + icol_start;
+
+ for (size_t j = feat; j < n_columns; j += feat_local) {
+ uint32_t idx_bin = static_cast(gr_index_local[j]);
+ if constexpr (isDense) {
+ idx_bin += offsets[j];
+ }
+ if (idx_bin < nbins) {
+ AtomicRef gsum(hist_data[2 * idx_bin]);
+ AtomicRef hsum(hist_data[2 * idx_bin + 1]);
+ gsum.fetch_add(pgh[2 * idx_gh]);
+ hsum.fetch_add(pgh[2 * idx_gh + 1]);
+ }
+ }
+ });
+ });
+ return event_main;
+}
+
+template
+::sycl::event BuildHistDispatchKernel(
+ ::sycl::queue qu,
+ const USMVector& gpair_device,
+ const RowSetCollection::Elem& row_indices,
+ const GHistIndexMatrix& gmat,
+ GHistRow* hist,
+ bool isDense,
+ GHistRow* hist_buffer,
+ ::sycl::event events_priv,
+ bool force_atomic_use) {
+ const size_t size = row_indices.Size();
+ const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride;
+ const size_t nbins = gmat.nbins;
+
+ // max cycle size, while atomics are still effective
+ const size_t max_cycle_size_atomics = nbins;
+ const size_t cycle_size = size;
+
+ // TODO(razdoburdin): replace the add-hock dispatching criteria by more sutable one
+ bool use_atomic = (size < nbins) || (gmat.max_num_bins == gmat.nbins / n_columns);
+
+ // force_atomic_use flag is used only for testing
+ use_atomic = use_atomic || force_atomic_use;
+ if (!use_atomic) {
+ if (isDense) {
+ return BuildHistKernel(qu, gpair_device, row_indices,
+ gmat, hist, hist_buffer,
+ events_priv);
+ } else {
+ return BuildHistKernel(qu, gpair_device, row_indices,
+ gmat, hist, hist_buffer,
+ events_priv);
+ }
+ } else {
+ if (isDense) {
+ return BuildHistKernel(qu, gpair_device, row_indices,
+ gmat, hist, events_priv);
+ } else {
+ return BuildHistKernel(qu, gpair_device, row_indices,
+ gmat, hist, events_priv);
+ }
+ }
+}
+
+template
+::sycl::event BuildHistKernel(::sycl::queue qu,
+ const USMVector& gpair_device,
+ const RowSetCollection::Elem& row_indices,
+ const GHistIndexMatrix& gmat, const bool isDense,
+ GHistRow* hist,
+ GHistRow* hist_buffer,
+ ::sycl::event event_priv,
+ bool force_atomic_use) {
+ const bool is_dense = isDense;
+ switch (gmat.index.GetBinTypeSize()) {
+ case BinTypeSize::kUint8BinsTypeSize:
+ return BuildHistDispatchKernel(qu, gpair_device, row_indices,
+ gmat, hist, is_dense, hist_buffer,
+ event_priv, force_atomic_use);
+ break;
+ case BinTypeSize::kUint16BinsTypeSize:
+ return BuildHistDispatchKernel(qu, gpair_device, row_indices,
+ gmat, hist, is_dense, hist_buffer,
+ event_priv, force_atomic_use);
+ break;
+ case BinTypeSize::kUint32BinsTypeSize:
+ return BuildHistDispatchKernel(qu, gpair_device, row_indices,
+ gmat, hist, is_dense, hist_buffer,
+ event_priv, force_atomic_use);
+ break;
+ default:
+ CHECK(false); // no default behavior
+ }
+}
+
+template
+::sycl::event GHistBuilder::BuildHist(
+ const USMVector& gpair_device,
+ const RowSetCollection::Elem& row_indices,
+ const GHistIndexMatrix &gmat,
+ GHistRowT* hist,
+ bool isDense,
+ GHistRowT* hist_buffer,
+ ::sycl::event event_priv,
+ bool force_atomic_use) {
+ return BuildHistKernel(qu_, gpair_device, row_indices, gmat,
+ isDense, hist, hist_buffer, event_priv,
+ force_atomic_use);
+}
+
+template
+::sycl::event GHistBuilder::BuildHist(
+ const USMVector& gpair_device,
+ const RowSetCollection::Elem& row_indices,
+ const GHistIndexMatrix& gmat,
+ GHistRow* hist,
+ bool isDense,
+ GHistRow* hist_buffer,
+ ::sycl::event event_priv,
+ bool force_atomic_use);
+template
+::sycl::event GHistBuilder::BuildHist(
+ const USMVector& gpair_device,
+ const RowSetCollection::Elem& row_indices,
+ const GHistIndexMatrix& gmat,
+ GHistRow* hist,
+ bool isDense,
+ GHistRow* hist_buffer,
+ ::sycl::event event_priv,
+ bool force_atomic_use);
+
+template
+void GHistBuilder::SubtractionTrick(GHistRowT* self,
+ const GHistRowT& sibling,
+ const GHistRowT& parent) {
+ const size_t size = self->Size();
+ CHECK_EQ(sibling.Size(), size);
+ CHECK_EQ(parent.Size(), size);
+
+ SubtractionHist(qu_, self, parent, sibling, size, ::sycl::event());
+}
+template
+void GHistBuilder::SubtractionTrick(GHistRow* self,
+ const GHistRow& sibling,
+ const GHistRow& parent);
+template
+void GHistBuilder::SubtractionTrick(GHistRow* self,
+ const GHistRow& sibling,
+ const GHistRow& parent);
+} // namespace common
+} // namespace sycl
+} // namespace xgboost
diff --git a/plugin/sycl/common/hist_util.h b/plugin/sycl/common/hist_util.h
new file mode 100644
index 000000000..7c7af71ae
--- /dev/null
+++ b/plugin/sycl/common/hist_util.h
@@ -0,0 +1,89 @@
+/*!
+ * Copyright 2017-2023 by Contributors
+ * \file hist_util.h
+ */
+#ifndef PLUGIN_SYCL_COMMON_HIST_UTIL_H_
+#define PLUGIN_SYCL_COMMON_HIST_UTIL_H_
+
+#include
+#include
+#include
+
+#include "../data.h"
+#include "row_set.h"
+
+#include "../../src/common/hist_util.h"
+#include "../data/gradient_index.h"
+
+#include
+
+namespace xgboost {
+namespace sycl {
+namespace common {
+
+template
+using GHistRow = USMVector, memory_type>;
+
+using BinTypeSize = ::xgboost::common::BinTypeSize;
+
+class ColumnMatrix;
+
+/*!
+ * \brief Fill histogram with zeroes
+ */
+template
+void InitHist(::sycl::queue qu,
+ GHistRow* hist,
+ size_t size, ::sycl::event* event);
+
+/*!
+ * \brief Compute subtraction: dst = src1 - src2
+ */
+template
+::sycl::event SubtractionHist(::sycl::queue qu,
+ GHistRow* dst,
+ const GHistRow& src1,
+ const GHistRow& src2,
+ size_t size, ::sycl::event event_priv);
+
+/*!
+ * \brief Builder for histograms of gradient statistics
+ */
+template
+class GHistBuilder {
+ public:
+ template
+ using GHistRowT = GHistRow;
+
+ GHistBuilder() = default;
+ GHistBuilder(::sycl::queue qu, uint32_t nbins) : qu_{qu}, nbins_{nbins} {}
+
+ // Construct a histogram via histogram aggregation
+ ::sycl::event BuildHist(const USMVector& gpair_device,
+ const RowSetCollection::Elem& row_indices,
+ const GHistIndexMatrix& gmat,
+ GHistRowT* HistCollection,
+ bool isDense,
+ GHistRowT* hist_buffer,
+ ::sycl::event event,
+ bool force_atomic_use = false);
+
+ // Construct a histogram via subtraction trick
+ void SubtractionTrick(GHistRowT* self,
+ const GHistRowT& sibling,
+ const GHistRowT& parent);
+
+ uint32_t GetNumBins() const {
+ return nbins_;
+ }
+
+ private:
+ /*! \brief Number of all bins over all features */
+ uint32_t nbins_ { 0 };
+
+ ::sycl::queue qu_;
+};
+} // namespace common
+} // namespace sycl
+} // namespace xgboost
+#endif // PLUGIN_SYCL_COMMON_HIST_UTIL_H_
diff --git a/plugin/sycl/tree/updater_quantile_hist.cc b/plugin/sycl/tree/updater_quantile_hist.cc
new file mode 100644
index 000000000..98a42c3c8
--- /dev/null
+++ b/plugin/sycl/tree/updater_quantile_hist.cc
@@ -0,0 +1,55 @@
+/*!
+ * Copyright 2017-2024 by Contributors
+ * \file updater_quantile_hist.cc
+ */
+#include
+
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
+#pragma GCC diagnostic ignored "-W#pragma-messages"
+#include "xgboost/tree_updater.h"
+#pragma GCC diagnostic pop
+
+#include "xgboost/logging.h"
+
+#include "updater_quantile_hist.h"
+#include "../data.h"
+
+namespace xgboost {
+namespace sycl {
+namespace tree {
+
+DMLC_REGISTRY_FILE_TAG(updater_quantile_hist_sycl);
+
+DMLC_REGISTER_PARAMETER(HistMakerTrainParam);
+
+void QuantileHistMaker::Configure(const Args& args) {
+ const DeviceOrd device_spec = ctx_->Device();
+ qu_ = device_manager.GetQueue(device_spec);
+
+ param_.UpdateAllowUnknown(args);
+ hist_maker_param_.UpdateAllowUnknown(args);
+}
+
+void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param,
+ linalg::Matrix* gpair,
+ DMatrix *dmat,
+ xgboost::common::Span> out_position,
+ const std::vector &trees) {
+ LOG(FATAL) << "Not Implemented yet";
+}
+
+bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data,
+ linalg::MatrixView out_preds) {
+ LOG(FATAL) << "Not Implemented yet";
+}
+
+XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker_sycl")
+.describe("Grow tree using quantized histogram with SYCL.")
+.set_body(
+ [](Context const* ctx, ObjInfo const * task) {
+ return new QuantileHistMaker(ctx, task);
+ });
+} // namespace tree
+} // namespace sycl
+} // namespace xgboost
diff --git a/plugin/sycl/tree/updater_quantile_hist.h b/plugin/sycl/tree/updater_quantile_hist.h
new file mode 100644
index 000000000..93a50de3e
--- /dev/null
+++ b/plugin/sycl/tree/updater_quantile_hist.h
@@ -0,0 +1,91 @@
+/*!
+ * Copyright 2017-2024 by Contributors
+ * \file updater_quantile_hist.h
+ */
+#ifndef PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_
+#define PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_
+
+#include
+#include
+
+#include
+
+#include "../data/gradient_index.h"
+#include "../common/hist_util.h"
+#include "../common/row_set.h"
+#include "../common/partition_builder.h"
+#include "split_evaluator.h"
+#include "../device_manager.h"
+
+#include "xgboost/data.h"
+#include "xgboost/json.h"
+#include "../../src/tree/constraints.h"
+#include "../../src/common/random.h"
+
+namespace xgboost {
+namespace sycl {
+namespace tree {
+
+// training parameters specific to this algorithm
+struct HistMakerTrainParam
+ : public XGBoostParameter {
+ bool single_precision_histogram = false;
+ // declare parameters
+ DMLC_DECLARE_PARAMETER(HistMakerTrainParam) {
+ DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
+ "Use single precision to build histograms.");
+ }
+};
+
+/*! \brief construct a tree using quantized feature values with SYCL backend*/
+class QuantileHistMaker: public TreeUpdater {
+ public:
+ QuantileHistMaker(Context const* ctx, ObjInfo const * task) :
+ TreeUpdater(ctx), task_{task} {
+ updater_monitor_.Init("SYCLQuantileHistMaker");
+ }
+ void Configure(const Args& args) override;
+
+ void Update(xgboost::tree::TrainParam const *param,
+ linalg::Matrix* gpair,
+ DMatrix* dmat,
+ xgboost::common::Span> out_position,
+ const std::vector& trees) override;
+
+ bool UpdatePredictionCache(const DMatrix* data,
+ linalg::MatrixView out_preds) override;
+
+ void LoadConfig(Json const& in) override {
+ auto const& config = get