merge latest change from upstream

This commit is contained in:
Hui Liu 2024-04-22 09:35:31 -07:00
commit 8b75204fed
146 changed files with 3111 additions and 1027 deletions

View File

@ -8,7 +8,7 @@ updates:
- package-ecosystem: "maven" - package-ecosystem: "maven"
directory: "/jvm-packages" directory: "/jvm-packages"
schedule: schedule:
interval: "daily" interval: "monthly"
- package-ecosystem: "maven" - package-ecosystem: "maven"
directory: "/jvm-packages/xgboost4j" directory: "/jvm-packages/xgboost4j"
schedule: schedule:
@ -16,11 +16,11 @@ updates:
- package-ecosystem: "maven" - package-ecosystem: "maven"
directory: "/jvm-packages/xgboost4j-gpu" directory: "/jvm-packages/xgboost4j-gpu"
schedule: schedule:
interval: "daily" interval: "monthly"
- package-ecosystem: "maven" - package-ecosystem: "maven"
directory: "/jvm-packages/xgboost4j-example" directory: "/jvm-packages/xgboost4j-example"
schedule: schedule:
interval: "daily" interval: "monthly"
- package-ecosystem: "maven" - package-ecosystem: "maven"
directory: "/jvm-packages/xgboost4j-spark" directory: "/jvm-packages/xgboost4j-spark"
schedule: schedule:
@ -28,4 +28,4 @@ updates:
- package-ecosystem: "maven" - package-ecosystem: "maven"
directory: "/jvm-packages/xgboost4j-spark-gpu" directory: "/jvm-packages/xgboost4j-spark-gpu"
schedule: schedule:
interval: "daily" interval: "monthly"

View File

@ -110,7 +110,7 @@ jobs:
name: Test R package on Debian name: Test R package on Debian
runs-on: ubuntu-latest runs-on: ubuntu-latest
container: container:
image: rhub/debian-gcc-devel image: rhub/debian-gcc-release
steps: steps:
- name: Install system dependencies - name: Install system dependencies
@ -130,12 +130,12 @@ jobs:
- name: Install dependencies - name: Install dependencies
shell: bash -l {0} shell: bash -l {0}
run: | run: |
/tmp/R-devel/bin/Rscript -e "source('./R-package/tests/helper_scripts/install_deps.R')" Rscript -e "source('./R-package/tests/helper_scripts/install_deps.R')"
- name: Test R - name: Test R
shell: bash -l {0} shell: bash -l {0}
run: | run: |
python3 tests/ci_build/test_r_package.py --r=/tmp/R-devel/bin/R --build-tool=autotools --task=check python3 tests/ci_build/test_r_package.py --r=/usr/bin/R --build-tool=autotools --task=check
- uses: dorny/paths-filter@v2 - uses: dorny/paths-filter@v2
id: changes id: changes
@ -147,4 +147,4 @@ jobs:
- name: Run document check - name: Run document check
if: steps.changes.outputs.r_package == 'true' if: steps.changes.outputs.r_package == 'true'
run: | run: |
python3 tests/ci_build/test_r_package.py --r=/tmp/R-devel/bin/R --task=doc python3 tests/ci_build/test_r_package.py --r=/usr/bin/R --task=doc

View File

@ -3,7 +3,7 @@ name: update-rapids
on: on:
workflow_dispatch: workflow_dispatch:
schedule: schedule:
- cron: "0 20 * * *" # Run once daily - cron: "0 20 * * 1" # Run once weekly
permissions: permissions:
pull-requests: write pull-requests: write
@ -32,7 +32,7 @@ jobs:
run: | run: |
bash tests/buildkite/update-rapids.sh bash tests/buildkite/update-rapids.sh
- name: Create Pull Request - name: Create Pull Request
uses: peter-evans/create-pull-request@v5 uses: peter-evans/create-pull-request@v6
if: github.ref == 'refs/heads/master' if: github.ref == 'refs/heads/master'
with: with:
add-paths: | add-paths: |

View File

@ -2101,7 +2101,7 @@ This release marks a major milestone for the XGBoost project.
## v0.90 (2019.05.18) ## v0.90 (2019.05.18)
### XGBoost Python package drops Python 2.x (#4379, #4381) ### XGBoost Python package drops Python 2.x (#4379, #4381)
Python 2.x is reaching its end-of-life at the end of this year. [Many scientific Python packages are now moving to drop Python 2.x](https://python3statement.org/). Python 2.x is reaching its end-of-life at the end of this year. [Many scientific Python packages are now moving to drop Python 2.x](https://python3statement.github.io/).
### XGBoost4J-Spark now requires Spark 2.4.x (#4377) ### XGBoost4J-Spark now requires Spark 2.4.x (#4377)
* Spark 2.3 is reaching its end-of-life soon. See discussion at #4389. * Spark 2.3 is reaching its end-of-life soon. See discussion at #4389.

View File

@ -26,6 +26,11 @@ NVL <- function(x, val) {
'multi:softprob', 'rank:pairwise', 'rank:ndcg', 'rank:map')) 'multi:softprob', 'rank:pairwise', 'rank:ndcg', 'rank:map'))
} }
.RANKING_OBJECTIVES <- function() {
return(c('binary:logistic', 'binary:logitraw', 'binary:hinge', 'multi:softmax',
'multi:softprob'))
}
# #
# Low-level functions for boosting -------------------------------------------- # Low-level functions for boosting --------------------------------------------
@ -235,33 +240,43 @@ convert.labels <- function(labels, objective_name) {
} }
# Generates random (stratified if needed) CV folds # Generates random (stratified if needed) CV folds
generate.cv.folds <- function(nfold, nrows, stratified, label, params) { generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
if (NROW(group)) {
if (stratified) {
warning(
paste0(
"Stratified splitting is not supported when using 'group' attribute.",
" Will use unstratified splitting."
)
)
}
return(generate.group.folds(nfold, group))
}
objective <- params$objective
if (!is.character(objective)) {
warning("Will use unstratified splitting (custom objective used)")
stratified <- FALSE
}
# cannot stratify if label is NULL
if (stratified && is.null(label)) {
warning("Will use unstratified splitting (no 'labels' available)")
stratified <- FALSE
}
# cannot do it for rank # cannot do it for rank
objective <- params$objective
if (is.character(objective) && strtrim(objective, 5) == 'rank:') { if (is.character(objective) && strtrim(objective, 5) == 'rank:') {
stop("\n\tAutomatic generation of CV-folds is not implemented for ranking!\n", stop("\n\tAutomatic generation of CV-folds is not implemented for ranking without 'group' field!\n",
"\tConsider providing pre-computed CV-folds through the 'folds=' parameter.\n") "\tConsider providing pre-computed CV-folds through the 'folds=' parameter.\n")
} }
# shuffle # shuffle
rnd_idx <- sample.int(nrows) rnd_idx <- sample.int(nrows)
if (stratified && if (stratified && length(label) == length(rnd_idx)) {
length(label) == length(rnd_idx)) {
y <- label[rnd_idx] y <- label[rnd_idx]
# WARNING: some heuristic logic is employed to identify classification setting!
# - For classification, need to convert y labels to factor before making the folds, # - For classification, need to convert y labels to factor before making the folds,
# and then do stratification by factor levels. # and then do stratification by factor levels.
# - For regression, leave y numeric and do stratification by quantiles. # - For regression, leave y numeric and do stratification by quantiles.
if (is.character(objective)) { if (is.character(objective)) {
y <- convert.labels(y, params$objective) y <- convert.labels(y, objective)
} else {
# If no 'objective' given in params, it means that user either wants to
# use the default 'reg:squarederror' objective or has provided a custom
# obj function. Here, assume classification setting when y has 5 or less
# unique values:
if (length(unique(y)) <= 5) {
y <- factor(y)
}
} }
folds <- xgb.createFolds(y = y, k = nfold) folds <- xgb.createFolds(y = y, k = nfold)
} else { } else {
@ -277,6 +292,29 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
return(folds) return(folds)
} }
generate.group.folds <- function(nfold, group) {
ngroups <- length(group) - 1
if (ngroups < nfold) {
stop("DMatrix has fewer groups than folds.")
}
seq_groups <- seq_len(ngroups)
indices <- lapply(seq_groups, function(gr) seq(group[gr] + 1, group[gr + 1]))
assignments <- base::split(seq_groups, as.integer(seq_groups %% nfold))
assignments <- unname(assignments)
out <- vector("list", nfold)
randomized_groups <- sample(ngroups)
for (idx in seq_len(nfold)) {
groups_idx_test <- randomized_groups[assignments[[idx]]]
groups_test <- indices[groups_idx_test]
idx_test <- unlist(groups_test)
attributes(idx_test)$group_test <- lengths(groups_test)
attributes(idx_test)$group_train <- lengths(indices[-groups_idx_test])
out[[idx]] <- idx_test
}
return(out)
}
# Creates CV folds stratified by the values of y. # Creates CV folds stratified by the values of y.
# It was borrowed from caret::createFolds and simplified # It was borrowed from caret::createFolds and simplified
# by always returning an unnamed list of fold indices. # by always returning an unnamed list of fold indices.

View File

@ -1259,8 +1259,11 @@ xgb.get.DMatrix.data <- function(dmat) {
#' Get a new DMatrix containing the specified rows of #' Get a new DMatrix containing the specified rows of
#' original xgb.DMatrix object #' original xgb.DMatrix object
#' #'
#' @param object Object of class "xgb.DMatrix" #' @param object Object of class "xgb.DMatrix".
#' @param idxset a integer vector of indices of rows needed #' @param idxset An integer vector of indices of rows needed (base-1 indexing).
#' @param allow_groups Whether to allow slicing an `xgb.DMatrix` with `group` (or
#' equivalently `qid`) field. Note that in such case, the result will not have
#' the groups anymore - they need to be set manually through `setinfo`.
#' @param colset currently not used (columns subsetting is not available) #' @param colset currently not used (columns subsetting is not available)
#' #'
#' @examples #' @examples
@ -1275,11 +1278,11 @@ xgb.get.DMatrix.data <- function(dmat) {
#' #'
#' @rdname xgb.slice.DMatrix #' @rdname xgb.slice.DMatrix
#' @export #' @export
xgb.slice.DMatrix <- function(object, idxset) { xgb.slice.DMatrix <- function(object, idxset, allow_groups = FALSE) {
if (!inherits(object, "xgb.DMatrix")) { if (!inherits(object, "xgb.DMatrix")) {
stop("object must be xgb.DMatrix") stop("object must be xgb.DMatrix")
} }
ret <- .Call(XGDMatrixSliceDMatrix_R, object, idxset) ret <- .Call(XGDMatrixSliceDMatrix_R, object, idxset, allow_groups)
attr_list <- attributes(object) attr_list <- attributes(object)
nr <- nrow(object) nr <- nrow(object)
@ -1296,7 +1299,15 @@ xgb.slice.DMatrix <- function(object, idxset) {
} }
} }
} }
return(structure(ret, class = "xgb.DMatrix"))
out <- structure(ret, class = "xgb.DMatrix")
parent_fields <- as.list(attributes(object)$fields)
if (NROW(parent_fields)) {
child_fields <- parent_fields[!(names(parent_fields) %in% c("group", "qid"))]
child_fields <- as.environment(child_fields)
attributes(out)$fields <- child_fields
}
return(out)
} }
#' @rdname xgb.slice.DMatrix #' @rdname xgb.slice.DMatrix
@ -1340,11 +1351,11 @@ print.xgb.DMatrix <- function(x, verbose = FALSE, ...) {
} }
cat(class_print, ' dim:', nrow(x), 'x', ncol(x), ' info: ') cat(class_print, ' dim:', nrow(x), 'x', ncol(x), ' info: ')
infos <- character(0) infos <- names(attributes(x)$fields)
if (xgb.DMatrix.hasinfo(x, 'label')) infos <- 'label' infos <- infos[infos != "feature_name"]
if (xgb.DMatrix.hasinfo(x, 'weight')) infos <- c(infos, 'weight') if (!NROW(infos)) infos <- "NA"
if (xgb.DMatrix.hasinfo(x, 'base_margin')) infos <- c(infos, 'base_margin') infos <- infos[order(infos)]
if (length(infos) == 0) infos <- 'NA' infos <- paste(infos, collapse = ", ")
cat(infos) cat(infos)
cnames <- colnames(x) cnames <- colnames(x)
cat(' colnames:') cat(' colnames:')

View File

@ -1,6 +1,6 @@
#' Cross Validation #' Cross Validation
#' #'
#' The cross validation function of xgboost #' The cross validation function of xgboost.
#' #'
#' @param params the list of parameters. The complete list of parameters is #' @param params the list of parameters. The complete list of parameters is
#' available in the \href{http://xgboost.readthedocs.io/en/latest/parameter.html}{online documentation}. Below #' available in the \href{http://xgboost.readthedocs.io/en/latest/parameter.html}{online documentation}. Below
@ -19,13 +19,17 @@
#' #'
#' See \code{\link{xgb.train}} for further details. #' See \code{\link{xgb.train}} for further details.
#' See also demo/ for walkthrough example in R. #' See also demo/ for walkthrough example in R.
#' @param data takes an \code{xgb.DMatrix}, \code{matrix}, or \code{dgCMatrix} as the input. #'
#' Note that, while `params` accepts a `seed` entry and will use such parameter for model training if
#' supplied, this seed is not used for creation of train-test splits, which instead rely on R's own RNG
#' system - thus, for reproducible results, one needs to call the `set.seed` function beforehand.
#' @param data An `xgb.DMatrix` object, with corresponding fields like `label` or bounds as required
#' for model training by the objective.
#'
#' Note that only the basic `xgb.DMatrix` class is supported - variants such as `xgb.QuantileDMatrix`
#' or `xgb.ExternalDMatrix` are not supported here.
#' @param nrounds the max number of iterations #' @param nrounds the max number of iterations
#' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples. #' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples.
#' @param label vector of response values. Should be provided only when data is an R-matrix.
#' @param missing is only used when input is a dense matrix. By default is set to NA, which means
#' that NA values should be considered as 'missing' by the algorithm.
#' Sometimes, 0 or other extreme value might be used to represent missing values.
#' @param prediction A logical value indicating whether to return the test fold predictions #' @param prediction A logical value indicating whether to return the test fold predictions
#' from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback. #' from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback.
#' @param showsd \code{boolean}, whether to show standard deviation of cross validation #' @param showsd \code{boolean}, whether to show standard deviation of cross validation
@ -47,13 +51,30 @@
#' @param feval customized evaluation function. Returns #' @param feval customized evaluation function. Returns
#' \code{list(metric='metric-name', value='metric-value')} with given #' \code{list(metric='metric-name', value='metric-value')} with given
#' prediction and dtrain. #' prediction and dtrain.
#' @param stratified a \code{boolean} indicating whether sampling of folds should be stratified #' @param stratified A \code{boolean} indicating whether sampling of folds should be stratified
#' by the values of outcome labels. #' by the values of outcome labels. For real-valued labels in regression objectives,
#' stratification will be done by discretizing the labels into up to 5 buckets beforehand.
#'
#' If passing "auto", will be set to `TRUE` if the objective in `params` is a classification
#' objective (from XGBoost's built-in objectives, doesn't apply to custom ones), and to
#' `FALSE` otherwise.
#'
#' This parameter is ignored when `data` has a `group` field - in such case, the splitting
#' will be based on whole groups (note that this might make the folds have different sizes).
#'
#' Value `TRUE` here is \bold{not} supported for custom objectives.
#' @param folds \code{list} provides a possibility to use a list of pre-defined CV folds #' @param folds \code{list} provides a possibility to use a list of pre-defined CV folds
#' (each element must be a vector of test fold's indices). When folds are supplied, #' (each element must be a vector of test fold's indices). When folds are supplied,
#' the \code{nfold} and \code{stratified} parameters are ignored. #' the \code{nfold} and \code{stratified} parameters are ignored.
#'
#' If `data` has a `group` field and the objective requires this field, each fold (list element)
#' must additionally have two attributes (retrievable through \link{attributes}) named `group_test`
#' and `group_train`, which should hold the `group` to assign through \link{setinfo.xgb.DMatrix} to
#' the resulting DMatrices.
#' @param train_folds \code{list} list specifying which indicies to use for training. If \code{NULL} #' @param train_folds \code{list} list specifying which indicies to use for training. If \code{NULL}
#' (the default) all indices not specified in \code{folds} will be used for training. #' (the default) all indices not specified in \code{folds} will be used for training.
#'
#' This is not supported when `data` has `group` field.
#' @param verbose \code{boolean}, print the statistics during the process #' @param verbose \code{boolean}, print the statistics during the process
#' @param print_every_n Print each n-th iteration evaluation messages when \code{verbose>0}. #' @param print_every_n Print each n-th iteration evaluation messages when \code{verbose>0}.
#' Default is 1 which means all messages are printed. This parameter is passed to the #' Default is 1 which means all messages are printed. This parameter is passed to the
@ -118,13 +139,14 @@
#' print(cv, verbose=TRUE) #' print(cv, verbose=TRUE)
#' #'
#' @export #' @export
xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing = NA, xgb.cv <- function(params = list(), data, nrounds, nfold,
prediction = FALSE, showsd = TRUE, metrics = list(), prediction = FALSE, showsd = TRUE, metrics = list(),
obj = NULL, feval = NULL, stratified = TRUE, folds = NULL, train_folds = NULL, obj = NULL, feval = NULL, stratified = "auto", folds = NULL, train_folds = NULL,
verbose = TRUE, print_every_n = 1L, verbose = TRUE, print_every_n = 1L,
early_stopping_rounds = NULL, maximize = NULL, callbacks = list(), ...) { early_stopping_rounds = NULL, maximize = NULL, callbacks = list(), ...) {
check.deprecation(...) check.deprecation(...)
stopifnot(inherits(data, "xgb.DMatrix"))
if (inherits(data, "xgb.DMatrix") && .Call(XGCheckNullPtr_R, data)) { if (inherits(data, "xgb.DMatrix") && .Call(XGCheckNullPtr_R, data)) {
stop("'data' is an invalid 'xgb.DMatrix' object. Must be constructed again.") stop("'data' is an invalid 'xgb.DMatrix' object. Must be constructed again.")
} }
@ -137,16 +159,22 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
check.custom.obj() check.custom.obj()
check.custom.eval() check.custom.eval()
# Check the labels if (stratified == "auto") {
if ((inherits(data, 'xgb.DMatrix') && !xgb.DMatrix.hasinfo(data, 'label')) || if (is.character(params$objective)) {
(!inherits(data, 'xgb.DMatrix') && is.null(label))) { stratified <- (
stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix") (params$objective %in% .CLASSIFICATION_OBJECTIVES())
} else if (inherits(data, 'xgb.DMatrix')) { && !(params$objective %in% .RANKING_OBJECTIVES())
if (!is.null(label)) )
warning("xgb.cv: label will be ignored, since data is of type xgb.DMatrix") } else {
cv_label <- getinfo(data, 'label') stratified <- FALSE
} else { }
cv_label <- label }
# Check the labels and groups
cv_label <- getinfo(data, "label")
cv_group <- getinfo(data, "group")
if (!is.null(train_folds) && NROW(cv_group)) {
stop("'train_folds' is not supported for DMatrix object with 'group' field.")
} }
# CV folds # CV folds
@ -157,7 +185,7 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
} else { } else {
if (nfold <= 1) if (nfold <= 1)
stop("'nfold' must be > 1") stop("'nfold' must be > 1")
folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, params) folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, cv_group, params)
} }
# Callbacks # Callbacks
@ -195,20 +223,18 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
# create the booster-folds # create the booster-folds
# train_folds # train_folds
dall <- xgb.get.DMatrix( dall <- data
data = data,
label = label,
missing = missing,
weight = NULL,
nthread = params$nthread
)
bst_folds <- lapply(seq_along(folds), function(k) { bst_folds <- lapply(seq_along(folds), function(k) {
dtest <- xgb.slice.DMatrix(dall, folds[[k]]) dtest <- xgb.slice.DMatrix(dall, folds[[k]], allow_groups = TRUE)
# code originally contributed by @RolandASc on stackoverflow # code originally contributed by @RolandASc on stackoverflow
if (is.null(train_folds)) if (is.null(train_folds))
dtrain <- xgb.slice.DMatrix(dall, unlist(folds[-k])) dtrain <- xgb.slice.DMatrix(dall, unlist(folds[-k]), allow_groups = TRUE)
else else
dtrain <- xgb.slice.DMatrix(dall, train_folds[[k]]) dtrain <- xgb.slice.DMatrix(dall, train_folds[[k]], allow_groups = TRUE)
if (!is.null(attributes(folds[[k]])$group_test)) {
setinfo(dtest, "group", attributes(folds[[k]])$group_test)
setinfo(dtrain, "group", attributes(folds[[k]])$group_train)
}
bst <- xgb.Booster( bst <- xgb.Booster(
params = params, params = params,
cachelist = list(dtrain, dtest), cachelist = list(dtrain, dtest),
@ -312,8 +338,8 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
#' @examples #' @examples
#' data(agaricus.train, package='xgboost') #' data(agaricus.train, package='xgboost')
#' train <- agaricus.train #' train <- agaricus.train
#' cv <- xgb.cv(data = train$data, label = train$label, nfold = 5, max_depth = 2, #' cv <- xgb.cv(data = xgb.DMatrix(train$data, label = train$label), nfold = 5, max_depth = 2,
#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic") #' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
#' print(cv) #' print(cv)
#' print(cv, verbose=TRUE) #' print(cv, verbose=TRUE)
#' #'

View File

@ -23,8 +23,8 @@ including the best iteration (when available).
\examples{ \examples{
data(agaricus.train, package='xgboost') data(agaricus.train, package='xgboost')
train <- agaricus.train train <- agaricus.train
cv <- xgb.cv(data = train$data, label = train$label, nfold = 5, max_depth = 2, cv <- xgb.cv(data = xgb.DMatrix(train$data, label = train$label), nfold = 5, max_depth = 2,
eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic") eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
print(cv) print(cv)
print(cv, verbose=TRUE) print(cv, verbose=TRUE)

View File

@ -9,14 +9,12 @@ xgb.cv(
data, data,
nrounds, nrounds,
nfold, nfold,
label = NULL,
missing = NA,
prediction = FALSE, prediction = FALSE,
showsd = TRUE, showsd = TRUE,
metrics = list(), metrics = list(),
obj = NULL, obj = NULL,
feval = NULL, feval = NULL,
stratified = TRUE, stratified = "auto",
folds = NULL, folds = NULL,
train_folds = NULL, train_folds = NULL,
verbose = TRUE, verbose = TRUE,
@ -44,20 +42,23 @@ is a shorter summary:
} }
See \code{\link{xgb.train}} for further details. See \code{\link{xgb.train}} for further details.
See also demo/ for walkthrough example in R.} See also demo/ for walkthrough example in R.
\item{data}{takes an \code{xgb.DMatrix}, \code{matrix}, or \code{dgCMatrix} as the input.} Note that, while \code{params} accepts a \code{seed} entry and will use such parameter for model training if
supplied, this seed is not used for creation of train-test splits, which instead rely on R's own RNG
system - thus, for reproducible results, one needs to call the \code{set.seed} function beforehand.}
\item{data}{An \code{xgb.DMatrix} object, with corresponding fields like \code{label} or bounds as required
for model training by the objective.
\if{html}{\out{<div class="sourceCode">}}\preformatted{ Note that only the basic `xgb.DMatrix` class is supported - variants such as `xgb.QuantileDMatrix`
or `xgb.ExternalDMatrix` are not supported here.
}\if{html}{\out{</div>}}}
\item{nrounds}{the max number of iterations} \item{nrounds}{the max number of iterations}
\item{nfold}{the original dataset is randomly partitioned into \code{nfold} equal size subsamples.} \item{nfold}{the original dataset is randomly partitioned into \code{nfold} equal size subsamples.}
\item{label}{vector of response values. Should be provided only when data is an R-matrix.}
\item{missing}{is only used when input is a dense matrix. By default is set to NA, which means
that NA values should be considered as 'missing' by the algorithm.
Sometimes, 0 or other extreme value might be used to represent missing values.}
\item{prediction}{A logical value indicating whether to return the test fold predictions \item{prediction}{A logical value indicating whether to return the test fold predictions
from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback.} from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback.}
@ -84,15 +85,35 @@ gradient with given prediction and dtrain.}
\code{list(metric='metric-name', value='metric-value')} with given \code{list(metric='metric-name', value='metric-value')} with given
prediction and dtrain.} prediction and dtrain.}
\item{stratified}{a \code{boolean} indicating whether sampling of folds should be stratified \item{stratified}{A \code{boolean} indicating whether sampling of folds should be stratified
by the values of outcome labels.} by the values of outcome labels. For real-valued labels in regression objectives,
stratification will be done by discretizing the labels into up to 5 buckets beforehand.
\if{html}{\out{<div class="sourceCode">}}\preformatted{ If passing "auto", will be set to `TRUE` if the objective in `params` is a classification
objective (from XGBoost's built-in objectives, doesn't apply to custom ones), and to
`FALSE` otherwise.
This parameter is ignored when `data` has a `group` field - in such case, the splitting
will be based on whole groups (note that this might make the folds have different sizes).
Value `TRUE` here is \\bold\{not\} supported for custom objectives.
}\if{html}{\out{</div>}}}
\item{folds}{\code{list} provides a possibility to use a list of pre-defined CV folds \item{folds}{\code{list} provides a possibility to use a list of pre-defined CV folds
(each element must be a vector of test fold's indices). When folds are supplied, (each element must be a vector of test fold's indices). When folds are supplied,
the \code{nfold} and \code{stratified} parameters are ignored.} the \code{nfold} and \code{stratified} parameters are ignored.
\if{html}{\out{<div class="sourceCode">}}\preformatted{ If `data` has a `group` field and the objective requires this field, each fold (list element)
must additionally have two attributes (retrievable through \link{attributes}) named `group_test`
and `group_train`, which should hold the `group` to assign through \link{setinfo.xgb.DMatrix} to
the resulting DMatrices.
}\if{html}{\out{</div>}}}
\item{train_folds}{\code{list} list specifying which indicies to use for training. If \code{NULL} \item{train_folds}{\code{list} list specifying which indicies to use for training. If \code{NULL}
(the default) all indices not specified in \code{folds} will be used for training.} (the default) all indices not specified in \code{folds} will be used for training.
\if{html}{\out{<div class="sourceCode">}}\preformatted{ This is not supported when `data` has `group` field.
}\if{html}{\out{</div>}}}
\item{verbose}{\code{boolean}, print the statistics during the process} \item{verbose}{\code{boolean}, print the statistics during the process}
@ -142,7 +163,7 @@ such as saving also the models created during cross validation); or a list \code
will contain elements such as \code{best_iteration} when using the early stopping callback (\link{xgb.cb.early.stop}). will contain elements such as \code{best_iteration} when using the early stopping callback (\link{xgb.cb.early.stop}).
} }
\description{ \description{
The cross validation function of xgboost The cross validation function of xgboost.
} }
\details{ \details{
The original sample is randomly partitioned into \code{nfold} equal size subsamples. The original sample is randomly partitioned into \code{nfold} equal size subsamples.

View File

@ -6,14 +6,18 @@
\title{Get a new DMatrix containing the specified rows of \title{Get a new DMatrix containing the specified rows of
original xgb.DMatrix object} original xgb.DMatrix object}
\usage{ \usage{
xgb.slice.DMatrix(object, idxset) xgb.slice.DMatrix(object, idxset, allow_groups = FALSE)
\method{[}{xgb.DMatrix}(object, idxset, colset = NULL) \method{[}{xgb.DMatrix}(object, idxset, colset = NULL)
} }
\arguments{ \arguments{
\item{object}{Object of class "xgb.DMatrix"} \item{object}{Object of class "xgb.DMatrix".}
\item{idxset}{a integer vector of indices of rows needed} \item{idxset}{An integer vector of indices of rows needed (base-1 indexing).}
\item{allow_groups}{Whether to allow slicing an \code{xgb.DMatrix} with \code{group} (or
equivalently \code{qid}) field. Note that in such case, the result will not have
the groups anymore - they need to be set manually through \code{setinfo}.}
\item{colset}{currently not used (columns subsetting is not available)} \item{colset}{currently not used (columns subsetting is not available)}
} }

View File

@ -99,10 +99,12 @@ OBJECTS= \
$(PKGROOT)/src/context.o \ $(PKGROOT)/src/context.o \
$(PKGROOT)/src/logging.o \ $(PKGROOT)/src/logging.o \
$(PKGROOT)/src/global_config.o \ $(PKGROOT)/src/global_config.o \
$(PKGROOT)/src/collective/result.o \
$(PKGROOT)/src/collective/allgather.o \ $(PKGROOT)/src/collective/allgather.o \
$(PKGROOT)/src/collective/allreduce.o \ $(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \ $(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \ $(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/comm_group.o \
$(PKGROOT)/src/collective/coll.o \ $(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/communicator-inl.o \ $(PKGROOT)/src/collective/communicator-inl.o \
$(PKGROOT)/src/collective/tracker.o \ $(PKGROOT)/src/collective/tracker.o \

View File

@ -99,10 +99,12 @@ OBJECTS= \
$(PKGROOT)/src/context.o \ $(PKGROOT)/src/context.o \
$(PKGROOT)/src/logging.o \ $(PKGROOT)/src/logging.o \
$(PKGROOT)/src/global_config.o \ $(PKGROOT)/src/global_config.o \
$(PKGROOT)/src/collective/result.o \
$(PKGROOT)/src/collective/allgather.o \ $(PKGROOT)/src/collective/allgather.o \
$(PKGROOT)/src/collective/allreduce.o \ $(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \ $(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \ $(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/comm_group.o \
$(PKGROOT)/src/collective/coll.o \ $(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/communicator-inl.o \ $(PKGROOT)/src/collective/communicator-inl.o \
$(PKGROOT)/src/collective/tracker.o \ $(PKGROOT)/src/collective/tracker.o \

View File

@ -71,7 +71,7 @@ extern SEXP XGDMatrixGetDataAsCSR_R(SEXP);
extern SEXP XGDMatrixSaveBinary_R(SEXP, SEXP, SEXP); extern SEXP XGDMatrixSaveBinary_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP); extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixSetStrFeatureInfo_R(SEXP, SEXP, SEXP); extern SEXP XGDMatrixSetStrFeatureInfo_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP); extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP, SEXP);
extern SEXP XGBSetGlobalConfig_R(SEXP); extern SEXP XGBSetGlobalConfig_R(SEXP);
extern SEXP XGBGetGlobalConfig_R(void); extern SEXP XGBGetGlobalConfig_R(void);
extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP); extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP);
@ -134,7 +134,7 @@ static const R_CallMethodDef CallEntries[] = {
{"XGDMatrixSaveBinary_R", (DL_FUNC) &XGDMatrixSaveBinary_R, 3}, {"XGDMatrixSaveBinary_R", (DL_FUNC) &XGDMatrixSaveBinary_R, 3},
{"XGDMatrixSetInfo_R", (DL_FUNC) &XGDMatrixSetInfo_R, 3}, {"XGDMatrixSetInfo_R", (DL_FUNC) &XGDMatrixSetInfo_R, 3},
{"XGDMatrixSetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixSetStrFeatureInfo_R, 3}, {"XGDMatrixSetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixSetStrFeatureInfo_R, 3},
{"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 2}, {"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 3},
{"XGBSetGlobalConfig_R", (DL_FUNC) &XGBSetGlobalConfig_R, 1}, {"XGBSetGlobalConfig_R", (DL_FUNC) &XGBSetGlobalConfig_R, 1},
{"XGBGetGlobalConfig_R", (DL_FUNC) &XGBGetGlobalConfig_R, 0}, {"XGBGetGlobalConfig_R", (DL_FUNC) &XGBGetGlobalConfig_R, 0},
{"XGBoosterFeatureScore_R", (DL_FUNC) &XGBoosterFeatureScore_R, 2}, {"XGBoosterFeatureScore_R", (DL_FUNC) &XGBoosterFeatureScore_R, 2},

View File

@ -512,7 +512,7 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP
return ret; return ret;
} }
XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) { XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset, SEXP allow_groups) {
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue)); SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
R_API_BEGIN(); R_API_BEGIN();
R_xlen_t len = Rf_xlength(idxset); R_xlen_t len = Rf_xlength(idxset);
@ -531,7 +531,7 @@ XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
res_code = XGDMatrixSliceDMatrixEx(R_ExternalPtrAddr(handle), res_code = XGDMatrixSliceDMatrixEx(R_ExternalPtrAddr(handle),
BeginPtr(idxvec), len, BeginPtr(idxvec), len,
&res, &res,
0); Rf_asLogical(allow_groups));
} }
CHECK_CALL(res_code); CHECK_CALL(res_code);
R_SetExternalPtrAddr(ret, res); R_SetExternalPtrAddr(ret, res);

View File

@ -112,9 +112,10 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP
* \brief create a new dmatrix from sliced content of existing matrix * \brief create a new dmatrix from sliced content of existing matrix
* \param handle instance of data matrix to be sliced * \param handle instance of data matrix to be sliced
* \param idxset index set * \param idxset index set
* \param allow_groups Whether to allow slicing the DMatrix if it has a 'group' field
* \return a sliced new matrix * \return a sliced new matrix
*/ */
XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset); XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset, SEXP allow_groups);
/*! /*!
* \brief load a data matrix into binary file * \brief load a data matrix into binary file

View File

@ -334,7 +334,7 @@ test_that("xgb.cv works", {
set.seed(11) set.seed(11)
expect_output( expect_output(
cv <- xgb.cv( cv <- xgb.cv(
data = train$data, label = train$label, max_depth = 2, nfold = 5, data = xgb.DMatrix(train$data, label = train$label), max_depth = 2, nfold = 5,
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic", eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
eval_metric = "error", verbose = TRUE eval_metric = "error", verbose = TRUE
), ),
@ -357,13 +357,13 @@ test_that("xgb.cv works with stratified folds", {
cv <- xgb.cv( cv <- xgb.cv(
data = dtrain, max_depth = 2, nfold = 5, data = dtrain, max_depth = 2, nfold = 5,
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic", eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
verbose = TRUE, stratified = FALSE verbose = FALSE, stratified = FALSE
) )
set.seed(314159) set.seed(314159)
cv2 <- xgb.cv( cv2 <- xgb.cv(
data = dtrain, max_depth = 2, nfold = 5, data = dtrain, max_depth = 2, nfold = 5,
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic", eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
verbose = TRUE, stratified = TRUE verbose = FALSE, stratified = TRUE
) )
# Stratified folds should result in a different evaluation logs # Stratified folds should result in a different evaluation logs
expect_true(all(cv$evaluation_log[, test_logloss_mean] != cv2$evaluation_log[, test_logloss_mean])) expect_true(all(cv$evaluation_log[, test_logloss_mean] != cv2$evaluation_log[, test_logloss_mean]))
@ -885,3 +885,57 @@ test_that("Seed in params override PRNG from R", {
) )
) )
}) })
test_that("xgb.cv works for AFT", {
X <- matrix(c(1, -1, -1, 1, 0, 1, 1, 0), nrow = 4, byrow = TRUE) # 4x2 matrix
dtrain <- xgb.DMatrix(X, nthread = n_threads)
params <- list(objective = 'survival:aft', learning_rate = 0.2, max_depth = 2L)
# data must have bounds
expect_error(
xgb.cv(
params = params,
data = dtrain,
nround = 5L,
nfold = 4L,
nthread = n_threads
)
)
setinfo(dtrain, 'label_lower_bound', c(2, 3, 0, 4))
setinfo(dtrain, 'label_upper_bound', c(2, Inf, 4, 5))
# automatic stratified splitting is turned off
expect_warning(
xgb.cv(
params = params, data = dtrain, nround = 5L, nfold = 4L,
nthread = n_threads, stratified = TRUE, verbose = FALSE
)
)
# this works without any issue
expect_no_warning(
xgb.cv(params = params, data = dtrain, nround = 5L, nfold = 4L, verbose = FALSE)
)
})
test_that("xgb.cv works for ranking", {
data(iris)
x <- iris[, -(4:5)]
y <- as.integer(iris$Petal.Width)
group <- rep(50, 3)
dm <- xgb.DMatrix(x, label = y, group = group)
res <- xgb.cv(
data = dm,
params = list(
objective = "rank:pairwise",
max_depth = 3
),
nrounds = 3,
nfold = 2,
verbose = FALSE,
stratified = FALSE
)
expect_equal(length(res$folds), 2L)
})

View File

@ -367,7 +367,7 @@ test_that("prediction in early-stopping xgb.cv works", {
expect_output( expect_output(
cv <- xgb.cv(param, dtrain, nfold = 5, eta = 0.1, nrounds = 20, cv <- xgb.cv(param, dtrain, nfold = 5, eta = 0.1, nrounds = 20,
early_stopping_rounds = 5, maximize = FALSE, stratified = FALSE, early_stopping_rounds = 5, maximize = FALSE, stratified = FALSE,
prediction = TRUE, base_score = 0.5) prediction = TRUE, base_score = 0.5, verbose = TRUE)
, "Stopping. Best iteration") , "Stopping. Best iteration")
expect_false(is.null(cv$early_stop$best_iteration)) expect_false(is.null(cv$early_stop$best_iteration))
@ -387,7 +387,7 @@ test_that("prediction in xgb.cv for softprob works", {
lb <- as.numeric(iris$Species) - 1 lb <- as.numeric(iris$Species) - 1
set.seed(11) set.seed(11)
expect_warning( expect_warning(
cv <- xgb.cv(data = as.matrix(iris[, -5]), label = lb, nfold = 4, cv <- xgb.cv(data = xgb.DMatrix(as.matrix(iris[, -5]), label = lb), nfold = 4,
eta = 0.5, nrounds = 5, max_depth = 3, nthread = n_threads, eta = 0.5, nrounds = 5, max_depth = 3, nthread = n_threads,
subsample = 0.8, gamma = 2, verbose = 0, subsample = 0.8, gamma = 2, verbose = 0,
prediction = TRUE, objective = "multi:softprob", num_class = 3) prediction = TRUE, objective = "multi:softprob", num_class = 3)

View File

@ -243,7 +243,7 @@ test_that("xgb.DMatrix: print", {
txt <- capture.output({ txt <- capture.output({
print(dtrain) print(dtrain)
}) })
expect_equal(txt, "xgb.DMatrix dim: 6513 x 126 info: label weight base_margin colnames: yes") expect_equal(txt, "xgb.DMatrix dim: 6513 x 126 info: base_margin, label, weight colnames: yes")
# DMatrix with just features # DMatrix with just features
dtrain <- xgb.DMatrix( dtrain <- xgb.DMatrix(
@ -724,6 +724,44 @@ test_that("xgb.DMatrix: quantile cuts look correct", {
) )
}) })
test_that("xgb.DMatrix: slicing keeps field indicators", {
data(mtcars)
x <- as.matrix(mtcars[, -1])
y <- mtcars[, 1]
dm <- xgb.DMatrix(
data = x,
label_lower_bound = -y,
label_upper_bound = y,
nthread = 1
)
idx_take <- seq(1, 5)
dm_slice <- xgb.slice.DMatrix(dm, idx_take)
expect_true(xgb.DMatrix.hasinfo(dm_slice, "label_lower_bound"))
expect_true(xgb.DMatrix.hasinfo(dm_slice, "label_upper_bound"))
expect_false(xgb.DMatrix.hasinfo(dm_slice, "label"))
expect_equal(getinfo(dm_slice, "label_lower_bound"), -y[idx_take], tolerance = 1e-6)
expect_equal(getinfo(dm_slice, "label_upper_bound"), y[idx_take], tolerance = 1e-6)
})
test_that("xgb.DMatrix: can slice with groups", {
data(iris)
x <- as.matrix(iris[, -5])
set.seed(123)
y <- sample(3, size = nrow(x), replace = TRUE)
group <- c(50, 50, 50)
dm <- xgb.DMatrix(x, label = y, group = group, nthread = 1)
idx_take <- seq(1, 50)
dm_slice <- xgb.slice.DMatrix(dm, idx_take, allow_groups = TRUE)
expect_true(xgb.DMatrix.hasinfo(dm_slice, "label"))
expect_false(xgb.DMatrix.hasinfo(dm_slice, "group"))
expect_false(xgb.DMatrix.hasinfo(dm_slice, "qid"))
expect_null(getinfo(dm_slice, "group"))
expect_equal(getinfo(dm_slice, "label"), y[idx_take], tolerance = 1e-6)
})
test_that("xgb.DMatrix: can read CSV", { test_that("xgb.DMatrix: can read CSV", {
txt <- paste( txt <- paste(
"1,2,3", "1,2,3",

View File

@ -40,7 +40,7 @@ def main(client):
# you can pass output directly into `predict` too. # you can pass output directly into `predict` too.
prediction = dxgb.predict(client, bst, dtrain) prediction = dxgb.predict(client, bst, dtrain)
print("Evaluation history:", history) print("Evaluation history:", history)
return prediction print("Error:", da.sqrt((prediction - y) ** 2).mean().compute())
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -144,6 +144,14 @@ which provides higher flexibility. For example:
ctest --verbose ctest --verbose
If you need to debug errors on Windows using the debugger from VS, you can append the gtest flags in `test_main.cc`:
.. code-block::
::testing::GTEST_FLAG(filter) = "Suite.Test";
::testing::GTEST_FLAG(repeat) = 10;
*********************************************** ***********************************************
Sanitizers: Detect memory errors and data races Sanitizers: Detect memory errors and data races
*********************************************** ***********************************************

View File

@ -28,7 +28,7 @@ Contents
Python Package <python/index> Python Package <python/index>
R Package <R-package/index> R Package <R-package/index>
JVM Package <jvm/index> JVM Package <jvm/index>
Ruby Package <https://github.com/ankane/xgb> Ruby Package <https://github.com/ankane/xgboost-ruby>
Swift Package <https://github.com/kongzii/SwiftXGBoost> Swift Package <https://github.com/kongzii/SwiftXGBoost>
Julia Package <julia> Julia Package <julia>
C Package <c> C Package <c>

View File

@ -52,7 +52,7 @@ Notice that the samples are sorted based on their query index in a non-decreasin
X, y = make_classification(random_state=seed) X, y = make_classification(random_state=seed)
rng = np.random.default_rng(seed) rng = np.random.default_rng(seed)
n_query_groups = 3 n_query_groups = 3
qid = rng.integers(0, 3, size=X.shape[0]) qid = rng.integers(0, n_query_groups, size=X.shape[0])
# Sort the inputs based on query index # Sort the inputs based on query index
sorted_idx = np.argsort(qid) sorted_idx = np.argsort(qid)
@ -65,14 +65,14 @@ The simplest way to train a ranking model is by using the scikit-learn estimator
.. code-block:: python .. code-block:: python
ranker = xgb.XGBRanker(tree_method="hist", lambdarank_num_pair_per_sample=8, objective="rank:ndcg", lambdarank_pair_method="topk") ranker = xgb.XGBRanker(tree_method="hist", lambdarank_num_pair_per_sample=8, objective="rank:ndcg", lambdarank_pair_method="topk")
ranker.fit(X, y, qid=qid) ranker.fit(X, y, qid=qid[sorted_idx])
Please note that, as of writing, there's no learning-to-rank interface in scikit-learn. As a result, the :py:class:`xgboost.XGBRanker` class does not fully conform the scikit-learn estimator guideline and can not be directly used with some of its utility functions. For instances, the ``auc_score`` and ``ndcg_score`` in scikit-learn don't consider query group information nor the pairwise loss. Most of the metrics are implemented as part of XGBoost, but to use scikit-learn utilities like :py:func:`sklearn.model_selection.cross_validation`, we need to make some adjustments in order to pass the ``qid`` as an additional parameter for :py:meth:`xgboost.XGBRanker.score`. Given a data frame ``X`` (either pandas or cuDF), add the column ``qid`` as follows: Please note that, as of writing, there's no learning-to-rank interface in scikit-learn. As a result, the :py:class:`xgboost.XGBRanker` class does not fully conform the scikit-learn estimator guideline and can not be directly used with some of its utility functions. For instances, the ``auc_score`` and ``ndcg_score`` in scikit-learn don't consider query group information nor the pairwise loss. Most of the metrics are implemented as part of XGBoost, but to use scikit-learn utilities like :py:func:`sklearn.model_selection.cross_validation`, we need to make some adjustments in order to pass the ``qid`` as an additional parameter for :py:meth:`xgboost.XGBRanker.score`. Given a data frame ``X`` (either pandas or cuDF), add the column ``qid`` as follows:
.. code-block:: python .. code-block:: python
df = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])]) df = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])])
df["qid"] = qid df["qid"] = qid[sorted_idx]
ranker.fit(df, y) # No need to pass qid as a separate argument ranker.fit(df, y) # No need to pass qid as a separate argument
from sklearn.model_selection import StratifiedGroupKFold, cross_val_score from sklearn.model_selection import StratifiedGroupKFold, cross_val_score

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2015~2023 by XGBoost Contributors * Copyright 2015-2024, XGBoost Contributors
* \file c_api.h * \file c_api.h
* \author Tianqi Chen * \author Tianqi Chen
* \brief C API of XGBoost, used for interfacing to other languages. * \brief C API of XGBoost, used for interfacing to other languages.
@ -639,21 +639,14 @@ XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle,
* \param len length of array * \param len length of array
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const float *array,
const char *field,
const float *array,
bst_ulong len); bst_ulong len);
/*! /**
* \brief set uint32 vector to a content in info * @deprecated since 2.1.0
* \param handle a instance of data matrix *
* \param field field name * Use @ref XGDMatrixSetInfoFromInterface instead.
* \param array pointer to unsigned int vector
* \param len length of array
* \return 0 when success, -1 when failure happens
*/ */
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const unsigned *array,
const char *field,
const unsigned *array,
bst_ulong len); bst_ulong len);
/*! /*!
@ -725,42 +718,13 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
bst_ulong *size, bst_ulong *size,
const char ***out_features); const char ***out_features);
/*! /**
* \brief Set meta info from dense matrix. Valid field names are: * @deprecated since 2.1.0
* *
* - label * Use @ref XGDMatrixSetInfoFromInterface instead.
* - weight
* - base_margin
* - group
* - label_lower_bound
* - label_upper_bound
* - feature_weights
*
* \param handle An instance of data matrix
* \param field Field name
* \param data Pointer to consecutive memory storing data.
* \param size Size of the data, this is relative to size of type. (Meaning NOT number
* of bytes.)
* \param type Indicator of data type. This is defined in xgboost::DataType enum class.
* - float = 1
* - double = 2
* - uint32_t = 3
* - uint64_t = 4
* \return 0 when success, -1 when failure happens
*/ */
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void const *data,
void const *data, bst_ulong size, int type); bst_ulong size, int type);
/*!
* \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix
* \param handle a instance of data matrix
* \param group pointer to group size
* \param len length of array
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
const unsigned *group,
bst_ulong len);
/*! /*!
* \brief get float info vector from matrix. * \brief get float info vector from matrix.
@ -1591,7 +1555,7 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
/** /**
* @brief Get the arguments needed for running workers. This should be called after * @brief Get the arguments needed for running workers. This should be called after
* XGTrackerRun() and XGTrackerWait() * XGTrackerRun().
* *
* @param handle The handle to the tracker. * @param handle The handle to the tracker.
* @param args The arguments returned as a JSON document. * @param args The arguments returned as a JSON document.
@ -1601,16 +1565,19 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args); XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args);
/** /**
* @brief Run the tracker. * @brief Start the tracker. The tracker runs in the background and this function returns
* once the tracker is started.
* *
* @param handle The handle to the tracker. * @param handle The handle to the tracker.
* @param config Unused at the moment, preserved for the future.
* *
* @return 0 for success, -1 for failure. * @return 0 for success, -1 for failure.
*/ */
XGB_DLL int XGTrackerRun(TrackerHandle handle); XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *config);
/** /**
* @brief Wait for the tracker to finish, should be called after XGTrackerRun(). * @brief Wait for the tracker to finish, should be called after XGTrackerRun(). This
* function will block until the tracker task is finished or timeout is reached.
* *
* @param handle The handle to the tracker. * @param handle The handle to the tracker.
* @param config JSON encoded configuration. No argument is required yet, preserved for * @param config JSON encoded configuration. No argument is required yet, preserved for
@ -1618,11 +1585,12 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle);
* *
* @return 0 for success, -1 for failure. * @return 0 for success, -1 for failure.
*/ */
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config); XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config);
/** /**
* @brief Free a tracker instance. XGTrackerWait() is called internally. If the tracker * @brief Free a tracker instance. This should be called after XGTrackerWaitFor(). If the
* cannot close properly, manual interruption is required. * tracker is not properly waited, this function will shutdown all connections with
* the tracker, potentially leading to undefined behavior.
* *
* @param handle The handle to the tracker. * @param handle The handle to the tracker.
* *

View File

@ -3,13 +3,11 @@
*/ */
#pragma once #pragma once
#include <xgboost/logging.h> #include <cstdint> // for int32_t
#include <memory> // for unique_ptr
#include <memory> // for unique_ptr #include <string> // for string
#include <sstream> // for stringstream #include <system_error> // for error_code
#include <stack> // for stack #include <utility> // for move
#include <string> // for string
#include <utility> // for move
namespace xgboost::collective { namespace xgboost::collective {
namespace detail { namespace detail {
@ -48,48 +46,19 @@ struct ResultImpl {
return cur_eq; return cur_eq;
} }
[[nodiscard]] std::string Report() { [[nodiscard]] std::string Report() const;
std::stringstream ss; [[nodiscard]] std::error_code Code() const;
ss << "\n- " << this->message;
if (this->errc != std::error_code{}) {
ss << " system error:" << this->errc.message();
}
auto ptr = prev.get(); void Concat(std::unique_ptr<ResultImpl> rhs);
while (ptr) {
ss << "\n- ";
ss << ptr->message;
if (ptr->errc != std::error_code{}) {
ss << " " << ptr->errc.message();
}
ptr = ptr->prev.get();
}
return ss.str();
}
[[nodiscard]] auto Code() const {
// Find the root error.
std::stack<ResultImpl const*> stack;
auto ptr = this;
while (ptr) {
stack.push(ptr);
if (ptr->prev) {
ptr = ptr->prev.get();
} else {
break;
}
}
while (!stack.empty()) {
auto frame = stack.top();
stack.pop();
if (frame->errc != std::error_code{}) {
return frame->errc;
}
}
return std::error_code{};
}
}; };
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
#define __builtin_FILE() nullptr
#define __builtin_LINE() (-1)
std::string MakeMsg(std::string&& msg, char const*, std::int32_t);
#else
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line);
#endif
} // namespace detail } // namespace detail
/** /**
@ -131,8 +100,21 @@ struct Result {
} }
return *impl_ == *that.impl_; return *impl_ == *that.impl_;
} }
friend Result operator+(Result&& lhs, Result&& rhs);
}; };
[[nodiscard]] inline Result operator+(Result&& lhs, Result&& rhs) {
if (lhs.OK()) {
return std::forward<Result>(rhs);
}
if (rhs.OK()) {
return std::forward<Result>(lhs);
}
lhs.impl_->Concat(std::move(rhs.impl_));
return std::forward<Result>(lhs);
}
/** /**
* @brief Return success. * @brief Return success.
*/ */
@ -140,38 +122,43 @@ struct Result {
/** /**
* @brief Return failure. * @brief Return failure.
*/ */
[[nodiscard]] inline auto Fail(std::string msg) { return Result{std::move(msg)}; } [[nodiscard]] inline auto Fail(std::string msg, char const* file = __builtin_FILE(),
std::int32_t line = __builtin_LINE()) {
return Result{detail::MakeMsg(std::move(msg), file, line)};
}
/** /**
* @brief Return failure with `errno`. * @brief Return failure with `errno`.
*/ */
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc) { [[nodiscard]] inline auto Fail(std::string msg, std::error_code errc,
return Result{std::move(msg), std::move(errc)}; char const* file = __builtin_FILE(),
std::int32_t line = __builtin_LINE()) {
return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc)};
} }
/** /**
* @brief Return failure with a previous error. * @brief Return failure with a previous error.
*/ */
[[nodiscard]] inline auto Fail(std::string msg, Result&& prev) { [[nodiscard]] inline auto Fail(std::string msg, Result&& prev, char const* file = __builtin_FILE(),
return Result{std::move(msg), std::forward<Result>(prev)}; std::int32_t line = __builtin_LINE()) {
return Result{detail::MakeMsg(std::move(msg), file, line), std::forward<Result>(prev)};
} }
/** /**
* @brief Return failure with a previous error and a new `errno`. * @brief Return failure with a previous error and a new `errno`.
*/ */
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev) { [[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev,
return Result{std::move(msg), std::move(errc), std::forward<Result>(prev)}; char const* file = __builtin_FILE(),
std::int32_t line = __builtin_LINE()) {
return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc),
std::forward<Result>(prev)};
} }
// We don't have monad, a simple helper would do. // We don't have monad, a simple helper would do.
template <typename Fn> template <typename Fn>
[[nodiscard]] Result operator<<(Result&& r, Fn&& fn) { [[nodiscard]] std::enable_if_t<std::is_invocable_v<Fn>, Result> operator<<(Result&& r, Fn&& fn) {
if (!r.OK()) { if (!r.OK()) {
return std::forward<Result>(r); return std::forward<Result>(r);
} }
return fn(); return fn();
} }
inline void SafeColl(Result const& rc) { void SafeColl(Result const& rc);
if (!rc.OK()) {
LOG(FATAL) << rc.Report();
}
}
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -1,5 +1,5 @@
/** /**
* Copyright (c) 2022-2023, XGBoost Contributors * Copyright (c) 2022-2024, XGBoost Contributors
*/ */
#pragma once #pragma once
@ -12,7 +12,6 @@
#include <cstddef> // std::size_t #include <cstddef> // std::size_t
#include <cstdint> // std::int32_t, std::uint16_t #include <cstdint> // std::int32_t, std::uint16_t
#include <cstring> // memset #include <cstring> // memset
#include <limits> // std::numeric_limits
#include <string> // std::string #include <string> // std::string
#include <system_error> // std::error_code, std::system_category #include <system_error> // std::error_code, std::system_category
#include <utility> // std::swap #include <utility> // std::swap
@ -125,6 +124,21 @@ inline std::int32_t CloseSocket(SocketT fd) {
#endif #endif
} }
inline std::int32_t ShutdownSocket(SocketT fd) {
#if defined(_WIN32)
auto rc = shutdown(fd, SD_BOTH);
if (rc != 0 && LastError() == WSANOTINITIALISED) {
return 0;
}
#else
auto rc = shutdown(fd, SHUT_RDWR);
if (rc != 0 && LastError() == ENOTCONN) {
return 0;
}
#endif
return rc;
}
inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) { inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
#ifdef _WIN32 #ifdef _WIN32
return errsv == WSAEWOULDBLOCK; return errsv == WSAEWOULDBLOCK;
@ -468,19 +482,30 @@ class TCPSocket {
*addr = SockAddress{SockAddrV6{caddr}}; *addr = SockAddress{SockAddrV6{caddr}};
*out = TCPSocket{newfd}; *out = TCPSocket{newfd};
} }
// On MacOS, this is automatically set to async socket if the parent socket is async
// We make sure all socket are blocking by default.
//
// On Windows, a closed socket is returned during shutdown. We guard against it when
// setting non-blocking.
if (!out->IsClosed()) {
return out->NonBlocking(false);
}
return Success(); return Success();
} }
~TCPSocket() { ~TCPSocket() {
if (!IsClosed()) { if (!IsClosed()) {
Close(); auto rc = this->Close();
if (!rc.OK()) {
LOG(WARNING) << rc.Report();
}
} }
} }
TCPSocket(TCPSocket const &that) = delete; TCPSocket(TCPSocket const &that) = delete;
TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); } TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); }
TCPSocket &operator=(TCPSocket const &that) = delete; TCPSocket &operator=(TCPSocket const &that) = delete;
TCPSocket &operator=(TCPSocket &&that) { TCPSocket &operator=(TCPSocket &&that) noexcept(true) {
std::swap(this->handle_, that.handle_); std::swap(this->handle_, that.handle_);
return *this; return *this;
} }
@ -489,36 +514,49 @@ class TCPSocket {
*/ */
[[nodiscard]] HandleT const &Handle() const { return handle_; } [[nodiscard]] HandleT const &Handle() const { return handle_; }
/** /**
* \brief Listen to incoming requests. Should be called after bind. * @brief Listen to incoming requests. Should be called after bind.
*/ */
void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); } [[nodiscard]] Result Listen(std::int32_t backlog = 16) {
if (listen(handle_, backlog) != 0) {
return system::FailWithCode("Failed to listen.");
}
return Success();
}
/** /**
* \brief Bind socket to INADDR_ANY, return the port selected by the OS. * @brief Bind socket to INADDR_ANY, return the port selected by the OS.
*/ */
[[nodiscard]] in_port_t BindHost() { [[nodiscard]] Result BindHost(std::int32_t* p_out) {
// Use int32 instead of in_port_t for consistency. We take port as parameter from
// users using other languages, the port is usually stored and passed around as int.
if (Domain() == SockDomain::kV6) { if (Domain() == SockDomain::kV6) {
auto addr = SockAddrV6::InaddrAny(); auto addr = SockAddrV6::InaddrAny();
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle()); auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
xgboost_CHECK_SYS_CALL( if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0); return system::FailWithCode("bind failed.");
}
sockaddr_in6 res_addr; sockaddr_in6 res_addr;
socklen_t addrlen = sizeof(res_addr); socklen_t addrlen = sizeof(res_addr);
xgboost_CHECK_SYS_CALL( if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0); return system::FailWithCode("getsockname failed.");
return ntohs(res_addr.sin6_port); }
*p_out = ntohs(res_addr.sin6_port);
} else { } else {
auto addr = SockAddrV4::InaddrAny(); auto addr = SockAddrV4::InaddrAny();
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle()); auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
xgboost_CHECK_SYS_CALL( if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0); return system::FailWithCode("bind failed.");
}
sockaddr_in res_addr; sockaddr_in res_addr;
socklen_t addrlen = sizeof(res_addr); socklen_t addrlen = sizeof(res_addr);
xgboost_CHECK_SYS_CALL( if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0); return system::FailWithCode("getsockname failed.");
return ntohs(res_addr.sin_port); }
*p_out = ntohs(res_addr.sin_port);
} }
return Success();
} }
[[nodiscard]] auto Port() const { [[nodiscard]] auto Port() const {
@ -631,26 +669,49 @@ class TCPSocket {
*/ */
std::size_t Send(StringView str); std::size_t Send(StringView str);
/** /**
* \brief Receive string, format is matched with the Python socket wrapper in RABIT. * @brief Receive string, format is matched with the Python socket wrapper in RABIT.
*/ */
std::size_t Recv(std::string *p_str); [[nodiscard]] Result Recv(std::string *p_str);
/** /**
* \brief Close the socket, called automatically in destructor if the socket is not closed. * @brief Close the socket, called automatically in destructor if the socket is not closed.
*/ */
void Close() { [[nodiscard]] Result Close() {
if (InvalidSocket() != handle_) { if (InvalidSocket() != handle_) {
#if defined(_WIN32)
auto rc = system::CloseSocket(handle_); auto rc = system::CloseSocket(handle_);
#if defined(_WIN32)
// it's possible that we close TCP sockets after finalizing WSA due to detached thread. // it's possible that we close TCP sockets after finalizing WSA due to detached thread.
if (rc != 0 && system::LastError() != WSANOTINITIALISED) { if (rc != 0 && system::LastError() != WSANOTINITIALISED) {
system::ThrowAtError("close", rc); return system::FailWithCode("Failed to close the socket.");
} }
#else #else
xgboost_CHECK_SYS_CALL(system::CloseSocket(handle_), 0); if (rc != 0) {
return system::FailWithCode("Failed to close the socket.");
}
#endif #endif
handle_ = InvalidSocket(); handle_ = InvalidSocket();
} }
return Success();
} }
/**
* @brief Call shutdown on the socket.
*/
[[nodiscard]] Result Shutdown() {
if (this->IsClosed()) {
return Success();
}
auto rc = system::ShutdownSocket(this->Handle());
#if defined(_WIN32)
// Windows cannot shutdown a socket if it's not connected.
if (rc == -1 && system::LastError() == WSAENOTCONN) {
return Success();
}
#endif
if (rc != 0) {
return system::FailWithCode("Failed to shutdown socket.");
}
return Success();
}
/** /**
* \brief Create a TCP socket on specified domain. * \brief Create a TCP socket on specified domain.
*/ */

View File

@ -19,7 +19,6 @@
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <numeric>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -137,14 +136,6 @@ class MetaInfo {
* \param fo The output stream. * \param fo The output stream.
*/ */
void SaveBinary(dmlc::Stream* fo) const; void SaveBinary(dmlc::Stream* fo) const;
/*!
* \brief Set information in the meta info.
* \param key The key of the information.
* \param dptr The data pointer of the source array.
* \param dtype The type of the source data.
* \param num Number of elements in the source array.
*/
void SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype, size_t num);
/*! /*!
* \brief Set information in the meta info with array interface. * \brief Set information in the meta info with array interface.
* \param key The key of the information. * \param key The key of the information.
@ -517,10 +508,6 @@ class DMatrix {
DMatrix() = default; DMatrix() = default;
/*! \brief meta information of the dataset */ /*! \brief meta information of the dataset */
virtual MetaInfo& Info() = 0; virtual MetaInfo& Info() = 0;
virtual void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
auto const& ctx = *this->Ctx();
this->Info().SetInfo(ctx, key, dptr, dtype, num);
}
virtual void SetInfo(const char* key, std::string const& interface_str) { virtual void SetInfo(const char* key, std::string const& interface_str) {
auto const& ctx = *this->Ctx(); auto const& ctx = *this->Ctx();
this->Info().SetInfo(ctx, key, StringView{interface_str}); this->Info().SetInfo(ctx, key, StringView{interface_str});

View File

@ -190,13 +190,14 @@ constexpr auto ArrToTuple(T (&arr)[N]) {
// uint division optimization inspired by the CIndexer in cupy. Division operation is // uint division optimization inspired by the CIndexer in cupy. Division operation is
// slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64 // slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64
// bit when the index is smaller, then try to avoid division when it's exp of 2. // bit when the index is smaller, then try to avoid division when it's exp of 2.
template <typename I, int32_t D> template <typename I, std::int32_t D>
LINALG_HD auto UnravelImpl(I idx, common::Span<size_t const, D> shape) { LINALG_HD auto UnravelImpl(I idx, common::Span<size_t const, D> shape) {
size_t index[D]{0}; std::size_t index[D]{0};
static_assert(std::is_signed<decltype(D)>::value, static_assert(std::is_signed<decltype(D)>::value,
"Don't change the type without changing the for loop."); "Don't change the type without changing the for loop.");
auto const sptr = shape.data();
for (int32_t dim = D; --dim > 0;) { for (int32_t dim = D; --dim > 0;) {
auto s = static_cast<std::remove_const_t<std::remove_reference_t<I>>>(shape[dim]); auto s = static_cast<std::remove_const_t<std::remove_reference_t<I>>>(sptr[dim]);
if (s & (s - 1)) { if (s & (s - 1)) {
auto t = idx / s; auto t = idx / s;
index[dim] = idx - t * s; index[dim] = idx - t * s;
@ -745,6 +746,14 @@ auto ArrayInterfaceStr(TensorView<T, D> const &t) {
return str; return str;
} }
template <typename T>
auto Make1dInterface(T const *vec, std::size_t len) {
Context ctx;
auto t = linalg::MakeTensorView(&ctx, common::Span{vec, len}, len);
auto str = linalg::ArrayInterfaceStr(t);
return str;
}
/** /**
* \brief A tensor storage. To use it for other functionality like slicing one needs to * \brief A tensor storage. To use it for other functionality like slicing one needs to
* obtain a view first. This way we can use it on both host and device. * obtain a view first. This way we can use it on both host and device.

View File

@ -30,9 +30,8 @@
#define XGBOOST_SPAN_H_ #define XGBOOST_SPAN_H_
#include <xgboost/base.h> #include <xgboost/base.h>
#include <xgboost/logging.h>
#include <cinttypes> // size_t #include <cstddef> // size_t
#include <cstdio> #include <cstdio>
#include <iterator> #include <iterator>
#include <limits> // numeric_limits #include <limits> // numeric_limits
@ -75,8 +74,7 @@
#endif // defined(_MSC_VER) && _MSC_VER < 1910 #endif // defined(_MSC_VER) && _MSC_VER < 1910
namespace xgboost { namespace xgboost::common {
namespace common {
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
// Usual logging facility is not available inside device code. // Usual logging facility is not available inside device code.
@ -744,8 +742,8 @@ class IterSpan {
return it_ + size(); return it_ + size();
} }
}; };
} // namespace common } // namespace xgboost::common
} // namespace xgboost
#if defined(_MSC_VER) &&_MSC_VER < 1910 #if defined(_MSC_VER) &&_MSC_VER < 1910
#undef constexpr #undef constexpr

View File

@ -33,7 +33,7 @@
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding> <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<maven.compiler.source>1.8</maven.compiler.source> <maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target> <maven.compiler.target>1.8</maven.compiler.target>
<flink.version>1.18.0</flink.version> <flink.version>1.19.0</flink.version>
<junit.version>4.13.2</junit.version> <junit.version>4.13.2</junit.version>
<spark.version>3.4.1</spark.version> <spark.version>3.4.1</spark.version>
<spark.version.gpu>3.4.1</spark.version.gpu> <spark.version.gpu>3.4.1</spark.version.gpu>
@ -46,9 +46,9 @@
<cudf.version>23.12.1</cudf.version> <cudf.version>23.12.1</cudf.version>
<spark.rapids.version>23.12.1</spark.rapids.version> <spark.rapids.version>23.12.1</spark.rapids.version>
<cudf.classifier>cuda12</cudf.classifier> <cudf.classifier>cuda12</cudf.classifier>
<scalatest.version>3.2.18</scalatest.version>
<scala-collection-compat.version>2.12.0</scala-collection-compat.version>
<use.hip>OFF</use.hip> <use.hip>OFF</use.hip>
<scalatest.version>3.2.17</scalatest.version>
<scala-collection-compat.version>2.11.0</scala-collection-compat.version>
<!-- SPARK-36796 for JDK-17 test--> <!-- SPARK-36796 for JDK-17 test-->
<extraJavaTestArgs> <extraJavaTestArgs>
@ -124,7 +124,7 @@
<plugin> <plugin>
<groupId>org.apache.maven.plugins</groupId> <groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId> <artifactId>maven-jar-plugin</artifactId>
<version>3.3.0</version> <version>3.4.0</version>
<executions> <executions>
<execution> <execution>
<id>empty-javadoc-jar</id> <id>empty-javadoc-jar</id>
@ -153,7 +153,7 @@
<plugin> <plugin>
<groupId>org.apache.maven.plugins</groupId> <groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-gpg-plugin</artifactId> <artifactId>maven-gpg-plugin</artifactId>
<version>3.1.0</version> <version>3.2.3</version>
<executions> <executions>
<execution> <execution>
<id>sign-artifacts</id> <id>sign-artifacts</id>
@ -167,7 +167,7 @@
<plugin> <plugin>
<groupId>org.apache.maven.plugins</groupId> <groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId> <artifactId>maven-source-plugin</artifactId>
<version>3.3.0</version> <version>3.3.1</version>
<executions> <executions>
<execution> <execution>
<id>attach-sources</id> <id>attach-sources</id>
@ -205,7 +205,7 @@
<plugin> <plugin>
<groupId>org.apache.maven.plugins</groupId> <groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId> <artifactId>maven-assembly-plugin</artifactId>
<version>3.6.0</version> <version>3.7.1</version>
<configuration> <configuration>
<descriptorRefs> <descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef> <descriptorRef>jar-with-dependencies</descriptorRef>
@ -446,7 +446,7 @@
<plugin> <plugin>
<groupId>org.apache.maven.plugins</groupId> <groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId> <artifactId>maven-surefire-plugin</artifactId>
<version>3.2.2</version> <version>3.2.5</version>
<configuration> <configuration>
<skipTests>false</skipTests> <skipTests>false</skipTests>
<useSystemClassLoader>false</useSystemClassLoader> <useSystemClassLoader>false</useSystemClassLoader>
@ -488,12 +488,12 @@
<dependency> <dependency>
<groupId>com.esotericsoftware</groupId> <groupId>com.esotericsoftware</groupId>
<artifactId>kryo</artifactId> <artifactId>kryo</artifactId>
<version>5.5.0</version> <version>5.6.0</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>commons-logging</groupId> <groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId> <artifactId>commons-logging</artifactId>
<version>1.3.0</version> <version>1.3.1</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.scalatest</groupId> <groupId>org.scalatest</groupId>

View File

@ -72,7 +72,7 @@
<plugin> <plugin>
<groupId>org.apache.maven.plugins</groupId> <groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId> <artifactId>maven-javadoc-plugin</artifactId>
<version>3.6.2</version> <version>3.6.3</version>
<configuration> <configuration>
<show>protected</show> <show>protected</show>
<nohelp>true</nohelp> <nohelp>true</nohelp>
@ -88,7 +88,7 @@
<plugin> <plugin>
<artifactId>exec-maven-plugin</artifactId> <artifactId>exec-maven-plugin</artifactId>
<groupId>org.codehaus.mojo</groupId> <groupId>org.codehaus.mojo</groupId>
<version>3.1.0</version> <version>3.2.0</version>
<executions> <executions>
<execution> <execution>
<id>native</id> <id>native</id>
@ -115,7 +115,7 @@
<plugin> <plugin>
<groupId>org.apache.maven.plugins</groupId> <groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId> <artifactId>maven-jar-plugin</artifactId>
<version>3.3.0</version> <version>3.4.0</version>
<executions> <executions>
<execution> <execution>
<goals> <goals>

View File

@ -22,7 +22,7 @@ pom_template = """
<scala.version>{scala_version}</scala.version> <scala.version>{scala_version}</scala.version>
<scalatest.version>3.2.15</scalatest.version> <scalatest.version>3.2.15</scalatest.version>
<scala.binary.version>{scala_binary_version}</scala.binary.version> <scala.binary.version>{scala_binary_version}</scala.binary.version>
<kryo.version>5.5.0</kryo.version> <kryo.version>5.6.0</kryo.version>
</properties> </properties>
<dependencies> <dependencies>

View File

@ -60,7 +60,7 @@
<plugin> <plugin>
<groupId>org.apache.maven.plugins</groupId> <groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId> <artifactId>maven-javadoc-plugin</artifactId>
<version>3.6.2</version> <version>3.6.3</version>
<configuration> <configuration>
<show>protected</show> <show>protected</show>
<nohelp>true</nohelp> <nohelp>true</nohelp>
@ -76,7 +76,7 @@
<plugin> <plugin>
<artifactId>exec-maven-plugin</artifactId> <artifactId>exec-maven-plugin</artifactId>
<groupId>org.codehaus.mojo</groupId> <groupId>org.codehaus.mojo</groupId>
<version>3.1.0</version> <version>3.2.0</version>
<executions> <executions>
<execution> <execution>
<id>native</id> <id>native</id>
@ -99,7 +99,7 @@
<plugin> <plugin>
<groupId>org.apache.maven.plugins</groupId> <groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId> <artifactId>maven-jar-plugin</artifactId>
<version>3.3.0</version> <version>3.4.0</version>
<executions> <executions>
<execution> <execution>
<goals> <goals>

View File

@ -408,7 +408,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatI
jfloat* array = jenv->GetFloatArrayElements(jarray, NULL); jfloat* array = jenv->GetFloatArrayElements(jarray, NULL);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
int ret = XGDMatrixSetFloatInfo(handle, field, (float const *)array, len); auto str = xgboost::linalg::Make1dInterface(array, len);
int ret = XGDMatrixSetInfoFromInterface(handle, field, str.c_str());
JVM_CHECK_CALL(ret); JVM_CHECK_CALL(ret);
//release //release
if (field) jenv->ReleaseStringUTFChars(jfield, field); if (field) jenv->ReleaseStringUTFChars(jfield, field);
@ -427,7 +428,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntIn
const char* field = jenv->GetStringUTFChars(jfield, 0); const char* field = jenv->GetStringUTFChars(jfield, 0);
jint* array = jenv->GetIntArrayElements(jarray, NULL); jint* array = jenv->GetIntArrayElements(jarray, NULL);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
int ret = XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len); auto str = xgboost::linalg::Make1dInterface(array, len);
int ret = XGDMatrixSetInfoFromInterface(handle, field, str.c_str());
JVM_CHECK_CALL(ret); JVM_CHECK_CALL(ret);
//release //release
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
@ -730,8 +732,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFr
if (jmargin) { if (jmargin) {
margin = jenv->GetFloatArrayElements(jmargin, nullptr); margin = jenv->GetFloatArrayElements(jmargin, nullptr);
JVM_CHECK_CALL(XGProxyDMatrixCreate(&proxy)); JVM_CHECK_CALL(XGProxyDMatrixCreate(&proxy));
JVM_CHECK_CALL( auto str = xgboost::linalg::Make1dInterface(margin, jenv->GetArrayLength(jmargin));
XGDMatrixSetFloatInfo(proxy, "base_margin", margin, jenv->GetArrayLength(jmargin))); JVM_CHECK_CALL(XGDMatrixSetInfoFromInterface(proxy, "base_margin", str.c_str()));
} }
bst_ulong const *out_shape; bst_ulong const *out_shape;

View File

@ -89,19 +89,15 @@ Coll *FederatedColl::MakeCUDAVar() {
[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data, [[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) { std::int32_t root) {
if (comm.Rank() == root) { return BroadcastImpl(comm, &this->sequence_number_, data, root);
return BroadcastImpl(comm, &this->sequence_number_, data, root);
} else {
return BroadcastImpl(comm, &this->sequence_number_, data, root);
}
} }
[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data, [[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data) {
std::int64_t size) {
using namespace federated; // NOLINT using namespace federated; // NOLINT
auto fed = dynamic_cast<FederatedComm const *>(&comm); auto fed = dynamic_cast<FederatedComm const *>(&comm);
CHECK(fed); CHECK(fed);
auto stub = fed->Handle(); auto stub = fed->Handle();
auto size = data.size_bytes() / comm.World();
auto offset = comm.Rank() * size; auto offset = comm.Rank() * size;
auto segment = data.subspan(offset, size); auto segment = data.subspan(offset, size);

View File

@ -53,8 +53,7 @@ Coll *FederatedColl::MakeCUDAVar() {
}; };
} }
[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data, [[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data) {
std::int64_t size) {
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm); auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
CHECK(cufed); CHECK(cufed);
std::vector<std::int8_t> h_data(data.size()); std::vector<std::int8_t> h_data(data.size());
@ -63,7 +62,7 @@ Coll *FederatedColl::MakeCUDAVar() {
return GetCUDAResult( return GetCUDAResult(
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost)); cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
} << [&] { } << [&] {
return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()}, size); return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()});
} << [&] { } << [&] {
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(), return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
cudaMemcpyHostToDevice, cufed->Stream())); cudaMemcpyHostToDevice, cufed->Stream()));

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2023, XGBoost contributors * Copyright 2023-2024, XGBoost contributors
*/ */
#include "../../src/collective/comm.h" // for Comm, Coll #include "../../src/collective/comm.h" // for Comm, Coll
#include "federated_coll.h" // for FederatedColl #include "federated_coll.h" // for FederatedColl
@ -16,8 +16,7 @@ class CUDAFederatedColl : public Coll {
ArrayInterfaceHandler::Type type, Op op) override; ArrayInterfaceHandler::Type type, Op op) override;
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data, [[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) override; std::int32_t root) override;
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data, [[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data) override;
std::int64_t size) override;
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data, [[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes, common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments, common::Span<std::int64_t> recv_segments,

View File

@ -1,12 +1,9 @@
/** /**
* Copyright 2023, XGBoost contributors * Copyright 2023-2024, XGBoost contributors
*/ */
#pragma once #pragma once
#include "../../src/collective/coll.h" // for Coll #include "../../src/collective/coll.h" // for Coll
#include "../../src/collective/comm.h" // for Comm #include "../../src/collective/comm.h" // for Comm
#include "../../src/common/io.h" // for ReadAll
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h" // for Json
namespace xgboost::collective { namespace xgboost::collective {
class FederatedColl : public Coll { class FederatedColl : public Coll {
@ -20,8 +17,7 @@ class FederatedColl : public Coll {
ArrayInterfaceHandler::Type type, Op op) override; ArrayInterfaceHandler::Type type, Op op) override;
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data, [[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) override; std::int32_t root) override;
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data, [[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data) override;
std::int64_t) override;
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data, [[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes, common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments, common::Span<std::int64_t> recv_segments,

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2023, XGBoost Contributors * Copyright 2023-2024, XGBoost Contributors
*/ */
#pragma once #pragma once
@ -9,7 +9,6 @@
#include "../../src/common/device_helpers.cuh" // for CUDAStreamView #include "../../src/common/device_helpers.cuh" // for CUDAStreamView
#include "federated_comm.h" // for FederatedComm #include "federated_comm.h" // for FederatedComm
#include "xgboost/context.h" // for Context #include "xgboost/context.h" // for Context
#include "xgboost/logging.h"
namespace xgboost::collective { namespace xgboost::collective {
class CUDAFederatedComm : public FederatedComm { class CUDAFederatedComm : public FederatedComm {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2023, XGBoost contributors * Copyright 2023-2024, XGBoost contributors
*/ */
#pragma once #pragma once
@ -11,7 +11,6 @@
#include <string> // for string #include <string> // for string
#include "../../src/collective/comm.h" // for HostComm #include "../../src/collective/comm.h" // for HostComm
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h" #include "xgboost/json.h"
namespace xgboost::collective { namespace xgboost::collective {
@ -51,6 +50,10 @@ class FederatedComm : public HostComm {
std::int32_t rank) { std::int32_t rank) {
this->Init(host, port, world, rank, {}, {}, {}); this->Init(host, port, world, rank, {}, {}, {});
} }
[[nodiscard]] Result Shutdown() final {
this->ResetState();
return Success();
}
~FederatedComm() override { stub_.reset(); } ~FederatedComm() override { stub_.reset(); }
[[nodiscard]] std::shared_ptr<Channel> Chan(std::int32_t) const override { [[nodiscard]] std::shared_ptr<Channel> Chan(std::int32_t) const override {
@ -65,5 +68,13 @@ class FederatedComm : public HostComm {
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); } [[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override; [[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
/**
* @brief Get a string ID for the current process.
*/
[[nodiscard]] Result ProcessorName(std::string* out) const final {
auto rank = this->Rank();
*out = "rank:" + std::to_string(rank);
return Success();
};
}; };
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -1,22 +1,18 @@
/** /**
* Copyright 2022-2023, XGBoost contributors * Copyright 2022-2024, XGBoost contributors
*/ */
#pragma once #pragma once
#include <federated.old.grpc.pb.h> #include <federated.old.grpc.pb.h>
#include <cstdint> // for int32_t #include <cstdint> // for int32_t
#include <future> // for future
#include "../../src/collective/in_memory_handler.h" #include "../../src/collective/in_memory_handler.h"
#include "../../src/collective/tracker.h" // for Tracker
#include "xgboost/collective/result.h" // for Result
namespace xgboost::federated { namespace xgboost::federated {
class FederatedService final : public Federated::Service { class FederatedService final : public Federated::Service {
public: public:
explicit FederatedService(std::int32_t world_size) explicit FederatedService(std::int32_t world_size) : handler_{world_size} {}
: handler_{static_cast<std::size_t>(world_size)} {}
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override; AllgatherReply* reply) override;

View File

@ -125,14 +125,14 @@ Result FederatedTracker::Shutdown() {
[[nodiscard]] Json FederatedTracker::WorkerArgs() const { [[nodiscard]] Json FederatedTracker::WorkerArgs() const {
auto rc = this->WaitUntilReady(); auto rc = this->WaitUntilReady();
CHECK(rc.OK()) << rc.Report(); SafeColl(rc);
std::string host; std::string host;
rc = GetHostAddress(&host); rc = GetHostAddress(&host);
CHECK(rc.OK()); CHECK(rc.OK());
Json args{Object{}}; Json args{Object{}};
args["DMLC_TRACKER_URI"] = String{host}; args["dmlc_tracker_uri"] = String{host};
args["DMLC_TRACKER_PORT"] = this->Port(); args["dmlc_tracker_port"] = this->Port();
return args; return args;
} }
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -17,8 +17,7 @@ namespace xgboost::collective {
namespace federated { namespace federated {
class FederatedService final : public Federated::Service { class FederatedService final : public Federated::Service {
public: public:
explicit FederatedService(std::int32_t world_size) explicit FederatedService(std::int32_t world_size) : handler_{world_size} {}
: handler_{static_cast<std::size_t>(world_size)} {}
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override; AllgatherReply* reply) override;

View File

@ -0,0 +1,334 @@
/*!
* Copyright 2017-2023 by Contributors
* \file hist_util.cc
*/
#include <vector>
#include <limits>
#include <algorithm>
#include "../data/gradient_index.h"
#include "hist_util.h"
#include <CL/sycl.hpp>
namespace xgboost {
namespace sycl {
namespace common {
/*!
* \brief Fill histogram with zeroes
*/
template<typename GradientSumT>
void InitHist(::sycl::queue qu, GHistRow<GradientSumT, MemoryType::on_device>* hist,
size_t size, ::sycl::event* event) {
*event = qu.fill(hist->Begin(),
xgboost::detail::GradientPairInternal<GradientSumT>(), size, *event);
}
template void InitHist(::sycl::queue qu,
GHistRow<float, MemoryType::on_device>* hist,
size_t size, ::sycl::event* event);
template void InitHist(::sycl::queue qu,
GHistRow<double, MemoryType::on_device>* hist,
size_t size, ::sycl::event* event);
/*!
* \brief Compute Subtraction: dst = src1 - src2
*/
template<typename GradientSumT>
::sycl::event SubtractionHist(::sycl::queue qu,
GHistRow<GradientSumT, MemoryType::on_device>* dst,
const GHistRow<GradientSumT, MemoryType::on_device>& src1,
const GHistRow<GradientSumT, MemoryType::on_device>& src2,
size_t size, ::sycl::event event_priv) {
GradientSumT* pdst = reinterpret_cast<GradientSumT*>(dst->Data());
const GradientSumT* psrc1 = reinterpret_cast<const GradientSumT*>(src1.DataConst());
const GradientSumT* psrc2 = reinterpret_cast<const GradientSumT*>(src2.DataConst());
auto event_final = qu.submit([&](::sycl::handler& cgh) {
cgh.depends_on(event_priv);
cgh.parallel_for<>(::sycl::range<1>(2 * size), [pdst, psrc1, psrc2](::sycl::item<1> pid) {
const size_t i = pid.get_id(0);
pdst[i] = psrc1[i] - psrc2[i];
});
});
return event_final;
}
template ::sycl::event SubtractionHist(::sycl::queue qu,
GHistRow<float, MemoryType::on_device>* dst,
const GHistRow<float, MemoryType::on_device>& src1,
const GHistRow<float, MemoryType::on_device>& src2,
size_t size, ::sycl::event event_priv);
template ::sycl::event SubtractionHist(::sycl::queue qu,
GHistRow<double, MemoryType::on_device>* dst,
const GHistRow<double, MemoryType::on_device>& src1,
const GHistRow<double, MemoryType::on_device>& src2,
size_t size, ::sycl::event event_priv);
// Kernel with buffer using
template<typename FPType, typename BinIdxType, bool isDense>
::sycl::event BuildHistKernel(::sycl::queue qu,
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat,
GHistRow<FPType, MemoryType::on_device>* hist,
GHistRow<FPType, MemoryType::on_device>* hist_buffer,
::sycl::event event_priv) {
const size_t size = row_indices.Size();
const size_t* rid = row_indices.begin;
const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride;
const GradientPair::ValueT* pgh =
reinterpret_cast<const GradientPair::ValueT*>(gpair_device.DataConst());
const BinIdxType* gradient_index = gmat.index.data<BinIdxType>();
const uint32_t* offsets = gmat.index.Offset();
FPType* hist_data = reinterpret_cast<FPType*>(hist->Data());
const size_t nbins = gmat.nbins;
const size_t max_work_group_size =
qu.get_device().get_info<::sycl::info::device::max_work_group_size>();
const size_t work_group_size = n_columns < max_work_group_size ? n_columns : max_work_group_size;
const size_t max_nblocks = hist_buffer->Size() / (nbins * 2);
const size_t min_block_size = 128;
size_t nblocks = std::min(max_nblocks, size / min_block_size + !!(size % min_block_size));
const size_t block_size = size / nblocks + !!(size % nblocks);
FPType* hist_buffer_data = reinterpret_cast<FPType*>(hist_buffer->Data());
auto event_fill = qu.fill(hist_buffer_data, FPType(0), nblocks * nbins * 2, event_priv);
auto event_main = qu.submit([&](::sycl::handler& cgh) {
cgh.depends_on(event_fill);
cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(nblocks, work_group_size),
::sycl::range<2>(1, work_group_size)),
[=](::sycl::nd_item<2> pid) {
size_t block = pid.get_global_id(0);
size_t feat = pid.get_global_id(1);
FPType* hist_local = hist_buffer_data + block * nbins * 2;
for (size_t idx = 0; idx < block_size; ++idx) {
size_t i = block * block_size + idx;
if (i < size) {
const size_t icol_start = n_columns * rid[i];
const size_t idx_gh = rid[i];
pid.barrier(::sycl::access::fence_space::local_space);
const BinIdxType* gr_index_local = gradient_index + icol_start;
for (size_t j = feat; j < n_columns; j += work_group_size) {
uint32_t idx_bin = static_cast<uint32_t>(gr_index_local[j]);
if constexpr (isDense) {
idx_bin += offsets[j];
}
if (idx_bin < nbins) {
hist_local[2 * idx_bin] += pgh[2 * idx_gh];
hist_local[2 * idx_bin+1] += pgh[2 * idx_gh+1];
}
}
}
}
});
});
auto event_save = qu.submit([&](::sycl::handler& cgh) {
cgh.depends_on(event_main);
cgh.parallel_for<>(::sycl::range<1>(nbins), [=](::sycl::item<1> pid) {
size_t idx_bin = pid.get_id(0);
FPType gsum = 0.0f;
FPType hsum = 0.0f;
for (size_t j = 0; j < nblocks; ++j) {
gsum += hist_buffer_data[j * nbins * 2 + 2 * idx_bin];
hsum += hist_buffer_data[j * nbins * 2 + 2 * idx_bin + 1];
}
hist_data[2 * idx_bin] = gsum;
hist_data[2 * idx_bin + 1] = hsum;
});
});
return event_save;
}
// Kernel with atomic using
template<typename FPType, typename BinIdxType, bool isDense>
::sycl::event BuildHistKernel(::sycl::queue qu,
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat,
GHistRow<FPType, MemoryType::on_device>* hist,
::sycl::event event_priv) {
const size_t size = row_indices.Size();
const size_t* rid = row_indices.begin;
const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride;
const GradientPair::ValueT* pgh =
reinterpret_cast<const GradientPair::ValueT*>(gpair_device.DataConst());
const BinIdxType* gradient_index = gmat.index.data<BinIdxType>();
const uint32_t* offsets = gmat.index.Offset();
FPType* hist_data = reinterpret_cast<FPType*>(hist->Data());
const size_t nbins = gmat.nbins;
const size_t max_work_group_size =
qu.get_device().get_info<::sycl::info::device::max_work_group_size>();
const size_t feat_local = n_columns < max_work_group_size ? n_columns : max_work_group_size;
auto event_fill = qu.fill(hist_data, FPType(0), nbins * 2, event_priv);
auto event_main = qu.submit([&](::sycl::handler& cgh) {
cgh.depends_on(event_fill);
cgh.parallel_for<>(::sycl::range<2>(size, feat_local),
[=](::sycl::item<2> pid) {
size_t i = pid.get_id(0);
size_t feat = pid.get_id(1);
const size_t icol_start = n_columns * rid[i];
const size_t idx_gh = rid[i];
const BinIdxType* gr_index_local = gradient_index + icol_start;
for (size_t j = feat; j < n_columns; j += feat_local) {
uint32_t idx_bin = static_cast<uint32_t>(gr_index_local[j]);
if constexpr (isDense) {
idx_bin += offsets[j];
}
if (idx_bin < nbins) {
AtomicRef<FPType> gsum(hist_data[2 * idx_bin]);
AtomicRef<FPType> hsum(hist_data[2 * idx_bin + 1]);
gsum.fetch_add(pgh[2 * idx_gh]);
hsum.fetch_add(pgh[2 * idx_gh + 1]);
}
}
});
});
return event_main;
}
template<typename FPType, typename BinIdxType>
::sycl::event BuildHistDispatchKernel(
::sycl::queue qu,
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat,
GHistRow<FPType, MemoryType::on_device>* hist,
bool isDense,
GHistRow<FPType, MemoryType::on_device>* hist_buffer,
::sycl::event events_priv,
bool force_atomic_use) {
const size_t size = row_indices.Size();
const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride;
const size_t nbins = gmat.nbins;
// max cycle size, while atomics are still effective
const size_t max_cycle_size_atomics = nbins;
const size_t cycle_size = size;
// TODO(razdoburdin): replace the add-hock dispatching criteria by more sutable one
bool use_atomic = (size < nbins) || (gmat.max_num_bins == gmat.nbins / n_columns);
// force_atomic_use flag is used only for testing
use_atomic = use_atomic || force_atomic_use;
if (!use_atomic) {
if (isDense) {
return BuildHistKernel<FPType, BinIdxType, true>(qu, gpair_device, row_indices,
gmat, hist, hist_buffer,
events_priv);
} else {
return BuildHistKernel<FPType, uint32_t, false>(qu, gpair_device, row_indices,
gmat, hist, hist_buffer,
events_priv);
}
} else {
if (isDense) {
return BuildHistKernel<FPType, BinIdxType, true>(qu, gpair_device, row_indices,
gmat, hist, events_priv);
} else {
return BuildHistKernel<FPType, uint32_t, false>(qu, gpair_device, row_indices,
gmat, hist, events_priv);
}
}
}
template<typename FPType>
::sycl::event BuildHistKernel(::sycl::queue qu,
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat, const bool isDense,
GHistRow<FPType, MemoryType::on_device>* hist,
GHistRow<FPType, MemoryType::on_device>* hist_buffer,
::sycl::event event_priv,
bool force_atomic_use) {
const bool is_dense = isDense;
switch (gmat.index.GetBinTypeSize()) {
case BinTypeSize::kUint8BinsTypeSize:
return BuildHistDispatchKernel<FPType, uint8_t>(qu, gpair_device, row_indices,
gmat, hist, is_dense, hist_buffer,
event_priv, force_atomic_use);
break;
case BinTypeSize::kUint16BinsTypeSize:
return BuildHistDispatchKernel<FPType, uint16_t>(qu, gpair_device, row_indices,
gmat, hist, is_dense, hist_buffer,
event_priv, force_atomic_use);
break;
case BinTypeSize::kUint32BinsTypeSize:
return BuildHistDispatchKernel<FPType, uint32_t>(qu, gpair_device, row_indices,
gmat, hist, is_dense, hist_buffer,
event_priv, force_atomic_use);
break;
default:
CHECK(false); // no default behavior
}
}
template <typename GradientSumT>
::sycl::event GHistBuilder<GradientSumT>::BuildHist(
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix &gmat,
GHistRowT<MemoryType::on_device>* hist,
bool isDense,
GHistRowT<MemoryType::on_device>* hist_buffer,
::sycl::event event_priv,
bool force_atomic_use) {
return BuildHistKernel<GradientSumT>(qu_, gpair_device, row_indices, gmat,
isDense, hist, hist_buffer, event_priv,
force_atomic_use);
}
template
::sycl::event GHistBuilder<float>::BuildHist(
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat,
GHistRow<float, MemoryType::on_device>* hist,
bool isDense,
GHistRow<float, MemoryType::on_device>* hist_buffer,
::sycl::event event_priv,
bool force_atomic_use);
template
::sycl::event GHistBuilder<double>::BuildHist(
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat,
GHistRow<double, MemoryType::on_device>* hist,
bool isDense,
GHistRow<double, MemoryType::on_device>* hist_buffer,
::sycl::event event_priv,
bool force_atomic_use);
template<typename GradientSumT>
void GHistBuilder<GradientSumT>::SubtractionTrick(GHistRowT<MemoryType::on_device>* self,
const GHistRowT<MemoryType::on_device>& sibling,
const GHistRowT<MemoryType::on_device>& parent) {
const size_t size = self->Size();
CHECK_EQ(sibling.Size(), size);
CHECK_EQ(parent.Size(), size);
SubtractionHist(qu_, self, parent, sibling, size, ::sycl::event());
}
template
void GHistBuilder<float>::SubtractionTrick(GHistRow<float, MemoryType::on_device>* self,
const GHistRow<float, MemoryType::on_device>& sibling,
const GHistRow<float, MemoryType::on_device>& parent);
template
void GHistBuilder<double>::SubtractionTrick(GHistRow<double, MemoryType::on_device>* self,
const GHistRow<double, MemoryType::on_device>& sibling,
const GHistRow<double, MemoryType::on_device>& parent);
} // namespace common
} // namespace sycl
} // namespace xgboost

View File

@ -0,0 +1,89 @@
/*!
* Copyright 2017-2023 by Contributors
* \file hist_util.h
*/
#ifndef PLUGIN_SYCL_COMMON_HIST_UTIL_H_
#define PLUGIN_SYCL_COMMON_HIST_UTIL_H_
#include <vector>
#include <unordered_map>
#include <memory>
#include "../data.h"
#include "row_set.h"
#include "../../src/common/hist_util.h"
#include "../data/gradient_index.h"
#include <CL/sycl.hpp>
namespace xgboost {
namespace sycl {
namespace common {
template<typename GradientSumT, MemoryType memory_type = MemoryType::shared>
using GHistRow = USMVector<xgboost::detail::GradientPairInternal<GradientSumT>, memory_type>;
using BinTypeSize = ::xgboost::common::BinTypeSize;
class ColumnMatrix;
/*!
* \brief Fill histogram with zeroes
*/
template<typename GradientSumT>
void InitHist(::sycl::queue qu,
GHistRow<GradientSumT, MemoryType::on_device>* hist,
size_t size, ::sycl::event* event);
/*!
* \brief Compute subtraction: dst = src1 - src2
*/
template<typename GradientSumT>
::sycl::event SubtractionHist(::sycl::queue qu,
GHistRow<GradientSumT, MemoryType::on_device>* dst,
const GHistRow<GradientSumT, MemoryType::on_device>& src1,
const GHistRow<GradientSumT, MemoryType::on_device>& src2,
size_t size, ::sycl::event event_priv);
/*!
* \brief Builder for histograms of gradient statistics
*/
template<typename GradientSumT>
class GHistBuilder {
public:
template<MemoryType memory_type = MemoryType::shared>
using GHistRowT = GHistRow<GradientSumT, memory_type>;
GHistBuilder() = default;
GHistBuilder(::sycl::queue qu, uint32_t nbins) : qu_{qu}, nbins_{nbins} {}
// Construct a histogram via histogram aggregation
::sycl::event BuildHist(const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat,
GHistRowT<MemoryType::on_device>* HistCollection,
bool isDense,
GHistRowT<MemoryType::on_device>* hist_buffer,
::sycl::event event,
bool force_atomic_use = false);
// Construct a histogram via subtraction trick
void SubtractionTrick(GHistRowT<MemoryType::on_device>* self,
const GHistRowT<MemoryType::on_device>& sibling,
const GHistRowT<MemoryType::on_device>& parent);
uint32_t GetNumBins() const {
return nbins_;
}
private:
/*! \brief Number of all bins over all features */
uint32_t nbins_ { 0 };
::sycl::queue qu_;
};
} // namespace common
} // namespace sycl
} // namespace xgboost
#endif // PLUGIN_SYCL_COMMON_HIST_UTIL_H_

View File

@ -0,0 +1,55 @@
/*!
* Copyright 2017-2024 by Contributors
* \file updater_quantile_hist.cc
*/
#include <vector>
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include "xgboost/tree_updater.h"
#pragma GCC diagnostic pop
#include "xgboost/logging.h"
#include "updater_quantile_hist.h"
#include "../data.h"
namespace xgboost {
namespace sycl {
namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_quantile_hist_sycl);
DMLC_REGISTER_PARAMETER(HistMakerTrainParam);
void QuantileHistMaker::Configure(const Args& args) {
const DeviceOrd device_spec = ctx_->Device();
qu_ = device_manager.GetQueue(device_spec);
param_.UpdateAllowUnknown(args);
hist_maker_param_.UpdateAllowUnknown(args);
}
void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param,
linalg::Matrix<GradientPair>* gpair,
DMatrix *dmat,
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) {
LOG(FATAL) << "Not Implemented yet";
}
bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data,
linalg::MatrixView<float> out_preds) {
LOG(FATAL) << "Not Implemented yet";
}
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker_sycl")
.describe("Grow tree using quantized histogram with SYCL.")
.set_body(
[](Context const* ctx, ObjInfo const * task) {
return new QuantileHistMaker(ctx, task);
});
} // namespace tree
} // namespace sycl
} // namespace xgboost

View File

@ -0,0 +1,91 @@
/*!
* Copyright 2017-2024 by Contributors
* \file updater_quantile_hist.h
*/
#ifndef PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_
#define PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_
#include <dmlc/timer.h>
#include <xgboost/tree_updater.h>
#include <vector>
#include "../data/gradient_index.h"
#include "../common/hist_util.h"
#include "../common/row_set.h"
#include "../common/partition_builder.h"
#include "split_evaluator.h"
#include "../device_manager.h"
#include "xgboost/data.h"
#include "xgboost/json.h"
#include "../../src/tree/constraints.h"
#include "../../src/common/random.h"
namespace xgboost {
namespace sycl {
namespace tree {
// training parameters specific to this algorithm
struct HistMakerTrainParam
: public XGBoostParameter<HistMakerTrainParam> {
bool single_precision_histogram = false;
// declare parameters
DMLC_DECLARE_PARAMETER(HistMakerTrainParam) {
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
"Use single precision to build histograms.");
}
};
/*! \brief construct a tree using quantized feature values with SYCL backend*/
class QuantileHistMaker: public TreeUpdater {
public:
QuantileHistMaker(Context const* ctx, ObjInfo const * task) :
TreeUpdater(ctx), task_{task} {
updater_monitor_.Init("SYCLQuantileHistMaker");
}
void Configure(const Args& args) override;
void Update(xgboost::tree::TrainParam const *param,
linalg::Matrix<GradientPair>* gpair,
DMatrix* dmat,
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& trees) override;
bool UpdatePredictionCache(const DMatrix* data,
linalg::MatrixView<float> out_preds) override;
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
FromJson(config.at("train_param"), &this->param_);
FromJson(config.at("sycl_hist_train_param"), &this->hist_maker_param_);
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["train_param"] = ToJson(param_);
out["sycl_hist_train_param"] = ToJson(hist_maker_param_);
}
char const* Name() const override {
return "grow_quantile_histmaker_sycl";
}
protected:
HistMakerTrainParam hist_maker_param_;
// training parameter
xgboost::tree::TrainParam param_;
xgboost::common::Monitor updater_monitor_;
::sycl::queue qu_;
DeviceManager device_manager;
ObjInfo const *task_{nullptr};
};
} // namespace tree
} // namespace sycl
} // namespace xgboost
#endif // PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_

View File

@ -909,9 +909,19 @@ def _transform_cudf_df(
enable_categorical: bool, enable_categorical: bool,
) -> Tuple[ctypes.c_void_p, list, Optional[FeatureNames], Optional[FeatureTypes]]: ) -> Tuple[ctypes.c_void_p, list, Optional[FeatureNames], Optional[FeatureTypes]]:
try: try:
from cudf.api.types import is_categorical_dtype from cudf.api.types import is_bool_dtype, is_categorical_dtype
except ImportError: except ImportError:
from cudf.utils.dtypes import is_categorical_dtype from cudf.utils.dtypes import is_categorical_dtype
from pandas.api.types import is_bool_dtype
# Work around https://github.com/dmlc/xgboost/issues/10181
if _is_cudf_ser(data):
if is_bool_dtype(data.dtype):
data = data.astype(np.uint8)
else:
data = data.astype(
{col: np.uint8 for col in data.select_dtypes(include="bool")}
)
if _is_cudf_ser(data): if _is_cudf_ser(data):
dtypes = [data.dtype] dtypes = [data.dtype]

View File

@ -347,15 +347,14 @@ class _SparkXGBParams(
predict_params[param.name] = self.getOrDefault(param) predict_params[param.name] = self.getOrDefault(param)
return predict_params return predict_params
def _validate_gpu_params(self) -> None: def _validate_gpu_params(
self, spark_version: str, conf: SparkConf, is_local: bool = False
) -> None:
"""Validate the gpu parameters and gpu configurations""" """Validate the gpu parameters and gpu configurations"""
if self._run_on_gpu(): if self._run_on_gpu():
ss = _get_spark_session() if is_local:
sc = ss.sparkContext # Supporting GPU training in Spark local mode is just for debugging
if _is_local(sc):
# Support GPU training in Spark local mode is just for debugging
# purposes, so it's okay for printing the below warning instead of # purposes, so it's okay for printing the below warning instead of
# checking the real gpu numbers and raising the exception. # checking the real gpu numbers and raising the exception.
get_logger(self.__class__.__name__).warning( get_logger(self.__class__.__name__).warning(
@ -364,33 +363,41 @@ class _SparkXGBParams(
self.getOrDefault(self.num_workers), self.getOrDefault(self.num_workers),
) )
else: else:
executor_gpus = sc.getConf().get("spark.executor.resource.gpu.amount") executor_gpus = conf.get("spark.executor.resource.gpu.amount")
if executor_gpus is None: if executor_gpus is None:
raise ValueError( raise ValueError(
"The `spark.executor.resource.gpu.amount` is required for training" "The `spark.executor.resource.gpu.amount` is required for training"
" on GPU." " on GPU."
) )
gpu_per_task = conf.get("spark.task.resource.gpu.amount")
if not ( if gpu_per_task is not None and float(gpu_per_task) > 1.0:
ss.version >= "3.4.0" get_logger(self.__class__.__name__).warning(
and _is_standalone_or_localcluster(sc.getConf()) "The configuration assigns %s GPUs to each Spark task, but each "
"XGBoost training task only utilizes 1 GPU, which will lead to "
"unnecessary GPU waste",
gpu_per_task,
)
# For 3.5.1+, Spark supports task stage-level scheduling for
# Yarn/K8s/Standalone/Local cluster
# From 3.4.0 ~ 3.5.0, Spark only supports task stage-level scheduing for
# Standalone/Local cluster
# For spark below 3.4.0, Task stage-level scheduling is not supported.
#
# With stage-level scheduling, spark.task.resource.gpu.amount is not required
# to be set explicitly. Or else, spark.task.resource.gpu.amount is a must-have and
# must be set to 1.0
if spark_version < "3.4.0" or (
"3.4.0" <= spark_version < "3.5.1"
and not _is_standalone_or_localcluster(conf)
): ):
# We will enable stage-level scheduling in spark 3.4.0+ which doesn't
# require spark.task.resource.gpu.amount to be set explicitly
gpu_per_task = sc.getConf().get("spark.task.resource.gpu.amount")
if gpu_per_task is not None: if gpu_per_task is not None:
if float(gpu_per_task) < 1.0: if float(gpu_per_task) < 1.0:
raise ValueError( raise ValueError(
"XGBoost doesn't support GPU fractional configurations. " "XGBoost doesn't support GPU fractional configurations. Please set "
"Please set `spark.task.resource.gpu.amount=spark.executor" "`spark.task.resource.gpu.amount=spark.executor.resource.gpu."
".resource.gpu.amount`" "amount`. To enable GPU fractional configurations, you can try "
) "standalone/localcluster with spark 3.4.0+ and"
"YARN/K8S with spark 3.5.1+"
if float(gpu_per_task) > 1.0:
get_logger(self.__class__.__name__).warning(
"%s GPUs for each Spark task is configured, but each "
"XGBoost training task uses only 1 GPU.",
gpu_per_task,
) )
else: else:
raise ValueError( raise ValueError(
@ -475,7 +482,9 @@ class _SparkXGBParams(
"`pyspark.ml.linalg.Vector` type." "`pyspark.ml.linalg.Vector` type."
) )
self._validate_gpu_params() ss = _get_spark_session()
sc = ss.sparkContext
self._validate_gpu_params(ss.version, sc.getConf(), _is_local(sc))
def _run_on_gpu(self) -> bool: def _run_on_gpu(self) -> bool:
"""If train or transform on the gpu according to the parameters""" """If train or transform on the gpu according to the parameters"""
@ -925,10 +934,14 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
) )
return True return True
if not _is_standalone_or_localcluster(conf): if (
"3.4.0" <= spark_version < "3.5.1"
and not _is_standalone_or_localcluster(conf)
):
self.logger.info( self.logger.info(
"Stage-level scheduling in xgboost requires spark standalone or " "For %s, Stage-level scheduling in xgboost requires spark standalone "
"local-cluster mode" "or local-cluster mode",
spark_version,
) )
return True return True
@ -980,7 +993,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
"""Try to enable stage-level scheduling""" """Try to enable stage-level scheduling"""
ss = _get_spark_session() ss = _get_spark_session()
conf = ss.sparkContext.getConf() conf = ss.sparkContext.getConf()
if self._skip_stage_level_scheduling(ss.version, conf): if _is_local(ss.sparkContext) or self._skip_stage_level_scheduling(
ss.version, conf
):
return rdd return rdd
# executor_cores will not be None # executor_cores will not be None
@ -1052,6 +1067,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
dev_ordinal = None dev_ordinal = None
use_qdm = _can_use_qdm(booster_params.get("tree_method", None)) use_qdm = _can_use_qdm(booster_params.get("tree_method", None))
verbosity = booster_params.get("verbosity", 1)
msg = "Training on CPUs" msg = "Training on CPUs"
if run_on_gpu: if run_on_gpu:
dev_ordinal = ( dev_ordinal = (
@ -1089,15 +1105,16 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
evals_result: Dict[str, Any] = {} evals_result: Dict[str, Any] = {}
with CommunicatorContext(context, **_rabit_args): with CommunicatorContext(context, **_rabit_args):
dtrain, dvalid = create_dmatrix_from_partitions( with xgboost.config_context(verbosity=verbosity):
pandas_df_iter, dtrain, dvalid = create_dmatrix_from_partitions(
feature_prop.features_cols_names, pandas_df_iter,
dev_ordinal, feature_prop.features_cols_names,
use_qdm, dev_ordinal,
dmatrix_kwargs, use_qdm,
enable_sparse_data_optim=feature_prop.enable_sparse_data_optim, dmatrix_kwargs,
has_validation_col=feature_prop.has_validation_col, enable_sparse_data_optim=feature_prop.enable_sparse_data_optim,
) has_validation_col=feature_prop.has_validation_col,
)
if dvalid is not None: if dvalid is not None:
dval = [(dtrain, "training"), (dvalid, "validation")] dval = [(dtrain, "training"), (dvalid, "validation")]
else: else:

View File

@ -14,7 +14,8 @@ import pyspark
from pyspark import BarrierTaskContext, SparkConf, SparkContext, SparkFiles, TaskContext from pyspark import BarrierTaskContext, SparkConf, SparkContext, SparkFiles, TaskContext
from pyspark.sql.session import SparkSession from pyspark.sql.session import SparkSession
from xgboost import Booster, XGBModel, collective from xgboost import Booster, XGBModel
from xgboost.collective import CommunicatorContext as CCtx
from xgboost.tracker import RabitTracker from xgboost.tracker import RabitTracker
@ -42,22 +43,12 @@ def _get_default_params_from_func(
return filtered_params_dict return filtered_params_dict
class CommunicatorContext: class CommunicatorContext(CCtx): # pylint: disable=too-few-public-methods
"""A context controlling collective communicator initialization and finalization. """Context with PySpark specific task ID."""
This isn't specificially necessary (note Part 3), but it is more understandable
coding-wise.
"""
def __init__(self, context: BarrierTaskContext, **args: Any) -> None: def __init__(self, context: BarrierTaskContext, **args: Any) -> None:
self.args = args args["DMLC_TASK_ID"] = str(context.partitionId())
self.args["DMLC_TASK_ID"] = str(context.partitionId()) super().__init__(**args)
def __enter__(self) -> None:
collective.init(**self.args)
def __exit__(self, *args: Any) -> None:
collective.finalize()
def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]: def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:

View File

@ -429,8 +429,8 @@ def make_categorical(
categories = np.arange(0, n_categories) categories = np.arange(0, n_categories)
for col in df.columns: for col in df.columns:
if rng.binomial(1, cat_ratio, size=1)[0] == 1: if rng.binomial(1, cat_ratio, size=1)[0] == 1:
df.loc[:, col] = df[col].astype("category") df[col] = df[col].astype("category")
df.loc[:, col] = df[col].cat.set_categories(categories) df[col] = df[col].cat.set_categories(categories)
if sparsity > 0.0: if sparsity > 0.0:
for i in range(n_features): for i in range(n_features):

View File

@ -100,6 +100,24 @@ std::enable_if_t<std::is_integral_v<E>, xgboost::collective::Result> PollError(E
if ((revents & POLLNVAL) != 0) { if ((revents & POLLNVAL) != 0) {
return xgboost::system::FailWithCode("Invalid polling request."); return xgboost::system::FailWithCode("Invalid polling request.");
} }
if ((revents & POLLHUP) != 0) {
// Excerpt from the Linux manual:
//
// Note that when reading from a channel such as a pipe or a stream socket, this event
// merely indicates that the peer closed its end of the channel.Subsequent reads from
// the channel will return 0 (end of file) only after all outstanding data in the
// channel has been consumed.
//
// We don't usually have a barrier for exiting workers, it's normal to have one end
// exit while the other still reading data.
return xgboost::collective::Success();
}
#if defined(POLLRDHUP)
// Linux only flag
if ((revents & POLLRDHUP) != 0) {
return xgboost::system::FailWithCode("Poll hung up on the other end.");
}
#endif // defined(POLLRDHUP)
return xgboost::collective::Success(); return xgboost::collective::Success();
} }
@ -179,9 +197,11 @@ struct PollHelper {
} }
std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout); std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout);
if (ret == 0) { if (ret == 0) {
return xgboost::collective::Fail("Poll timeout.", std::make_error_code(std::errc::timed_out)); return xgboost::collective::Fail(
"Poll timeout:" + std::to_string(timeout.count()) + " seconds.",
std::make_error_code(std::errc::timed_out));
} else if (ret < 0) { } else if (ret < 0) {
return xgboost::system::FailWithCode("Poll failed."); return xgboost::system::FailWithCode("Poll failed, nfds:" + std::to_string(fdset.size()));
} }
for (auto& pfd : fdset) { for (auto& pfd : fdset) {

View File

@ -132,7 +132,7 @@ bool AllreduceBase::Shutdown() {
try { try {
for (auto &all_link : all_links) { for (auto &all_link : all_links) {
if (!all_link.sock.IsClosed()) { if (!all_link.sock.IsClosed()) {
all_link.sock.Close(); SafeColl(all_link.sock.Close());
} }
} }
all_links.clear(); all_links.clear();
@ -146,7 +146,7 @@ bool AllreduceBase::Shutdown() {
LOG(FATAL) << rc.Report(); LOG(FATAL) << rc.Report();
} }
tracker.Send(xgboost::StringView{"shutdown"}); tracker.Send(xgboost::StringView{"shutdown"});
tracker.Close(); SafeColl(tracker.Close());
xgboost::system::SocketFinalize(); xgboost::system::SocketFinalize();
return true; return true;
} catch (std::exception const &e) { } catch (std::exception const &e) {
@ -167,7 +167,7 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
tracker.Send(xgboost::StringView{"print"}); tracker.Send(xgboost::StringView{"print"});
tracker.Send(xgboost::StringView{msg}); tracker.Send(xgboost::StringView{msg});
tracker.Close(); SafeColl(tracker.Close());
} }
// util to parse data with unit suffix // util to parse data with unit suffix
@ -332,15 +332,15 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())}; auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())};
// create listening socket // create listening socket
int port = sock_listen.BindHost(); std::int32_t port{0};
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified"); SafeColl(sock_listen.BindHost(&port));
sock_listen.Listen(); SafeColl(sock_listen.Listen());
// get number of to connect and number of to accept nodes from tracker // get number of to connect and number of to accept nodes from tracker
int num_conn, num_accept, num_error = 1; int num_conn, num_accept, num_error = 1;
do { do {
for (auto & all_link : all_links) { for (auto & all_link : all_links) {
all_link.sock.Close(); SafeColl(all_link.sock.Close());
} }
// tracker construct goodset // tracker construct goodset
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn), Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
@ -352,7 +352,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
LinkRecord r; LinkRecord r;
int hport, hrank; int hport, hrank;
std::string hname; std::string hname;
tracker.Recv(&hname); SafeColl(tracker.Recv(&hname));
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9"); Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10"); Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
// connect to peer // connect to peer
@ -360,7 +360,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
timeout_sec, &r.sock) timeout_sec, &r.sock)
.OK()) { .OK()) {
num_error += 1; num_error += 1;
r.sock.Close(); SafeColl(r.sock.Close());
continue; continue;
} }
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
@ -386,7 +386,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
// send back socket listening port to tracker // send back socket listening port to tracker
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14"); Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
// close connection to tracker // close connection to tracker
tracker.Close(); SafeColl(tracker.Close());
// listen to incoming links // listen to incoming links
for (int i = 0; i < num_accept; ++i) { for (int i = 0; i < num_accept; ++i) {
@ -408,7 +408,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
} }
if (!match) all_links.emplace_back(std::move(r)); if (!match) all_links.emplace_back(std::move(r));
} }
sock_listen.Close(); SafeColl(sock_listen.Close());
this->parent_index = -1; this->parent_index = -1;
// setup tree links and ring structure // setup tree links and ring structure
@ -635,7 +635,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
Recv(sendrecvbuf + size_down_in, total_size - size_down_in); Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
if (len == 0) { if (len == 0) {
links[parent_index].sock.Close(); SafeColl(links[parent_index].sock.Close());
return ReportError(&links[parent_index], kRecvZeroLen); return ReportError(&links[parent_index], kRecvZeroLen);
} }
if (len != -1) { if (len != -1) {

View File

@ -270,7 +270,7 @@ class AllreduceBase : public IEngine {
ssize_t len = sock.Recv(buffer_head + offset, nmax); ssize_t len = sock.Recv(buffer_head + offset, nmax);
// length equals 0, remote disconnected // length equals 0, remote disconnected
if (len == 0) { if (len == 0) {
sock.Close(); return kRecvZeroLen; SafeColl(sock.Close()); return kRecvZeroLen;
} }
if (len == -1) return Errno2Return(); if (len == -1) return Errno2Return();
size_read += static_cast<size_t>(len); size_read += static_cast<size_t>(len);
@ -289,7 +289,7 @@ class AllreduceBase : public IEngine {
ssize_t len = sock.Recv(p + size_read, max_size - size_read); ssize_t len = sock.Recv(p + size_read, max_size - size_read);
// length equals 0, remote disconnected // length equals 0, remote disconnected
if (len == 0) { if (len == 0) {
sock.Close(); return kRecvZeroLen; SafeColl(sock.Close()); return kRecvZeroLen;
} }
if (len == -1) return Errno2Return(); if (len == -1) return Errno2Return();
size_read += static_cast<size_t>(len); size_read += static_cast<size_t>(len);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2014-2024 by XGBoost Contributors * Copyright 2014-2024, XGBoost Contributors
*/ */
#include "xgboost/c_api.h" #include "xgboost/c_api.h"
@ -617,8 +617,8 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
xgboost_CHECK_C_ARG_PTR(field); xgboost_CHECK_C_ARG_PTR(field);
auto const& p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle); auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
p_fmat->SetInfo(field, info, xgboost::DataType::kFloat32, len); p_fmat->SetInfo(field, linalg::Make1dInterface(info, len));
API_END(); API_END();
} }
@ -637,8 +637,9 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
xgboost_CHECK_C_ARG_PTR(field); xgboost_CHECK_C_ARG_PTR(field);
LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGDMatrixSetInfoFromInterface");
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle); auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
p_fmat->SetInfo(field, info, xgboost::DataType::kUInt32, len); p_fmat->SetInfo(field, linalg::Make1dInterface(info, len));
API_END(); API_END();
} }
@ -682,19 +683,52 @@ XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void
xgboost::bst_ulong size, int type) { xgboost::bst_ulong size, int type) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGDMatrixSetInfoFromInterface");
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle); auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
CHECK(type >= 1 && type <= 4); CHECK(type >= 1 && type <= 4);
xgboost_CHECK_C_ARG_PTR(field); xgboost_CHECK_C_ARG_PTR(field);
p_fmat->SetInfo(field, data, static_cast<DataType>(type), size);
API_END();
}
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, const unsigned *group, xgboost::bst_ulong len) { Context ctx;
API_BEGIN(); auto dtype = static_cast<DataType>(type);
CHECK_HANDLE(); std::string str;
LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead."; auto proc = [&](auto cast_d_ptr) {
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle); using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
p_fmat->SetInfo("group", group, xgboost::DataType::kUInt32, len); auto t = linalg::TensorView<T, 1>(
common::Span<T>{cast_d_ptr, static_cast<typename common::Span<T>::index_type>(size)},
{size}, DeviceOrd::CPU());
CHECK(t.CContiguous());
Json iface{linalg::ArrayInterface(t)};
CHECK(ArrayInterface<1>{iface}.is_contiguous);
str = Json::Dump(iface);
return str;
};
// Legacy code using XGBoost dtype, which is a small subset of array interface types.
switch (dtype) {
case xgboost::DataType::kFloat32: {
auto cast_ptr = reinterpret_cast<const float *>(data);
p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr));
break;
}
case xgboost::DataType::kDouble: {
auto cast_ptr = reinterpret_cast<const double *>(data);
p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr));
break;
}
case xgboost::DataType::kUInt32: {
auto cast_ptr = reinterpret_cast<const uint32_t *>(data);
p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr));
break;
}
case xgboost::DataType::kUInt64: {
auto cast_ptr = reinterpret_cast<const uint64_t *>(data);
p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr));
break;
}
default:
LOG(FATAL) << "Unknown data type" << static_cast<uint8_t>(dtype);
}
API_END(); API_END();
} }
@ -990,7 +1024,7 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle, DMatrixHandle dtrain, bs
bst_float *hess, xgboost::bst_ulong len) { bst_float *hess, xgboost::bst_ulong len) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
error::DeprecatedFunc(__func__, "2.1.0", "XGBoosterTrainOneIter"); LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGBoosterTrainOneIter");
auto *learner = static_cast<Learner *>(handle); auto *learner = static_cast<Learner *>(handle);
auto ctx = learner->Ctx()->MakeCPU(); auto ctx = learner->Ctx()->MakeCPU();

View File

@ -1,17 +1,18 @@
/** /**
* Copyright 2021-2023, XGBoost Contributors * Copyright 2021-2024, XGBoost Contributors
*/ */
#ifndef XGBOOST_C_API_C_API_UTILS_H_ #ifndef XGBOOST_C_API_C_API_UTILS_H_
#define XGBOOST_C_API_C_API_UTILS_H_ #define XGBOOST_C_API_C_API_UTILS_H_
#include <algorithm> #include <algorithm> // for min
#include <cstddef> #include <cstddef> // for size_t
#include <functional> #include <functional> // for multiplies
#include <memory> // for shared_ptr #include <memory> // for shared_ptr
#include <string> // for string #include <numeric> // for accumulate
#include <tuple> // for make_tuple #include <string> // for string
#include <utility> // for move #include <tuple> // for make_tuple
#include <vector> #include <utility> // for move
#include <vector> // for vector
#include "../common/json_utils.h" // for TypeCheck #include "../common/json_utils.h" // for TypeCheck
#include "xgboost/c_api.h" #include "xgboost/c_api.h"

View File

@ -1,15 +1,17 @@
/** /**
* Copyright 2023, XGBoost Contributors * Copyright 2023-2024, XGBoost Contributors
*/ */
#include <chrono> // for seconds #include <chrono> // for seconds
#include <cstddef> // for size_t
#include <future> // for future #include <future> // for future
#include <memory> // for unique_ptr #include <memory> // for unique_ptr
#include <string> // for string #include <string> // for string
#include <thread> // for sleep_for
#include <type_traits> // for is_same_v, remove_pointer_t #include <type_traits> // for is_same_v, remove_pointer_t
#include <utility> // for pair #include <utility> // for pair
#include "../collective/comm.h" // for DefaultTimeoutSec
#include "../collective/tracker.h" // for RabitTracker #include "../collective/tracker.h" // for RabitTracker
#include "../common/timer.h" // for Timer
#include "c_api_error.h" // for API_BEGIN #include "c_api_error.h" // for API_BEGIN
#include "xgboost/c_api.h" #include "xgboost/c_api.h"
#include "xgboost/collective/result.h" // for Result #include "xgboost/collective/result.h" // for Result
@ -26,7 +28,7 @@ using namespace xgboost; // NOLINT
namespace { namespace {
using TrackerHandleT = using TrackerHandleT =
std::pair<std::unique_ptr<collective::Tracker>, std::shared_future<collective::Result>>; std::pair<std::shared_ptr<collective::Tracker>, std::shared_future<collective::Result>>;
TrackerHandleT *GetTrackerHandle(TrackerHandle handle) { TrackerHandleT *GetTrackerHandle(TrackerHandle handle) {
xgboost_CHECK_C_ARG_PTR(handle); xgboost_CHECK_C_ARG_PTR(handle);
@ -40,17 +42,29 @@ struct CollAPIEntry {
}; };
using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>; using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
void WaitImpl(TrackerHandleT *ptr) { void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
std::chrono::seconds wait_for{100}; constexpr std::int64_t kDft{collective::DefaultTimeoutSec()};
std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft};
common::Timer timer;
timer.Start();
auto ref = ptr->first; // hold a reference to that free don't delete it while waiting.
auto fut = ptr->second; auto fut = ptr->second;
while (fut.valid()) { while (fut.valid()) {
auto res = fut.wait_for(wait_for); auto res = fut.wait_for(wait_for);
CHECK(res != std::future_status::deferred); CHECK(res != std::future_status::deferred);
if (res == std::future_status::ready) { if (res == std::future_status::ready) {
auto const &rc = ptr->second.get(); auto const &rc = ptr->second.get();
CHECK(rc.OK()) << rc.Report(); collective::SafeColl(rc);
break; break;
} }
if (timer.Duration() > timeout && timeout.count() != 0) {
collective::SafeColl(collective::Fail("Timeout waiting for the tracker."));
}
} }
} }
} // namespace } // namespace
@ -62,15 +76,15 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle) {
Json jconfig = Json::Load(config); Json jconfig = Json::Load(config);
auto type = RequiredArg<String>(jconfig, "dmlc_communicator", __func__); auto type = RequiredArg<String>(jconfig, "dmlc_communicator", __func__);
std::unique_ptr<collective::Tracker> tptr; std::shared_ptr<collective::Tracker> tptr;
if (type == "federated") { if (type == "federated") {
#if defined(XGBOOST_USE_FEDERATED) #if defined(XGBOOST_USE_FEDERATED)
tptr = std::make_unique<collective::FederatedTracker>(jconfig); tptr = std::make_shared<collective::FederatedTracker>(jconfig);
#else #else
LOG(FATAL) << error::NoFederated(); LOG(FATAL) << error::NoFederated();
#endif // defined(XGBOOST_USE_FEDERATED) #endif // defined(XGBOOST_USE_FEDERATED)
} else if (type == "rabit") { } else if (type == "rabit") {
tptr = std::make_unique<collective::RabitTracker>(jconfig); tptr = std::make_shared<collective::RabitTracker>(jconfig);
} else { } else {
LOG(FATAL) << "Unknown communicator:" << type; LOG(FATAL) << "Unknown communicator:" << type;
} }
@ -93,7 +107,7 @@ XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args) {
API_END(); API_END();
} }
XGB_DLL int XGTrackerRun(TrackerHandle handle) { XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *) {
API_BEGIN(); API_BEGIN();
auto *ptr = GetTrackerHandle(handle); auto *ptr = GetTrackerHandle(handle);
CHECK(!ptr->second.valid()) << "Tracker is already running."; CHECK(!ptr->second.valid()) << "Tracker is already running.";
@ -101,19 +115,39 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle) {
API_END(); API_END();
} }
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) { XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config) {
API_BEGIN(); API_BEGIN();
auto *ptr = GetTrackerHandle(handle); auto *ptr = GetTrackerHandle(handle);
xgboost_CHECK_C_ARG_PTR(config); xgboost_CHECK_C_ARG_PTR(config);
auto jconfig = Json::Load(StringView{config}); auto jconfig = Json::Load(StringView{config});
WaitImpl(ptr); // Internally, 0 indicates no timeout, which is the default since we don't want to
// interrupt the model training.
xgboost_CHECK_C_ARG_PTR(config);
auto timeout = OptionalArg<Integer>(jconfig, "timeout", std::int64_t{0});
WaitImpl(ptr, std::chrono::seconds{timeout});
API_END(); API_END();
} }
XGB_DLL int XGTrackerFree(TrackerHandle handle) { XGB_DLL int XGTrackerFree(TrackerHandle handle) {
API_BEGIN(); API_BEGIN();
using namespace std::chrono_literals; // NOLINT
auto *ptr = GetTrackerHandle(handle); auto *ptr = GetTrackerHandle(handle);
WaitImpl(ptr); ptr->first->Stop();
// The wait is not necessary since we just called stop, just reusing the function to do
// any potential cleanups.
WaitImpl(ptr, ptr->first->Timeout());
common::Timer timer;
timer.Start();
// Make sure no one else is waiting on the tracker.
while (!ptr->first.unique()) {
auto ela = timer.Duration().count();
if (ela > ptr->first->Timeout().count()) {
LOG(WARNING) << "Time out " << ptr->first->Timeout().count()
<< " seconds reached for TrackerFree, killing the tracker.";
break;
}
std::this_thread::sleep_for(64ms);
}
delete ptr; delete ptr;
API_END(); API_END();
} }

View File

@ -165,7 +165,7 @@ template <typename T>
T GlobalRatio(Context const* ctx, MetaInfo const& info, T dividend, T divisor) { T GlobalRatio(Context const* ctx, MetaInfo const& info, T dividend, T divisor) {
std::array<T, 2> results{dividend, divisor}; std::array<T, 2> results{dividend, divisor};
auto rc = GlobalSum(ctx, info, linalg::MakeVec(results.data(), results.size())); auto rc = GlobalSum(ctx, info, linalg::MakeVec(results.data(), results.size()));
collective::SafeColl(rc); SafeColl(rc);
std::tie(dividend, divisor) = std::tuple_cat(results); std::tie(dividend, divisor) = std::tuple_cat(results);
if (divisor <= 0) { if (divisor <= 0) {
return std::numeric_limits<T>::quiet_NaN(); return std::numeric_limits<T>::quiet_NaN();

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2023, XGBoost Contributors * Copyright 2023-2024, XGBoost Contributors
*/ */
#include "allgather.h" #include "allgather.h"
@ -7,6 +7,7 @@
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include <cstdint> // for int8_t, int32_t, int64_t #include <cstdint> // for int8_t, int32_t, int64_t
#include <memory> // for shared_ptr #include <memory> // for shared_ptr
#include <utility> // for move
#include "broadcast.h" #include "broadcast.h"
#include "comm.h" // for Comm, Channel #include "comm.h" // for Comm, Channel
@ -29,16 +30,22 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
auto rc = Success() << [&] { auto rc = Success() << [&] {
auto send_rank = (rank + world - r + worker_off) % world; auto send_rank = (rank + world - r + worker_off) % world;
auto send_off = send_rank * segment_size; auto send_off = send_rank * segment_size;
send_off = std::min(send_off, data.size_bytes()); bool is_last_segment = send_rank == (world - 1);
auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off)); auto send_nbytes = is_last_segment ? (data.size_bytes() - send_off) : segment_size;
auto send_seg = data.subspan(send_off, send_nbytes);
CHECK_NE(send_seg.size(), 0);
return next_ch->SendAll(send_seg.data(), send_seg.size_bytes()); return next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
} << [&] { } << [&] {
auto recv_rank = (rank + world - r - 1 + worker_off) % world; auto recv_rank = (rank + world - r - 1 + worker_off) % world;
auto recv_off = recv_rank * segment_size; auto recv_off = recv_rank * segment_size;
recv_off = std::min(recv_off, data.size_bytes()); bool is_last_segment = recv_rank == (world - 1);
auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off)); auto recv_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : segment_size;
auto recv_seg = data.subspan(recv_off, recv_nbytes);
CHECK_NE(recv_seg.size(), 0);
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes()); return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
} << [&] { return prev_ch->Block(); }; } << [&] {
return comm.Block();
};
if (!rc.OK()) { if (!rc.OK()) {
return rc; return rc;
} }
@ -91,7 +98,9 @@ namespace detail {
auto recv_size = sizes[recv_rank]; auto recv_size = sizes[recv_rank];
auto recv_seg = erased_result.subspan(recv_off, recv_size); auto recv_seg = erased_result.subspan(recv_off, recv_size);
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes()); return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
} << [&] { return prev_ch->Block(); }; } << [&] {
return prev_ch->Block();
};
if (!rc.OK()) { if (!rc.OK()) {
return rc; return rc;
} }
@ -99,4 +108,47 @@ namespace detail {
return comm.Block(); return comm.Block();
} }
} // namespace detail } // namespace detail
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
Context const* ctx, CommGroup const& comm, std::vector<std::vector<char>> const& input) {
auto n_inputs = input.size();
std::vector<std::int64_t> sizes(n_inputs);
std::transform(input.cbegin(), input.cend(), sizes.begin(),
[](auto const& vec) { return vec.size(); });
std::vector<std::int64_t> recv_segments(comm.World() + 1, 0);
HostDeviceVector<std::int8_t> recv;
auto rc =
AllgatherV(ctx, comm, linalg::MakeVec(sizes.data(), sizes.size()), &recv_segments, &recv);
SafeColl(rc);
auto global_sizes = common::RestoreType<std::int64_t const>(recv.ConstHostSpan());
std::vector<std::int64_t> offset(global_sizes.size() + 1);
offset[0] = 0;
for (std::size_t i = 1; i < offset.size(); i++) {
offset[i] = offset[i - 1] + global_sizes[i - 1];
}
std::vector<char> collected;
for (auto const& vec : input) {
collected.insert(collected.end(), vec.cbegin(), vec.cend());
}
rc = AllgatherV(ctx, comm, linalg::MakeVec(collected.data(), collected.size()), &recv_segments,
&recv);
SafeColl(rc);
auto out = common::RestoreType<char const>(recv.ConstHostSpan());
std::vector<std::vector<char>> result;
for (std::size_t i = 1; i < offset.size(); ++i) {
std::vector<char> local(out.cbegin() + offset[i - 1], out.cbegin() + offset[i]);
result.emplace_back(std::move(local));
}
return result;
}
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
Context const* ctx, std::vector<std::vector<char>> const& input) {
return VectorAllgatherV(ctx, *GlobalCommGroup(), input);
}
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -1,25 +1,27 @@
/** /**
* Copyright 2023, XGBoost Contributors * Copyright 2023-2024, XGBoost Contributors
*/ */
#pragma once #pragma once
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include <cstdint> // for int32_t #include <cstdint> // for int32_t
#include <memory> // for shared_ptr #include <memory> // for shared_ptr
#include <numeric> // for accumulate #include <numeric> // for accumulate
#include <string> // for string
#include <type_traits> // for remove_cv_t #include <type_traits> // for remove_cv_t
#include <vector> // for vector #include <vector> // for vector
#include "../common/type.h" // for EraseType #include "../common/type.h" // for EraseType
#include "comm.h" // for Comm, Channel #include "comm.h" // for Comm, Channel
#include "comm_group.h" // for CommGroup
#include "xgboost/collective/result.h" // for Result #include "xgboost/collective/result.h" // for Result
#include "xgboost/linalg.h" #include "xgboost/linalg.h" // for MakeVec
#include "xgboost/span.h" // for Span #include "xgboost/span.h" // for Span
namespace xgboost::collective { namespace xgboost::collective {
namespace cpu_impl { namespace cpu_impl {
/** /**
* @param worker_off Segment offset. For example, if the rank 2 worker specifies * @param worker_off Segment offset. For example, if the rank 2 worker specifies
* worker_off = 1, then it owns the third segment. * worker_off = 1, then it owns the third segment (2 + 1).
*/ */
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, [[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data,
std::size_t segment_size, std::int32_t worker_off, std::size_t segment_size, std::int32_t worker_off,
@ -51,8 +53,10 @@ inline void AllgatherVOffset(common::Span<std::int64_t const> sizes,
} // namespace detail } // namespace detail
template <typename T> template <typename T>
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) { [[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data) {
auto n_bytes = sizeof(T) * size; // This function is also used for ring allreduce, hence we allow the last segment to be
// larger due to round-down.
auto n_bytes_per_segment = data.size_bytes() / comm.World();
auto erased = common::EraseType(data); auto erased = common::EraseType(data);
auto rank = comm.Rank(); auto rank = comm.Rank();
@ -61,7 +65,7 @@ template <typename T>
auto prev_ch = comm.Chan(prev); auto prev_ch = comm.Chan(prev);
auto next_ch = comm.Chan(next); auto next_ch = comm.Chan(next);
auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes, 0, prev_ch, next_ch); auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes_per_segment, 0, prev_ch, next_ch);
if (!rc.OK()) { if (!rc.OK()) {
return rc; return rc;
} }
@ -76,7 +80,7 @@ template <typename T>
std::vector<std::int64_t> sizes(world, 0); std::vector<std::int64_t> sizes(world, 0);
sizes[rank] = data.size_bytes(); sizes[rank] = data.size_bytes();
auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()}, 1); auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()});
if (!rc.OK()) { if (!rc.OK()) {
return rc; return rc;
} }
@ -98,4 +102,115 @@ template <typename T>
return detail::RingAllgatherV(comm, sizes, s_segments, erased_result); return detail::RingAllgatherV(comm, sizes, s_segments, erased_result);
} }
template <typename T>
[[nodiscard]] Result Allgather(Context const* ctx, CommGroup const& comm,
linalg::VectorView<T> data) {
if (!comm.IsDistributed()) {
return Success();
}
CHECK(data.Contiguous());
auto erased = common::EraseType(data.Values());
auto const& cctx = comm.Ctx(ctx, data.Device());
auto backend = comm.Backend(data.Device());
return backend->Allgather(cctx, erased);
}
/**
* @brief Gather all data from all workers.
*
* @param data The input and output buffer, needs to be pre-allocated by the caller.
*/
template <typename T>
[[nodiscard]] Result Allgather(Context const* ctx, linalg::VectorView<T> data) {
auto const& cg = *GlobalCommGroup();
if (data.Size() % cg.World() != 0) {
return Fail("The total number of elements should be multiple of the number of workers.");
}
return Allgather(ctx, cg, data);
}
template <typename T>
[[nodiscard]] Result AllgatherV(Context const* ctx, CommGroup const& comm,
linalg::VectorView<T> data,
std::vector<std::int64_t>* recv_segments,
HostDeviceVector<std::int8_t>* recv) {
if (!comm.IsDistributed()) {
return Success();
}
std::vector<std::int64_t> sizes(comm.World(), 0);
sizes[comm.Rank()] = data.Values().size_bytes();
auto erased_sizes = common::EraseType(common::Span{sizes.data(), sizes.size()});
auto rc = comm.Backend(DeviceOrd::CPU())
->Allgather(comm.Ctx(ctx, DeviceOrd::CPU()), erased_sizes);
if (!rc.OK()) {
return rc;
}
recv_segments->resize(sizes.size() + 1);
detail::AllgatherVOffset(sizes, common::Span{recv_segments->data(), recv_segments->size()});
auto total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0LL);
recv->SetDevice(data.Device());
recv->Resize(total_bytes);
auto s_segments = common::Span{recv_segments->data(), recv_segments->size()};
auto backend = comm.Backend(data.Device());
auto erased = common::EraseType(data.Values());
return backend->AllgatherV(
comm.Ctx(ctx, data.Device()), erased, common::Span{sizes.data(), sizes.size()}, s_segments,
data.Device().IsCUDA() ? recv->DeviceSpan() : recv->HostSpan(), AllgatherVAlgo::kBcast);
}
/**
* @brief Allgather with variable length data.
*
* @param data The input data.
* @param recv_segments segment size for each worker. [0, 2, 5] means [0, 2) elements are
* from the first worker, [2, 5) elements are from the second one.
* @param recv The buffer storing the result.
*/
template <typename T>
[[nodiscard]] Result AllgatherV(Context const* ctx, linalg::VectorView<T> data,
std::vector<std::int64_t>* recv_segments,
HostDeviceVector<std::int8_t>* recv) {
return AllgatherV(ctx, *GlobalCommGroup(), data, recv_segments, recv);
}
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
Context const* ctx, CommGroup const& comm, std::vector<std::vector<char>> const& input);
/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
*
* @param inputs All the inputs from the local worker. The number of inputs can vary
* across different workers. Along with which, the size of each vector in
* the input can also vary.
*
* @return The AllgatherV result, containing vectors from all workers.
*/
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
Context const* ctx, std::vector<std::vector<char>> const& input);
/**
* @brief Gathers variable-length strings from all processes and distributes them to all processes.
* @param input Variable-length list of variable-length strings.
*/
[[nodiscard]] inline Result AllgatherStrings(std::vector<std::string> const& input,
std::vector<std::string>* p_result) {
std::vector<std::vector<char>> inputs(input.size());
for (std::size_t i = 0; i < input.size(); ++i) {
inputs[i] = {input[i].cbegin(), input[i].cend()};
}
Context ctx;
auto out = VectorAllgatherV(&ctx, *GlobalCommGroup(), inputs);
auto& result = *p_result;
result.resize(out.size());
for (std::size_t i = 0; i < out.size(); ++i) {
result[i] = {out[i].cbegin(), out[i].cend()};
}
return Success();
}
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2023, XGBoost Contributors * Copyright 2023-2024, XGBoost Contributors
*/ */
#include "allreduce.h" #include "allreduce.h"
@ -16,7 +16,44 @@
#include "xgboost/span.h" // for Span #include "xgboost/span.h" // for Span
namespace xgboost::collective::cpu_impl { namespace xgboost::collective::cpu_impl {
namespace {
template <typename T> template <typename T>
Result RingAllreduceSmall(Comm const& comm, common::Span<std::int8_t> data, Func const& op) {
auto rank = comm.Rank();
auto world = comm.World();
auto next_ch = comm.Chan(BootstrapNext(rank, world));
auto prev_ch = comm.Chan(BootstrapPrev(rank, world));
std::vector<std::int8_t> buffer(data.size_bytes() * world, 0);
auto s_buffer = common::Span{buffer.data(), buffer.size()};
auto offset = data.size_bytes() * rank;
auto self = s_buffer.subspan(offset, data.size_bytes());
std::copy_n(data.data(), data.size_bytes(), self.data());
auto typed = common::RestoreType<T>(s_buffer);
auto rc = RingAllgather(comm, typed);
if (!rc.OK()) {
return rc;
}
auto first = s_buffer.subspan(0, data.size_bytes());
CHECK_EQ(first.size(), data.size());
for (std::int32_t r = 1; r < world; ++r) {
auto offset = data.size_bytes() * r;
auto buf = s_buffer.subspan(offset, data.size_bytes());
op(buf, first);
}
std::copy_n(first.data(), first.size(), data.data());
return Success();
}
} // namespace
template <typename T>
// note that n_bytes_in_seg is calculated with round-down.
Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data, Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
std::size_t n_bytes_in_seg, Func const& op) { std::size_t n_bytes_in_seg, Func const& op) {
auto rank = comm.Rank(); auto rank = comm.Rank();
@ -27,33 +64,39 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
auto next_ch = comm.Chan(dst_rank); auto next_ch = comm.Chan(dst_rank);
auto prev_ch = comm.Chan(src_rank); auto prev_ch = comm.Chan(src_rank);
std::vector<std::int8_t> buffer(n_bytes_in_seg, 0); std::vector<std::int8_t> buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, 0);
auto s_buf = common::Span{buffer.data(), buffer.size()}; auto s_buf = common::Span{buffer.data(), buffer.size()};
for (std::int32_t r = 0; r < world - 1; ++r) { for (std::int32_t r = 0; r < world - 1; ++r) {
// send to ring next common::Span<std::int8_t> seg, recv_seg;
auto send_off = ((rank + world - r) % world) * n_bytes_in_seg; auto rc = Success() << [&] {
send_off = std::min(send_off, data.size_bytes()); // send to ring next
auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg); auto send_rank = (rank + world - r) % world;
auto send_seg = data.subspan(send_off, seg_nbytes); auto send_off = send_rank * n_bytes_in_seg;
auto rc = next_ch->SendAll(send_seg); bool is_last_segment = send_rank == (world - 1);
if (!rc.OK()) {
return rc;
}
// receive from ring prev auto seg_nbytes = is_last_segment ? data.size_bytes() - send_off : n_bytes_in_seg;
auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg; CHECK_EQ(seg_nbytes % sizeof(T), 0);
recv_off = std::min(recv_off, data.size_bytes());
seg_nbytes = std::min(data.size_bytes() - recv_off, n_bytes_in_seg);
CHECK_EQ(seg_nbytes % sizeof(T), 0);
auto recv_seg = data.subspan(recv_off, seg_nbytes);
auto seg = s_buf.subspan(0, recv_seg.size());
rc = std::move(rc) << [&] { return prev_ch->RecvAll(seg); } << [&] { return comm.Block(); }; auto send_seg = data.subspan(send_off, seg_nbytes);
if (!rc.OK()) { return next_ch->SendAll(send_seg);
return rc; } << [&] {
} // receive from ring prev
auto recv_rank = (rank + world - r - 1) % world;
auto recv_off = recv_rank * n_bytes_in_seg;
bool is_last_segment = recv_rank == (world - 1);
auto seg_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : n_bytes_in_seg;
CHECK_EQ(seg_nbytes % sizeof(T), 0);
recv_seg = data.subspan(recv_off, seg_nbytes);
seg = s_buf.subspan(0, recv_seg.size());
return prev_ch->RecvAll(seg);
} << [&] {
return comm.Block();
};
// accumulate to recv_seg // accumulate to recv_seg
CHECK_EQ(seg.size(), recv_seg.size()); CHECK_EQ(seg.size(), recv_seg.size());
@ -68,6 +111,9 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
if (comm.World() == 1) { if (comm.World() == 1) {
return Success(); return Success();
} }
if (data.size_bytes() == 0) {
return Success();
}
return DispatchDType(type, [&](auto t) { return DispatchDType(type, [&](auto t) {
using T = decltype(t); using T = decltype(t);
// Divide the data into segments according to the number of workers. // Divide the data into segments according to the number of workers.
@ -75,7 +121,11 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
CHECK_EQ(data.size_bytes() % n_bytes_elem, 0); CHECK_EQ(data.size_bytes() % n_bytes_elem, 0);
auto n = data.size_bytes() / n_bytes_elem; auto n = data.size_bytes() / n_bytes_elem;
auto world = comm.World(); auto world = comm.World();
auto n_bytes_in_seg = common::DivRoundUp(n, world) * sizeof(T); if (n < static_cast<decltype(n)>(world)) {
return RingAllreduceSmall<T>(comm, data, op);
}
auto n_bytes_in_seg = (n / world) * sizeof(T);
auto rc = RingScatterReduceTyped<T>(comm, data, n_bytes_in_seg, op); auto rc = RingScatterReduceTyped<T>(comm, data, n_bytes_in_seg, op);
if (!rc.OK()) { if (!rc.OK()) {
return rc; return rc;
@ -88,7 +138,9 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
return std::move(rc) << [&] { return std::move(rc) << [&] {
return RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch); return RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
} << [&] { return comm.Block(); }; } << [&] {
return comm.Block();
};
}); });
} }
} // namespace xgboost::collective::cpu_impl } // namespace xgboost::collective::cpu_impl

View File

@ -1,15 +1,18 @@
/** /**
* Copyright 2023, XGBoost Contributors * Copyright 2023-2024, XGBoost Contributors
*/ */
#pragma once #pragma once
#include <cstdint> // for int8_t #include <cstdint> // for int8_t
#include <functional> // for function #include <functional> // for function
#include <type_traits> // for is_invocable_v, enable_if_t #include <type_traits> // for is_invocable_v, enable_if_t
#include <vector> // for vector
#include "../common/type.h" // for EraseType, RestoreType #include "../common/type.h" // for EraseType, RestoreType
#include "../data/array_interface.h" // for ArrayInterfaceHandler #include "../data/array_interface.h" // for ToDType, ArrayInterfaceHandler
#include "comm.h" // for Comm, RestoreType #include "comm.h" // for Comm, RestoreType
#include "comm_group.h" // for GlobalCommGroup
#include "xgboost/collective/result.h" // for Result #include "xgboost/collective/result.h" // for Result
#include "xgboost/context.h" // for Context
#include "xgboost/span.h" // for Span #include "xgboost/span.h" // for Span
namespace xgboost::collective { namespace xgboost::collective {
@ -27,8 +30,7 @@ std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>
auto erased = common::EraseType(data); auto erased = common::EraseType(data);
auto type = ToDType<T>::kType; auto type = ToDType<T>::kType;
auto erased_fn = [type, redop](common::Span<std::int8_t const> lhs, auto erased_fn = [redop](common::Span<std::int8_t const> lhs, common::Span<std::int8_t> out) {
common::Span<std::int8_t> out) {
CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction."; CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction.";
auto lhs_t = common::RestoreType<T const>(lhs); auto lhs_t = common::RestoreType<T const>(lhs);
auto rhs_t = common::RestoreType<T>(out); auto rhs_t = common::RestoreType<T>(out);
@ -37,4 +39,40 @@ std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>
return cpu_impl::RingAllreduce(comm, erased, erased_fn, type); return cpu_impl::RingAllreduce(comm, erased, erased_fn, type);
} }
template <typename T, std::int32_t kDim>
[[nodiscard]] Result Allreduce(Context const* ctx, CommGroup const& comm,
linalg::TensorView<T, kDim> data, Op op) {
if (!comm.IsDistributed()) {
return Success();
}
CHECK(data.Contiguous());
auto erased = common::EraseType(data.Values());
auto type = ToDType<T>::kType;
auto backend = comm.Backend(data.Device());
return backend->Allreduce(comm.Ctx(ctx, data.Device()), erased, type, op);
}
template <typename T, std::int32_t kDim>
[[nodiscard]] Result Allreduce(Context const* ctx, linalg::TensorView<T, kDim> data, Op op) {
return Allreduce(ctx, *GlobalCommGroup(), data, op);
}
/**
* @brief Specialization for std::vector.
*/
template <typename T, typename Alloc>
[[nodiscard]] Result Allreduce(Context const* ctx, std::vector<T, Alloc>* data, Op op) {
return Allreduce(ctx, linalg::MakeVec(data->data(), data->size()), op);
}
/**
* @brief Specialization for scalar value.
*/
template <typename T>
[[nodiscard]] std::enable_if_t<std::is_standard_layout_v<T> && std::is_trivial_v<T>, Result>
Allreduce(Context const* ctx, T* data, Op op) {
return Allreduce(ctx, linalg::MakeVec(data, 1), op);
}
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -1,11 +1,15 @@
/** /**
* Copyright 2023, XGBoost Contributors * Copyright 2023-2024, XGBoost Contributors
*/ */
#pragma once #pragma once
#include <cstdint> // for int32_t, int8_t #include <cstdint> // for int32_t, int8_t
#include "comm.h" // for Comm #include "../common/type.h"
#include "xgboost/collective/result.h" // for #include "comm.h" // for Comm, EraseType
#include "comm_group.h" // for CommGroup
#include "xgboost/collective/result.h" // for Result
#include "xgboost/context.h" // for Context
#include "xgboost/linalg.h" // for VectorView
#include "xgboost/span.h" // for Span #include "xgboost/span.h" // for Span
namespace xgboost::collective { namespace xgboost::collective {
@ -23,4 +27,21 @@ template <typename T>
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes}; common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes};
return cpu_impl::Broadcast(comm, erased, root); return cpu_impl::Broadcast(comm, erased, root);
} }
template <typename T>
[[nodiscard]] Result Broadcast(Context const* ctx, CommGroup const& comm,
linalg::VectorView<T> data, std::int32_t root) {
if (!comm.IsDistributed()) {
return Success();
}
CHECK(data.Contiguous());
auto erased = common::EraseType(data.Values());
auto backend = comm.Backend(data.Device());
return backend->Broadcast(comm.Ctx(ctx, data.Device()), erased, root);
}
template <typename T>
[[nodiscard]] Result Broadcast(Context const* ctx, linalg::VectorView<T> data, std::int32_t root) {
return Broadcast(ctx, *GlobalCommGroup(), data, root);
}
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -42,6 +42,10 @@ bool constexpr IsFloatingPointV() {
auto redop_fn = [](auto lhs, auto out, auto elem_op) { auto redop_fn = [](auto lhs, auto out, auto elem_op) {
auto p_lhs = lhs.data(); auto p_lhs = lhs.data();
auto p_out = out.data(); auto p_out = out.data();
#if defined(__GNUC__) || defined(__clang__)
// For the sum op, one can verify the simd by: addps %xmm15, %xmm14
#pragma omp simd
#endif
for (std::size_t i = 0; i < lhs.size(); ++i) { for (std::size_t i = 0; i < lhs.size(); ++i) {
p_out[i] = elem_op(p_lhs[i], p_out[i]); p_out[i] = elem_op(p_lhs[i], p_out[i]);
} }
@ -108,9 +112,8 @@ bool constexpr IsFloatingPointV() {
return cpu_impl::Broadcast(comm, data, root); return cpu_impl::Broadcast(comm, data, root);
} }
[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span<std::int8_t> data, [[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span<std::int8_t> data) {
std::int64_t size) { return RingAllgather(comm, data);
return RingAllgather(comm, data, size);
} }
[[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span<std::int8_t const> data, [[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,

View File

@ -1,10 +1,9 @@
/** /**
* Copyright 2023, XGBoost Contributors * Copyright 2023-2024, XGBoost Contributors
*/ */
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL) #if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
#include <cstdint> // for int8_t, int64_t #include <cstdint> // for int8_t, int64_t
#include "../common/cuda_context.cuh"
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
#include "../data/array_interface.h" #include "../data/array_interface.h"
#include "allgather.h" // for AllgatherVOffset #include "allgather.h" // for AllgatherVOffset
@ -166,14 +165,14 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
} << [&] { return nccl->Block(); }; } << [&] { return nccl->Block(); };
} }
[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span<std::int8_t> data, [[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span<std::int8_t> data) {
std::int64_t size) {
if (!comm.IsDistributed()) { if (!comm.IsDistributed()) {
return Success(); return Success();
} }
auto nccl = dynamic_cast<NCCLComm const*>(&comm); auto nccl = dynamic_cast<NCCLComm const*>(&comm);
CHECK(nccl); CHECK(nccl);
auto stub = nccl->Stub(); auto stub = nccl->Stub();
auto size = data.size_bytes() / comm.World();
auto send = data.subspan(comm.Rank() * size, size); auto send = data.subspan(comm.Rank() * size, size);
return Success() << [&] { return Success() << [&] {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2023, XGBoost Contributors * Copyright 2023-2024, XGBoost Contributors
*/ */
#pragma once #pragma once
@ -8,8 +8,7 @@
#include "../data/array_interface.h" // for ArrayInterfaceHandler #include "../data/array_interface.h" // for ArrayInterfaceHandler
#include "coll.h" // for Coll #include "coll.h" // for Coll
#include "comm.h" // for Comm #include "comm.h" // for Comm
#include "nccl_stub.h" #include "xgboost/span.h" // for Span
#include "xgboost/span.h" // for Span
namespace xgboost::collective { namespace xgboost::collective {
class NCCLColl : public Coll { class NCCLColl : public Coll {
@ -20,8 +19,7 @@ class NCCLColl : public Coll {
ArrayInterfaceHandler::Type type, Op op) override; ArrayInterfaceHandler::Type type, Op op) override;
[[nodiscard]] Result Broadcast(Comm const& comm, common::Span<std::int8_t> data, [[nodiscard]] Result Broadcast(Comm const& comm, common::Span<std::int8_t> data,
std::int32_t root) override; std::int32_t root) override;
[[nodiscard]] Result Allgather(Comm const& comm, common::Span<std::int8_t> data, [[nodiscard]] Result Allgather(Comm const& comm, common::Span<std::int8_t> data) override;
std::int64_t size) override;
[[nodiscard]] Result AllgatherV(Comm const& comm, common::Span<std::int8_t const> data, [[nodiscard]] Result AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes, common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments, common::Span<std::int64_t> recv_segments,

View File

@ -48,10 +48,8 @@ class Coll : public std::enable_shared_from_this<Coll> {
* @brief Allgather * @brief Allgather
* *
* @param [in,out] data Data buffer for input and output. * @param [in,out] data Data buffer for input and output.
* @param [in] size Size of data for each worker.
*/ */
[[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span<std::int8_t> data, [[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span<std::int8_t> data);
std::int64_t size);
/** /**
* @brief Allgather with variable length. * @brief Allgather with variable length.
* *

View File

@ -1,16 +1,19 @@
/** /**
* Copyright 2023, XGBoost Contributors * Copyright 2023-2024, XGBoost Contributors
*/ */
#include "comm.h" #include "comm.h"
#include <algorithm> // for copy #include <algorithm> // for copy
#include <chrono> // for seconds #include <chrono> // for seconds
#include <cstdint> // for int32_t
#include <cstdlib> // for exit #include <cstdlib> // for exit
#include <memory> // for shared_ptr #include <memory> // for shared_ptr
#include <string> // for string #include <string> // for string
#include <thread> // for thread
#include <utility> // for move, forward #include <utility> // for move, forward
#if !defined(XGBOOST_USE_NCCL)
#include "../common/common.h" // for AssertGPUSupport #include "../common/common.h" // for AssertNCCLSupport
#endif // !defined(XGBOOST_USE_NCCL)
#include "allgather.h" // for RingAllgather #include "allgather.h" // for RingAllgather
#include "protocol.h" // for kMagic #include "protocol.h" // for kMagic
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE #include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
@ -21,11 +24,7 @@
namespace xgboost::collective { namespace xgboost::collective {
Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t retry, std::string task_id) std::int32_t retry, std::string task_id)
: timeout_{timeout}, : timeout_{timeout}, retry_{retry}, tracker_{host, port, -1}, task_id_{std::move(task_id)} {}
retry_{retry},
tracker_{host, port, -1},
task_id_{std::move(task_id)},
loop_{std::shared_ptr<Loop>{new Loop{timeout}}} {}
Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry, Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
std::string const& task_id, TCPSocket* out, std::int32_t rank, std::string const& task_id, TCPSocket* out, std::int32_t rank,
@ -187,12 +186,30 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
return Success(); return Success();
} }
RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, namespace {
std::int32_t retry, std::string task_id, StringView nccl_path) std::string InitLog(std::string task_id, std::int32_t rank) {
: HostComm{std::move(host), port, timeout, retry, std::move(task_id)}, if (task_id.empty()) {
return "Rank " + std::to_string(rank);
}
return "Task " + task_id + " got rank " + std::to_string(rank);
}
} // namespace
RabitComm::RabitComm(std::string const& tracker_host, std::int32_t tracker_port,
std::chrono::seconds timeout, std::int32_t retry, std::string task_id,
StringView nccl_path)
: HostComm{tracker_host, tracker_port, timeout, retry, std::move(task_id)},
nccl_path_{std::move(nccl_path)} { nccl_path_{std::move(nccl_path)} {
if (this->TrackerInfo().host.empty()) {
// Not in a distributed environment.
LOG(CONSOLE) << InitLog(task_id_, rank_);
return;
}
loop_.reset(new Loop{std::chrono::seconds{timeout_}}); // NOLINT
auto rc = this->Bootstrap(timeout_, retry_, task_id_); auto rc = this->Bootstrap(timeout_, retry_, task_id_);
if (!rc.OK()) { if (!rc.OK()) {
this->ResetState();
SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc))); SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc)));
} }
} }
@ -219,20 +236,54 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
// Start command // Start command
TCPSocket listener = TCPSocket::Create(tracker.Domain()); TCPSocket listener = TCPSocket::Create(tracker.Domain());
std::int32_t lport = listener.BindHost(); std::int32_t lport{0};
listener.Listen(); rc = std::move(rc) << [&] {
return listener.BindHost(&lport);
} << [&] {
return listener.Listen();
};
if (!rc.OK()) {
return rc;
}
// create worker for listening to error notice. // create worker for listening to error notice.
auto domain = tracker.Domain(); auto domain = tracker.Domain();
std::shared_ptr<TCPSocket> error_sock{TCPSocket::CreatePtr(domain)}; std::shared_ptr<TCPSocket> error_sock{TCPSocket::CreatePtr(domain)};
auto eport = error_sock->BindHost(); std::int32_t eport{0};
error_sock->Listen(); rc = std::move(rc) << [&] {
return error_sock->BindHost(&eport);
} << [&] {
return error_sock->Listen();
};
if (!rc.OK()) {
return rc;
}
error_port_ = eport;
error_worker_ = std::thread{[error_sock = std::move(error_sock)] { error_worker_ = std::thread{[error_sock = std::move(error_sock)] {
auto conn = error_sock->Accept(); TCPSocket conn;
SockAddress addr;
auto rc = error_sock->Accept(&conn, &addr);
// On Linux, a shutdown causes an invalid argument error;
if (rc.Code() == std::errc::invalid_argument) {
return;
}
// On Windows, accept returns a closed socket after finalize. // On Windows, accept returns a closed socket after finalize.
if (conn.IsClosed()) { if (conn.IsClosed()) {
return; return;
} }
// The error signal is from the tracker, while shutdown signal is from the shutdown method
// of the RabitComm class (this).
bool is_error{false};
rc = proto::Error{}.RecvSignal(&conn, &is_error);
if (!rc.OK()) {
LOG(WARNING) << rc.Report();
return;
}
if (!is_error) {
return; // shutdown
}
LOG(WARNING) << "Another worker is running into error."; LOG(WARNING) << "Another worker is running into error.";
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0 #if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
// exit is nicer than abort as the former performs cleanups. // exit is nicer than abort as the former performs cleanups.
@ -241,6 +292,9 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
LOG(FATAL) << "abort"; LOG(FATAL) << "abort";
#endif #endif
}}; }};
// The worker thread is detached here to avoid the need to handle it later during
// destruction. For C++, if a thread is not joined or detached, it will segfault during
// destruction.
error_worker_.detach(); error_worker_.detach();
proto::Start start; proto::Start start;
@ -253,7 +307,7 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
// get ring neighbors // get ring neighbors
std::string snext; std::string snext;
tracker.Recv(&snext); rc = tracker.Recv(&snext);
if (!rc.OK()) { if (!rc.OK()) {
return Fail("Failed to receive the rank for the next worker.", std::move(rc)); return Fail("Failed to receive the rank for the next worker.", std::move(rc));
} }
@ -273,14 +327,21 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
CHECK(this->channels_.empty()); CHECK(this->channels_.empty());
for (auto& w : workers) { for (auto& w : workers) {
if (w) { if (w) {
rc = std::move(rc) << [&] { return w->SetNoDelay(); } << [&] { return w->NonBlocking(true); } rc = std::move(rc) << [&] {
<< [&] { return w->SetKeepAlive(); }; return w->SetNoDelay();
} << [&] {
return w->NonBlocking(true);
} << [&] {
return w->SetKeepAlive();
};
} }
if (!rc.OK()) { if (!rc.OK()) {
return rc; return rc;
} }
this->channels_.emplace_back(std::make_shared<Channel>(*this, w)); this->channels_.emplace_back(std::make_shared<Channel>(*this, w));
} }
LOG(CONSOLE) << InitLog(task_id_, rank_);
return rc; return rc;
} }
@ -288,6 +349,8 @@ RabitComm::~RabitComm() noexcept(false) {
if (!this->IsDistributed()) { if (!this->IsDistributed()) {
return; return;
} }
LOG(WARNING) << "The communicator is being destroyed without a call to shutdown first. This can "
"lead to undefined behaviour.";
auto rc = this->Shutdown(); auto rc = this->Shutdown();
if (!rc.OK()) { if (!rc.OK()) {
LOG(WARNING) << rc.Report(); LOG(WARNING) << rc.Report();
@ -295,24 +358,52 @@ RabitComm::~RabitComm() noexcept(false) {
} }
[[nodiscard]] Result RabitComm::Shutdown() { [[nodiscard]] Result RabitComm::Shutdown() {
if (!this->IsDistributed()) {
return Success();
}
// Tell the tracker that this worker is shutting down.
TCPSocket tracker; TCPSocket tracker;
// Tell the error hanlding thread that we are shutting down.
TCPSocket err_client;
return Success() << [&] { return Success() << [&] {
return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World()); return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World());
} << [&] { } << [&] {
return this->Block(); return this->Block();
} << [&] { } << [&] {
Json jcmd{Object{}}; return proto::ShutdownCMD{}.Send(&tracker);
jcmd["cmd"] = Integer{static_cast<std::int32_t>(proto::CMD::kShutdown)}; } << [&] {
auto scmd = Json::Dump(jcmd); this->channels_.clear();
auto n_bytes = tracker.Send(scmd);
if (n_bytes != scmd.size()) {
return Fail("Faled to send cmd.");
}
return Success(); return Success();
} << [&] {
// Use tracker address to determine whether we want to use IPv6.
auto taddr = MakeSockAddress(xgboost::StringView{this->tracker_.host}, this->tracker_.port);
// Shutdown the error handling thread. We signal the thread through socket,
// alternatively, we can get the native handle and use pthread_cancel. But using a
// socket seems to be clearer as we know what's happening.
auto const& addr = taddr.IsV4() ? SockAddrV4::Loopback().Addr() : SockAddrV6::Loopback().Addr();
// We use hardcoded 10 seconds and 1 retry here since we are just connecting to a
// local socket. For a normal OS, this should be enough time to schedule the
// connection.
auto rc = Connect(StringView{addr}, this->error_port_, 1,
std::min(std::chrono::seconds{10}, timeout_), &err_client);
this->ResetState();
if (!rc.OK()) {
return Fail("Failed to connect to the error socket.", std::move(rc));
}
return rc;
} << [&] {
// We put error thread shutdown at the end so that we have a better chance to finish
// the previous more important steps.
return proto::Error{}.SignalShutdown(&err_client);
}; };
} }
[[nodiscard]] Result RabitComm::LogTracker(std::string msg) const { [[nodiscard]] Result RabitComm::LogTracker(std::string msg) const {
if (!this->IsDistributed()) {
LOG(CONSOLE) << msg;
return Success();
}
TCPSocket out; TCPSocket out;
proto::Print print; proto::Print print;
return Success() << [&] { return this->ConnectTracker(&out); } return Success() << [&] { return this->ConnectTracker(&out); }
@ -320,8 +411,11 @@ RabitComm::~RabitComm() noexcept(false) {
} }
[[nodiscard]] Result RabitComm::SignalError(Result const& res) { [[nodiscard]] Result RabitComm::SignalError(Result const& res) {
TCPSocket out; TCPSocket tracker;
return Success() << [&] { return this->ConnectTracker(&out); } return Success() << [&] {
<< [&] { return proto::ErrorCMD{}.WorkerSend(&out, res); }; return this->ConnectTracker(&tracker);
} << [&] {
return proto::ErrorCMD{}.WorkerSend(&tracker, res);
};
} }
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -27,7 +27,7 @@ Result GetUniqueId(Comm const& comm, std::shared_ptr<NcclStub> stub, std::shared
ncclUniqueId id; ncclUniqueId id;
if (comm.Rank() == kRootRank) { if (comm.Rank() == kRootRank) {
auto rc = stub->GetUniqueId(&id); auto rc = stub->GetUniqueId(&id);
CHECK(rc.OK()) << rc.Report(); SafeColl(rc);
} }
auto rc = coll->Broadcast( auto rc = coll->Broadcast(
comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)}, kRootRank); comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)}, kRootRank);
@ -90,9 +90,8 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
auto s_this_uuid = s_uuid.subspan(root.Rank() * kUuidLength, kUuidLength); auto s_this_uuid = s_uuid.subspan(root.Rank() * kUuidLength, kUuidLength);
GetCudaUUID(s_this_uuid, ctx->Device()); GetCudaUUID(s_this_uuid, ctx->Device());
auto rc = pimpl->Allgather(root, common::EraseType(s_uuid), s_this_uuid.size_bytes()); auto rc = pimpl->Allgather(root, common::EraseType(s_uuid));
SafeColl(rc);
CHECK(rc.OK()) << rc.Report();
std::vector<xgboost::common::Span<std::uint64_t, kUuidLength>> converted(root.World()); std::vector<xgboost::common::Span<std::uint64_t, kUuidLength>> converted(root.World());
std::size_t j = 0; std::size_t j = 0;
@ -113,7 +112,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
[&] { [&] {
return this->stub_->CommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank()); return this->stub_->CommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank());
}; };
CHECK(rc.OK()) << rc.Report(); SafeColl(rc);
for (std::int32_t r = 0; r < root.World(); ++r) { for (std::int32_t r = 0; r < root.World(); ++r) {
this->channels_.emplace_back( this->channels_.emplace_back(
@ -124,7 +123,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
NCCLComm::~NCCLComm() { NCCLComm::~NCCLComm() {
if (nccl_comm_) { if (nccl_comm_) {
auto rc = stub_->CommDestroy(nccl_comm_); auto rc = stub_->CommDestroy(nccl_comm_);
CHECK(rc.OK()) << rc.Report(); SafeColl(rc);
} }
} }
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -53,6 +53,10 @@ class NCCLComm : public Comm {
auto rc = this->Stream().Sync(false); auto rc = this->Stream().Sync(false);
return GetCUDAResult(rc); return GetCUDAResult(rc);
} }
[[nodiscard]] Result Shutdown() final {
this->ResetState();
return Success();
}
}; };
class NCCLChannel : public Channel { class NCCLChannel : public Channel {

View File

@ -1,10 +1,10 @@
/** /**
* Copyright 2023, XGBoost Contributors * Copyright 2023-2024, XGBoost Contributors
*/ */
#pragma once #pragma once
#include <chrono> // for seconds #include <chrono> // for seconds
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include <cstdint> // for int32_t #include <cstdint> // for int32_t, int64_t
#include <memory> // for shared_ptr #include <memory> // for shared_ptr
#include <string> // for string #include <string> // for string
#include <thread> // for thread #include <thread> // for thread
@ -14,13 +14,13 @@
#include "loop.h" // for Loop #include "loop.h" // for Loop
#include "protocol.h" // for PeerInfo #include "protocol.h" // for PeerInfo
#include "xgboost/collective/result.h" // for Result #include "xgboost/collective/result.h" // for Result
#include "xgboost/collective/socket.h" // for TCPSocket #include "xgboost/collective/socket.h" // for TCPSocket, GetHostName
#include "xgboost/context.h" // for Context #include "xgboost/context.h" // for Context
#include "xgboost/span.h" // for Span #include "xgboost/span.h" // for Span
namespace xgboost::collective { namespace xgboost::collective {
inline constexpr std::int32_t DefaultTimeoutSec() { return 300; } // 5min inline constexpr std::int64_t DefaultTimeoutSec() { return 300; } // 5min
inline constexpr std::int32_t DefaultRetry() { return 3; } inline constexpr std::int32_t DefaultRetry() { return 3; }
// indexing into the ring // indexing into the ring
@ -51,11 +51,25 @@ class Comm : public std::enable_shared_from_this<Comm> {
proto::PeerInfo tracker_; proto::PeerInfo tracker_;
SockDomain domain_{SockDomain::kV4}; SockDomain domain_{SockDomain::kV4};
std::thread error_worker_; std::thread error_worker_;
std::int32_t error_port_;
std::string task_id_; std::string task_id_;
std::vector<std::shared_ptr<Channel>> channels_; std::vector<std::shared_ptr<Channel>> channels_;
std::shared_ptr<Loop> loop_{new Loop{std::chrono::seconds{ std::shared_ptr<Loop> loop_{nullptr}; // fixme: require federated comm to have a timeout
DefaultTimeoutSec()}}}; // fixme: require federated comm to have a timeout
void ResetState() {
this->world_ = -1;
this->rank_ = 0;
this->timeout_ = std::chrono::seconds{DefaultTimeoutSec()};
tracker_ = proto::PeerInfo{};
this->task_id_.clear();
channels_.clear();
loop_.reset();
}
public: public:
Comm() = default; Comm() = default;
@ -75,10 +89,13 @@ class Comm : public std::enable_shared_from_this<Comm> {
[[nodiscard]] auto Retry() const { return retry_; } [[nodiscard]] auto Retry() const { return retry_; }
[[nodiscard]] auto TaskID() const { return task_id_; } [[nodiscard]] auto TaskID() const { return task_id_; }
[[nodiscard]] auto Rank() const { return rank_; } [[nodiscard]] auto Rank() const noexcept { return rank_; }
[[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; } [[nodiscard]] auto World() const noexcept { return IsDistributed() ? world_ : 1; }
[[nodiscard]] bool IsDistributed() const { return world_ != -1; } [[nodiscard]] bool IsDistributed() const noexcept { return world_ != -1; }
void Submit(Loop::Op op) const { loop_->Submit(op); } void Submit(Loop::Op op) const {
CHECK(loop_);
loop_->Submit(op);
}
[[nodiscard]] virtual Result Block() const { return loop_->Block(); } [[nodiscard]] virtual Result Block() const { return loop_->Block(); }
[[nodiscard]] virtual std::shared_ptr<Channel> Chan(std::int32_t rank) const { [[nodiscard]] virtual std::shared_ptr<Channel> Chan(std::int32_t rank) const {
@ -88,6 +105,14 @@ class Comm : public std::enable_shared_from_this<Comm> {
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0; [[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); } [[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
/**
* @brief Get a string ID for the current process.
*/
[[nodiscard]] virtual Result ProcessorName(std::string* out) const {
auto rc = GetHostName(out);
return rc;
}
[[nodiscard]] virtual Result Shutdown() = 0;
}; };
/** /**
@ -105,20 +130,20 @@ class RabitComm : public HostComm {
[[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry, [[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
std::string task_id); std::string task_id);
[[nodiscard]] Result Shutdown();
public: public:
// bootstrapping construction. // bootstrapping construction.
RabitComm() = default; RabitComm() = default;
// ctor for testing where environment is known. RabitComm(std::string const& tracker_host, std::int32_t tracker_port,
RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, std::chrono::seconds timeout, std::int32_t retry, std::string task_id,
std::int32_t retry, std::string task_id, StringView nccl_path); StringView nccl_path);
~RabitComm() noexcept(false) override; ~RabitComm() noexcept(false) override;
[[nodiscard]] bool IsFederated() const override { return false; } [[nodiscard]] bool IsFederated() const override { return false; }
[[nodiscard]] Result LogTracker(std::string msg) const override; [[nodiscard]] Result LogTracker(std::string msg) const override;
[[nodiscard]] Result SignalError(Result const&) override; [[nodiscard]] Result SignalError(Result const&) override;
[[nodiscard]] Result Shutdown() final;
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override; [[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
}; };

View File

@ -1,22 +1,21 @@
/** /**
* Copyright 2023, XGBoost Contributors * Copyright 2023-2024, XGBoost Contributors
*/ */
#include "comm_group.h" #include "comm_group.h"
#include <algorithm> // for transform #include <algorithm> // for transform
#include <cctype> // for tolower
#include <chrono> // for seconds #include <chrono> // for seconds
#include <cstdint> // for int32_t #include <cstdint> // for int32_t
#include <iterator> // for back_inserter
#include <memory> // for shared_ptr, unique_ptr #include <memory> // for shared_ptr, unique_ptr
#include <string> // for string #include <string> // for string
#include <vector> // for vector
#include "../common/json_utils.h" // for OptionalArg #include "../common/json_utils.h" // for OptionalArg
#include "coll.h" // for Coll #include "coll.h" // for Coll
#include "comm.h" // for Comm #include "comm.h" // for Comm
#include "tracker.h" // for GetHostAddress #include "xgboost/context.h" // for DeviceOrd
#include "xgboost/collective/result.h" // for Result #include "xgboost/json.h" // for Json
#include "xgboost/context.h" // for DeviceOrd
#include "xgboost/json.h" // for Json
#if defined(XGBOOST_USE_FEDERATED) #if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_coll.h" #include "../../plugin/federated/federated_coll.h"
@ -65,6 +64,9 @@ CommGroup::CommGroup()
auto const& obj = get<Object const>(config); auto const& obj = get<Object const>(config);
auto it = obj.find(upper); auto it = obj.find(upper);
if (it != obj.cend() && obj.find(name) != obj.cend()) {
LOG(FATAL) << "Duplicated parameter:" << name;
}
if (it != obj.cend()) { if (it != obj.cend()) {
return OptionalArg<decltype(t)>(config, upper, dft); return OptionalArg<decltype(t)>(config, upper, dft);
} else { } else {
@ -78,14 +80,14 @@ CommGroup::CommGroup()
auto task_id = get_param("dmlc_task_id", std::string{}, String{}); auto task_id = get_param("dmlc_task_id", std::string{}, String{});
if (type == "rabit") { if (type == "rabit") {
auto host = get_param("dmlc_tracker_uri", std::string{}, String{}); auto tracker_host = get_param("dmlc_tracker_uri", std::string{}, String{});
auto port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{}); auto tracker_port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{});
auto nccl = get_param("dmlc_nccl_path", std::string{DefaultNcclName()}, String{}); auto nccl = get_param("dmlc_nccl_path", std::string{DefaultNcclName()}, String{});
auto ptr = auto ptr = new CommGroup{
new CommGroup{std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT
host, static_cast<std::int32_t>(port), std::chrono::seconds{timeout}, tracker_host, static_cast<std::int32_t>(tracker_port), std::chrono::seconds{timeout},
static_cast<std::int32_t>(retry), task_id, nccl}}, static_cast<std::int32_t>(retry), task_id, nccl}},
std::shared_ptr<Coll>(new Coll{})}; // NOLINT std::shared_ptr<Coll>(new Coll{})}; // NOLINT
return ptr; return ptr;
} else if (type == "federated") { } else if (type == "federated") {
#if defined(XGBOOST_USE_FEDERATED) #if defined(XGBOOST_USE_FEDERATED)
@ -117,6 +119,8 @@ void GlobalCommGroupInit(Json config) {
void GlobalCommGroupFinalize() { void GlobalCommGroupFinalize() {
auto& sptr = GlobalCommGroup(); auto& sptr = GlobalCommGroup();
auto rc = sptr->Finalize();
sptr.reset(); sptr.reset();
SafeColl(rc);
} }
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -9,7 +9,6 @@
#include "coll.h" // for Comm #include "coll.h" // for Comm
#include "comm.h" // for Coll #include "comm.h" // for Coll
#include "xgboost/collective/result.h" // for Result #include "xgboost/collective/result.h" // for Result
#include "xgboost/collective/socket.h" // for GetHostName
namespace xgboost::collective { namespace xgboost::collective {
/** /**
@ -31,19 +30,35 @@ class CommGroup {
public: public:
CommGroup(); CommGroup();
[[nodiscard]] auto World() const { return comm_->World(); } [[nodiscard]] auto World() const noexcept { return comm_->World(); }
[[nodiscard]] auto Rank() const { return comm_->Rank(); } [[nodiscard]] auto Rank() const noexcept { return comm_->Rank(); }
[[nodiscard]] bool IsDistributed() const { return comm_->IsDistributed(); } [[nodiscard]] bool IsDistributed() const noexcept { return comm_->IsDistributed(); }
[[nodiscard]] Result Finalize() const {
return Success() << [this] {
if (gpu_comm_) {
return gpu_comm_->Shutdown();
}
return Success();
} << [&] {
return comm_->Shutdown();
};
}
[[nodiscard]] static CommGroup* Create(Json config); [[nodiscard]] static CommGroup* Create(Json config);
[[nodiscard]] std::shared_ptr<Coll> Backend(DeviceOrd device) const; [[nodiscard]] std::shared_ptr<Coll> Backend(DeviceOrd device) const;
/**
* @brief Decide the context to use for communication.
*
* @param ctx Global context, provides the CUDA stream and ordinal.
* @param device The device used by the data to be communicated.
*/
[[nodiscard]] Comm const& Ctx(Context const* ctx, DeviceOrd device) const; [[nodiscard]] Comm const& Ctx(Context const* ctx, DeviceOrd device) const;
[[nodiscard]] Result SignalError(Result const& res) { return comm_->SignalError(res); } [[nodiscard]] Result SignalError(Result const& res) { return comm_->SignalError(res); }
[[nodiscard]] Result ProcessorName(std::string* out) const { [[nodiscard]] Result ProcessorName(std::string* out) const {
auto rc = GetHostName(out); return this->comm_->ProcessorName(out);
return rc;
} }
}; };

View File

@ -32,7 +32,8 @@ class InMemoryHandler {
* *
* This is used when the handler only needs to be initialized once with a known world size. * This is used when the handler only needs to be initialized once with a known world size.
*/ */
explicit InMemoryHandler(std::size_t worldSize) : world_size_{worldSize} {} explicit InMemoryHandler(std::int32_t worldSize)
: world_size_{static_cast<std::size_t>(worldSize)} {}
/** /**
* @brief Initialize the handler with the world size and rank. * @brief Initialize the handler with the world size and rank.

View File

@ -18,9 +18,11 @@
#include "xgboost/logging.h" // for CHECK #include "xgboost/logging.h" // for CHECK
namespace xgboost::collective { namespace xgboost::collective {
Result Loop::EmptyQueue(std::queue<Op>* p_queue) const { Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
timer_.Start(__func__); timer_.Start(__func__);
auto error = [this] { timer_.Stop(__func__); }; auto error = [this] {
timer_.Stop(__func__);
};
if (stop_) { if (stop_) {
timer_.Stop(__func__); timer_.Stop(__func__);
@ -48,6 +50,9 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
poll.WatchWrite(*op.sock); poll.WatchWrite(*op.sock);
break; break;
} }
case Op::kSleep: {
break;
}
default: { default: {
error(); error();
return Fail("Invalid socket operation."); return Fail("Invalid socket operation.");
@ -59,12 +64,14 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
// poll, work on fds that are ready. // poll, work on fds that are ready.
timer_.Start("poll"); timer_.Start("poll");
auto rc = poll.Poll(timeout_); if (!poll.fds.empty()) {
timer_.Stop("poll"); auto rc = poll.Poll(timeout_);
if (!rc.OK()) { if (!rc.OK()) {
error(); error();
return rc; return rc;
}
} }
timer_.Stop("poll");
// we wonldn't be here if the queue is empty. // we wonldn't be here if the queue is empty.
CHECK(!qcopy.empty()); CHECK(!qcopy.empty());
@ -75,12 +82,20 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
qcopy.pop(); qcopy.pop();
std::int32_t n_bytes_done{0}; std::int32_t n_bytes_done{0};
CHECK(op.sock->NonBlocking()); if (!op.sock) {
CHECK(op.code == Op::kSleep);
} else {
CHECK(op.sock->NonBlocking());
}
switch (op.code) { switch (op.code) {
case Op::kRead: { case Op::kRead: {
if (poll.CheckRead(*op.sock)) { if (poll.CheckRead(*op.sock)) {
n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off); n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off);
if (n_bytes_done == 0) {
error();
return Fail("Encountered EOF. The other end is likely closed.");
}
} }
break; break;
} }
@ -90,6 +105,12 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
} }
break; break;
} }
case Op::kSleep: {
// For testing only.
std::this_thread::sleep_for(std::chrono::seconds{op.n});
n_bytes_done = op.n;
break;
}
default: { default: {
error(); error();
return Fail("Invalid socket operation."); return Fail("Invalid socket operation.");
@ -110,6 +131,10 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
qcopy.push(op); qcopy.push(op);
} }
} }
if (!blocking) {
break;
}
} }
timer_.Stop(__func__); timer_.Stop(__func__);
@ -128,6 +153,15 @@ void Loop::Process() {
while (true) { while (true) {
try { try {
std::unique_lock lock{mu_}; std::unique_lock lock{mu_};
// This can handle missed notification: wait(lock, predicate) is equivalent to:
//
// while (!predicate()) {
// cv.wait(lock);
// }
//
// As a result, if there's a missed notification, the queue wouldn't be empty, hence
// the predicate would be false and the actual wait wouldn't be invoked. Therefore,
// the blocking call can never go unanswered.
cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; }); cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; });
if (stop_) { if (stop_) {
break; // only point where this loop can exit. break; // only point where this loop can exit.
@ -142,26 +176,27 @@ void Loop::Process() {
queue_.pop(); queue_.pop();
if (op.code == Op::kBlock) { if (op.code == Op::kBlock) {
is_blocking = true; is_blocking = true;
// Block must be the last op in the current batch since no further submit can be
// issued until the blocking call is finished.
CHECK(queue_.empty());
} else { } else {
qcopy.push(op); qcopy.push(op);
} }
} }
if (!is_blocking) { lock.unlock();
// Unblock, we can write to the global queue again. // Clear the local queue, if `is_blocking` is true, this is blocking the current
lock.unlock(); // worker thread (but not the client thread), wait until all operations are
// finished.
auto rc = this->ProcessQueue(&qcopy, is_blocking);
if (is_blocking && rc.OK()) {
CHECK(qcopy.empty());
} }
// Push back the remaining operations.
// Clear the local queue, this is blocking the current worker thread (but not the if (rc.OK()) {
// client thread), wait until all operations are finished. std::unique_lock lock{mu_};
auto rc = this->EmptyQueue(&qcopy); while (!qcopy.empty()) {
queue_.push(qcopy.front());
if (is_blocking) { qcopy.pop();
// The unlock is delayed if this is a blocking call }
lock.unlock();
} }
// Notify the client thread who called block after all error conditions are set. // Notify the client thread who called block after all error conditions are set.
@ -228,7 +263,6 @@ Result Loop::Stop() {
} }
this->Submit(Op{Op::kBlock}); this->Submit(Op{Op::kBlock});
{ {
// Wait for the block call to finish. // Wait for the block call to finish.
std::unique_lock lock{mu_}; std::unique_lock lock{mu_};
@ -243,8 +277,20 @@ Result Loop::Stop() {
} }
} }
void Loop::Submit(Op op) {
std::unique_lock lock{mu_};
if (op.code != Op::kBlock) {
CHECK_NE(op.n, 0);
}
queue_.push(op);
lock.unlock();
cv_.notify_one();
}
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} { Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
timer_.Init(__func__); timer_.Init(__func__);
worker_ = std::thread{[this] { this->Process(); }}; worker_ = std::thread{[this] {
this->Process();
}};
} }
} // namespace xgboost::collective } // namespace xgboost::collective

View File

@ -19,20 +19,27 @@ namespace xgboost::collective {
class Loop { class Loop {
public: public:
struct Op { struct Op {
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2 } code; // kSleep is only for testing
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2, kSleep = 4 } code;
std::int32_t rank{-1}; std::int32_t rank{-1};
std::int8_t* ptr{nullptr}; std::int8_t* ptr{nullptr};
std::size_t n{0}; std::size_t n{0};
TCPSocket* sock{nullptr}; TCPSocket* sock{nullptr};
std::size_t off{0}; std::size_t off{0};
explicit Op(Code c) : code{c} { CHECK(c == kBlock); } explicit Op(Code c) : code{c} { CHECK(c == kBlock || c == kSleep); }
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off) Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {} : code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
Op(Op const&) = default; Op(Op const&) = default;
Op& operator=(Op const&) = default; Op& operator=(Op const&) = default;
Op(Op&&) = default; Op(Op&&) = default;
Op& operator=(Op&&) = default; Op& operator=(Op&&) = default;
// For testing purpose only
[[nodiscard]] static Op Sleep(std::size_t seconds) {
Op op{kSleep};
op.n = seconds;
return op;
}
}; };
private: private:
@ -54,7 +61,7 @@ class Loop {
std::exception_ptr curr_exce_{nullptr}; std::exception_ptr curr_exce_{nullptr};
common::Monitor mutable timer_; common::Monitor mutable timer_;
Result EmptyQueue(std::queue<Op>* p_queue) const; Result ProcessQueue(std::queue<Op>* p_queue, bool blocking) const;
// The cunsumer function that runs inside a worker thread. // The cunsumer function that runs inside a worker thread.
void Process(); void Process();
@ -64,12 +71,7 @@ class Loop {
*/ */
Result Stop(); Result Stop();
void Submit(Op op) { void Submit(Op op);
std::unique_lock lock{mu_};
queue_.push(op);
lock.unlock();
cv_.notify_one();
}
/** /**
* @brief Block the event loop until all ops are finished. In the case of failure, this * @brief Block the event loop until all ops are finished. In the case of failure, this

View File

@ -2,6 +2,8 @@
* Copyright 2023 XGBoost contributors * Copyright 2023 XGBoost contributors
*/ */
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL) #if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
#include <numeric> // for accumulate
#include "comm.cuh" #include "comm.cuh"
#include "nccl_device_communicator.cuh" #include "nccl_device_communicator.cuh"

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2023, XGBoost Contributors * Copyright 2023-2024, XGBoost Contributors
*/ */
#pragma once #pragma once
#include <cstdint> // for int32_t #include <cstdint> // for int32_t
@ -58,6 +58,7 @@ struct Magic {
} }
}; };
// Basic commands for communication between workers and the tracker.
enum class CMD : std::int32_t { enum class CMD : std::int32_t {
kInvalid = 0, kInvalid = 0,
kStart = 1, kStart = 1,
@ -84,7 +85,10 @@ struct Connect {
[[nodiscard]] Result TrackerRecv(TCPSocket* sock, std::int32_t* world, std::int32_t* rank, [[nodiscard]] Result TrackerRecv(TCPSocket* sock, std::int32_t* world, std::int32_t* rank,
std::string* task_id) const { std::string* task_id) const {
std::string init; std::string init;
sock->Recv(&init); auto rc = sock->Recv(&init);
if (!rc.OK()) {
return Fail("Connect protocol failed.", std::move(rc));
}
auto jinit = Json::Load(StringView{init}); auto jinit = Json::Load(StringView{init});
*world = get<Integer const>(jinit["world_size"]); *world = get<Integer const>(jinit["world_size"]);
*rank = get<Integer const>(jinit["rank"]); *rank = get<Integer const>(jinit["rank"]);
@ -122,9 +126,9 @@ class Start {
} }
[[nodiscard]] Result WorkerRecv(TCPSocket* tracker, std::int32_t* p_world) const { [[nodiscard]] Result WorkerRecv(TCPSocket* tracker, std::int32_t* p_world) const {
std::string scmd; std::string scmd;
auto n_bytes = tracker->Recv(&scmd); auto rc = tracker->Recv(&scmd);
if (n_bytes <= 0) { if (!rc.OK()) {
return Fail("Failed to recv init command from tracker."); return Fail("Failed to recv init command from tracker.", std::move(rc));
} }
auto jcmd = Json::Load(scmd); auto jcmd = Json::Load(scmd);
auto world = get<Integer const>(jcmd["world_size"]); auto world = get<Integer const>(jcmd["world_size"]);
@ -132,7 +136,7 @@ class Start {
return Fail("Invalid world size."); return Fail("Invalid world size.");
} }
*p_world = world; *p_world = world;
return Success(); return rc;
} }
[[nodiscard]] Result TrackerHandle(Json jcmd, std::int32_t* recv_world, std::int32_t world, [[nodiscard]] Result TrackerHandle(Json jcmd, std::int32_t* recv_world, std::int32_t world,
std::int32_t* p_port, TCPSocket* p_sock, std::int32_t* p_port, TCPSocket* p_sock,
@ -150,6 +154,7 @@ class Start {
} }
}; };
// Protocol for communicating with the tracker for printing message.
struct Print { struct Print {
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::string msg) const { [[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::string msg) const {
Json jcmd{Object{}}; Json jcmd{Object{}};
@ -172,6 +177,7 @@ struct Print {
} }
}; };
// Protocol for communicating with the tracker during error.
struct ErrorCMD { struct ErrorCMD {
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, Result const& res) const { [[nodiscard]] Result WorkerSend(TCPSocket* tracker, Result const& res) const {
auto msg = res.Report(); auto msg = res.Report();
@ -199,6 +205,7 @@ struct ErrorCMD {
} }
}; };
// Protocol for communicating with the tracker during shutdown.
struct ShutdownCMD { struct ShutdownCMD {
[[nodiscard]] Result Send(TCPSocket* peer) const { [[nodiscard]] Result Send(TCPSocket* peer) const {
Json jcmd{Object{}}; Json jcmd{Object{}};
@ -211,4 +218,40 @@ struct ShutdownCMD {
return Success(); return Success();
} }
}; };
// Protocol for communicating with the local error handler during error or shutdown. Only
// one protocol that doesn't have the tracker involved.
struct Error {
constexpr static std::int32_t ShutdownSignal() { return 0; }
constexpr static std::int32_t ErrorSignal() { return -1; }
[[nodiscard]] Result SignalError(TCPSocket* worker) const {
std::int32_t err{ErrorSignal()};
auto n_sent = worker->SendAll(&err, sizeof(err));
if (n_sent == sizeof(err)) {
return Success();
}
return Fail("Failed to send error signal");
}
// self is localhost, we are sending the signal to the error handling thread for it to
// close.
[[nodiscard]] Result SignalShutdown(TCPSocket* self) const {
std::int32_t err{ShutdownSignal()};
auto n_sent = self->SendAll(&err, sizeof(err));
if (n_sent == sizeof(err)) {
return Success();
}
return Fail("Failed to send shutdown signal");
}
// get signal, either for error or for shutdown.
[[nodiscard]] Result RecvSignal(TCPSocket* peer, bool* p_is_error) const {
std::int32_t err{ShutdownSignal()};
auto n_recv = peer->RecvAll(&err, sizeof(err));
if (n_recv == sizeof(err)) {
*p_is_error = err == 1;
return Success();
}
return Fail("Failed to receive error signal.");
}
};
} // namespace xgboost::collective::proto } // namespace xgboost::collective::proto

86
src/collective/result.cc Normal file
View File

@ -0,0 +1,86 @@
/**
* Copyright 2024, XGBoost Contributors
*/
#include "xgboost/collective/result.h"
#include <filesystem> // for path
#include <sstream> // for stringstream
#include <stack> // for stack
#include "xgboost/logging.h"
namespace xgboost::collective {
namespace detail {
[[nodiscard]] std::string ResultImpl::Report() const {
std::stringstream ss;
ss << "\n- " << this->message;
if (this->errc != std::error_code{}) {
ss << " system error:" << this->errc.message();
}
auto ptr = prev.get();
while (ptr) {
ss << "\n- ";
ss << ptr->message;
if (ptr->errc != std::error_code{}) {
ss << " " << ptr->errc.message();
}
ptr = ptr->prev.get();
}
return ss.str();
}
[[nodiscard]] std::error_code ResultImpl::Code() const {
// Find the root error.
std::stack<ResultImpl const*> stack;
auto ptr = this;
while (ptr) {
stack.push(ptr);
if (ptr->prev) {
ptr = ptr->prev.get();
} else {
break;
}
}
while (!stack.empty()) {
auto frame = stack.top();
stack.pop();
if (frame->errc != std::error_code{}) {
return frame->errc;
}
}
return std::error_code{};
}
void ResultImpl::Concat(std::unique_ptr<ResultImpl> rhs) {
auto ptr = this;
while (ptr->prev) {
ptr = ptr->prev.get();
}
ptr->prev = std::move(rhs);
}
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
std::string MakeMsg(std::string&& msg, char const*, std::int32_t) {
return std::forward<std::string>(msg);
}
#else
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line) {
auto name = std::filesystem::path{file}.filename();
if (file && line != -1) {
return "[" + name.string() + ":" + std::to_string(line) + // NOLINT
"]: " + std::forward<std::string>(msg);
}
return std::forward<std::string>(msg);
}
#endif
} // namespace detail
void SafeColl(Result const& rc) {
if (!rc.OK()) {
LOG(FATAL) << rc.Report();
}
}
} // namespace xgboost::collective

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2022-2023 by XGBoost Contributors * Copyright 2022-2024, XGBoost Contributors
*/ */
#include "xgboost/collective/socket.h" #include "xgboost/collective/socket.h"
@ -8,7 +8,8 @@
#include <cstdint> // std::int32_t #include <cstdint> // std::int32_t
#include <cstring> // std::memcpy, std::memset #include <cstring> // std::memcpy, std::memset
#include <filesystem> // for path #include <filesystem> // for path
#include <system_error> // std::error_code, std::system_category #include <system_error> // for error_code, system_category
#include <thread> // for sleep_for
#include "rabit/internal/socket.h" // for PollHelper #include "rabit/internal/socket.h" // for PollHelper
#include "xgboost/collective/result.h" // for Result #include "xgboost/collective/result.h" // for Result
@ -65,14 +66,18 @@ std::size_t TCPSocket::Send(StringView str) {
return bytes; return bytes;
} }
std::size_t TCPSocket::Recv(std::string *p_str) { [[nodiscard]] Result TCPSocket::Recv(std::string *p_str) {
CHECK(!this->IsClosed()); CHECK(!this->IsClosed());
std::int32_t len; std::int32_t len;
CHECK_EQ(this->RecvAll(&len, sizeof(len)), sizeof(len)) << "Failed to recv string length."; if (this->RecvAll(&len, sizeof(len)) != sizeof(len)) {
return Fail("Failed to recv string length.");
}
p_str->resize(len); p_str->resize(len);
auto bytes = this->RecvAll(&(*p_str)[0], len); auto bytes = this->RecvAll(&(*p_str)[0], len);
CHECK_EQ(bytes, len) << "Failed to recv string."; if (static_cast<decltype(len)>(bytes) != len) {
return bytes; return Fail("Failed to recv string.");
}
return Success();
} }
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry, [[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
@ -110,11 +115,7 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) { for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
if (attempt > 0) { if (attempt > 0) {
LOG(WARNING) << "Retrying connection to " << host << " for the " << attempt << " time."; LOG(WARNING) << "Retrying connection to " << host << " for the " << attempt << " time.";
#if defined(_MSC_VER) || defined(__MINGW32__) std::this_thread::sleep_for(std::chrono::seconds{attempt << 1});
Sleep(attempt << 1);
#else
sleep(attempt << 1);
#endif
} }
auto rc = connect(conn.Handle(), addr_handle, addr_len); auto rc = connect(conn.Handle(), addr_handle, addr_len);
@ -158,8 +159,8 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
std::stringstream ss; std::stringstream ss;
ss << "Failed to connect to " << host << ":" << port; ss << "Failed to connect to " << host << ":" << port;
conn.Close(); auto close_rc = conn.Close();
return Fail(ss.str(), std::move(last_error)); return Fail(ss.str(), std::move(close_rc) + std::move(last_error));
} }
[[nodiscard]] Result GetHostName(std::string *p_out) { [[nodiscard]] Result GetHostName(std::string *p_out) {

View File

@ -1,6 +1,7 @@
/** /**
* Copyright 2023-2024, XGBoost Contributors * Copyright 2023-2024, XGBoost Contributors
*/ */
#include "rabit/internal/socket.h"
#if defined(__unix__) || defined(__APPLE__) #if defined(__unix__) || defined(__APPLE__)
#include <netdb.h> // gethostbyname #include <netdb.h> // gethostbyname
#include <sys/socket.h> // socket, AF_INET6, AF_INET, connect, getsockname #include <sys/socket.h> // socket, AF_INET6, AF_INET, connect, getsockname
@ -70,10 +71,13 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA
return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_); return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_);
} << [&] { } << [&] {
std::string cmd; std::string cmd;
sock_.Recv(&cmd); auto rc = sock_.Recv(&cmd);
if (!rc.OK()) {
return rc;
}
jcmd = Json::Load(StringView{cmd}); jcmd = Json::Load(StringView{cmd});
cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"])); cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"]));
return Success(); return rc;
} << [&] { } << [&] {
if (cmd_ == proto::CMD::kStart) { if (cmd_ == proto::CMD::kStart) {
proto::Start start; proto::Start start;
@ -100,14 +104,18 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA
RabitTracker::RabitTracker(Json const& config) : Tracker{config} { RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
std::string self; std::string self;
auto rc = collective::GetHostAddress(&self); auto rc = Success() << [&] {
host_ = OptionalArg<String>(config, "host", self); return collective::GetHostAddress(&self);
} << [&] {
host_ = OptionalArg<String>(config, "host", self);
auto addr = MakeSockAddress(xgboost::StringView{host_}, 0); auto addr = MakeSockAddress(xgboost::StringView{host_}, 0);
listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6); listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6);
rc = listener_.Bind(host_, &this->port_); return listener_.Bind(host_, &this->port_);
} << [&] {
return listener_.Listen();
};
SafeColl(rc); SafeColl(rc);
listener_.Listen();
} }
Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) { Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
@ -220,9 +228,13 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
// //
// retry is set to 1, just let the worker timeout or error. Otherwise the // retry is set to 1, just let the worker timeout or error. Otherwise the
// tracker and the worker might be waiting for each other. // tracker and the worker might be waiting for each other.
auto rc = Connect(w.first, w.second, 1, timeout_, &out); auto rc = Success() << [&] {
return Connect(w.first, w.second, 1, timeout_, &out);
} << [&] {
return proto::Error{}.SignalError(&out);
};
if (!rc.OK()) { if (!rc.OK()) {
return Fail("Failed to inform workers to stop."); return Fail("Failed to inform worker:" + w.first + " for error.", std::move(rc));
} }
} }
return Success(); return Success();
@ -231,13 +243,37 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
return std::async(std::launch::async, [this, handle_error] { return std::async(std::launch::async, [this, handle_error] {
State state{this->n_workers_}; State state{this->n_workers_};
auto select_accept = [&](TCPSocket* sock, auto* addr) {
// accept with poll so that we can enable timeout and interruption.
rabit::utils::PollHelper poll;
auto rc = Success() << [&] {
std::lock_guard lock{listener_mu_};
return listener_.NonBlocking(true);
} << [&] {
std::lock_guard lock{listener_mu_};
poll.WatchRead(listener_);
if (state.running) {
// Don't timeout if the communicator group is up and running.
return poll.Poll(std::chrono::seconds{-1});
} else {
// Have timeout for workers to bootstrap.
return poll.Poll(timeout_);
}
} << [&] {
// this->Stop() closes the socket with a lock. Therefore, when the accept returns
// due to shutdown, the state is still valid (closed).
return listener_.Accept(sock, addr);
};
return rc;
};
while (state.ShouldContinue()) { while (state.ShouldContinue()) {
TCPSocket sock; TCPSocket sock;
SockAddress addr; SockAddress addr;
this->ready_ = true; this->ready_ = true;
auto rc = listener_.Accept(&sock, &addr); auto rc = select_accept(&sock, &addr);
if (!rc.OK()) { if (!rc.OK()) {
return Fail("Failed to accept connection.", std::move(rc)); return Fail("Failed to accept connection.", this->Stop() + std::move(rc));
} }
auto worker = WorkerProxy{n_workers_, std::move(sock), std::move(addr)}; auto worker = WorkerProxy{n_workers_, std::move(sock), std::move(addr)};
@ -252,7 +288,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
state.Error(); state.Error();
rc = handle_error(worker); rc = handle_error(worker);
if (!rc.OK()) { if (!rc.OK()) {
return Fail("Failed to handle abort.", std::move(rc)); return Fail("Failed to handle abort.", this->Stop() + std::move(rc));
} }
} }
@ -262,7 +298,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
state.Bootstrap(); state.Bootstrap();
} }
if (!rc.OK()) { if (!rc.OK()) {
return rc; return this->Stop() + std::move(rc);
} }
continue; continue;
} }
@ -289,12 +325,11 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
} }
case proto::CMD::kInvalid: case proto::CMD::kInvalid:
default: { default: {
return Fail("Invalid command received."); return Fail("Invalid command received.", this->Stop());
} }
} }
} }
ready_ = false; return this->Stop();
return Success();
}); });
} }
@ -303,11 +338,30 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
SafeColl(rc); SafeColl(rc);
Json args{Object{}}; Json args{Object{}};
args["DMLC_TRACKER_URI"] = String{host_}; args["dmlc_tracker_uri"] = String{host_};
args["DMLC_TRACKER_PORT"] = this->Port(); args["dmlc_tracker_port"] = this->Port();
return args; return args;
} }
[[nodiscard]] Result RabitTracker::Stop() {
if (!this->Ready()) {
return Success();
}
ready_ = false;
std::lock_guard lock{listener_mu_};
if (this->listener_.IsClosed()) {
return Success();
}
return Success() << [&] {
// This should have the effect of stopping the `accept` call.
return this->listener_.Shutdown();
} << [&] {
return listener_.Close();
};
}
[[nodiscard]] Result GetHostAddress(std::string* out) { [[nodiscard]] Result GetHostAddress(std::string* out) {
auto rc = GetHostName(out); auto rc = GetHostName(out);
if (!rc.OK()) { if (!rc.OK()) {

View File

@ -36,15 +36,18 @@ namespace xgboost::collective {
* signal an error to the tracker and the tracker will notify other workers. * signal an error to the tracker and the tracker will notify other workers.
*/ */
class Tracker { class Tracker {
public:
enum class SortBy : std::int8_t {
kHost = 0,
kTask = 1,
};
protected: protected:
// How to sort the workers, either by host name or by task ID. When using a multi-GPU // How to sort the workers, either by host name or by task ID. When using a multi-GPU
// setting, multiple workers can occupy the same host, in which case one should sort // setting, multiple workers can occupy the same host, in which case one should sort
// workers by task. Due to compatibility reason, the task ID is not always available, so // workers by task. Due to compatibility reason, the task ID is not always available, so
// we use host as the default. // we use host as the default.
enum class SortBy : std::int8_t { SortBy sortby_;
kHost = 0,
kTask = 1,
} sortby_;
protected: protected:
std::int32_t n_workers_{0}; std::int32_t n_workers_{0};
@ -54,10 +57,7 @@ class Tracker {
public: public:
explicit Tracker(Json const& config); explicit Tracker(Json const& config);
Tracker(std::int32_t n_worders, std::int32_t port, std::chrono::seconds timeout) virtual ~Tracker() = default;
: n_workers_{n_worders}, port_{port}, timeout_{timeout} {}
virtual ~Tracker() noexcept(false){}; // NOLINT
[[nodiscard]] Result WaitUntilReady() const; [[nodiscard]] Result WaitUntilReady() const;
@ -69,6 +69,11 @@ class Tracker {
* @brief Flag to indicate whether the server is running. * @brief Flag to indicate whether the server is running.
*/ */
[[nodiscard]] bool Ready() const { return ready_; } [[nodiscard]] bool Ready() const { return ready_; }
/**
* @brief Shutdown the tracker, cannot be restarted again. Useful when the tracker hangs while
* calling accept.
*/
virtual Result Stop() { return Success(); }
}; };
class RabitTracker : public Tracker { class RabitTracker : public Tracker {
@ -127,28 +132,22 @@ class RabitTracker : public Tracker {
// record for how to reach out to workers if error happens. // record for how to reach out to workers if error happens.
std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_; std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_;
// listening socket for incoming workers. // listening socket for incoming workers.
//
// At the moment, the listener calls accept without first polling. We can add an
// additional unix domain socket to allow cancelling the accept.
TCPSocket listener_; TCPSocket listener_;
// mutex for protecting the listener, used to prevent race when it's listening while
// another thread tries to shut it down.
std::mutex listener_mu_;
Result Bootstrap(std::vector<WorkerProxy>* p_workers); Result Bootstrap(std::vector<WorkerProxy>* p_workers);
public: public:
explicit RabitTracker(StringView host, std::int32_t n_worders, std::int32_t port,
std::chrono::seconds timeout)
: Tracker{n_worders, port, timeout}, host_{host.c_str(), host.size()} {
listener_ = TCPSocket::Create(SockDomain::kV4);
auto rc = listener_.Bind(host, &this->port_);
CHECK(rc.OK()) << rc.Report();
listener_.Listen();
}
explicit RabitTracker(Json const& config); explicit RabitTracker(Json const& config);
~RabitTracker() noexcept(false) override = default; ~RabitTracker() override = default;
std::future<Result> Run() override; std::future<Result> Run() override;
[[nodiscard]] Json WorkerArgs() const override; [[nodiscard]] Json WorkerArgs() const override;
// Stop the tracker without waiting. This is to prevent the tracker from hanging when
// one of the workers failes to start.
[[nodiscard]] Result Stop() override;
}; };
// Prob the public IP address of the host, need a better method. // Prob the public IP address of the host, need a better method.

View File

@ -14,7 +14,6 @@
#include <thrust/iterator/transform_output_iterator.h> // make_transform_output_iterator #include <thrust/iterator/transform_output_iterator.h> // make_transform_output_iterator
#include <thrust/logical.h> #include <thrust/logical.h>
#include <thrust/sequence.h> #include <thrust/sequence.h>
#include <thrust/sort.h>
#include <thrust/system/cuda/error.h> #include <thrust/system/cuda/error.h>
#include <thrust/system_error.h> #include <thrust/system_error.h>
#include <thrust/transform_scan.h> #include <thrust/transform_scan.h>
@ -301,21 +300,22 @@ class MemoryLogger {
void RegisterAllocation(void *ptr, size_t n) { void RegisterAllocation(void *ptr, size_t n) {
device_allocations[ptr] = n; device_allocations[ptr] = n;
currently_allocated_bytes += n; currently_allocated_bytes += n;
peak_allocated_bytes = peak_allocated_bytes = std::max(peak_allocated_bytes, currently_allocated_bytes);
std::max(peak_allocated_bytes, currently_allocated_bytes);
num_allocations++; num_allocations++;
CHECK_GT(num_allocations, num_deallocations); CHECK_GT(num_allocations, num_deallocations);
} }
void RegisterDeallocation(void *ptr, size_t n, int current_device) { void RegisterDeallocation(void *ptr, size_t n, int current_device) {
auto itr = device_allocations.find(ptr); auto itr = device_allocations.find(ptr);
if (itr == device_allocations.end()) { if (itr == device_allocations.end()) {
LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device " LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device " << current_device
<< current_device << " that was never allocated "; << " that was never allocated\n"
<< dmlc::StackTrace();
} else {
num_deallocations++;
CHECK_LE(num_deallocations, num_allocations);
currently_allocated_bytes -= itr->second;
device_allocations.erase(itr);
} }
num_deallocations++;
CHECK_LE(num_deallocations, num_allocations);
currently_allocated_bytes -= itr->second;
device_allocations.erase(itr);
} }
}; };
DeviceStats stats_; DeviceStats stats_;

View File

@ -11,7 +11,7 @@
#include "xgboost/logging.h" #include "xgboost/logging.h"
namespace xgboost::error { namespace xgboost::error {
std::string DeprecatedFunc(StringView old, StringView since, StringView replacement) { [[nodiscard]] std::string DeprecatedFunc(StringView old, StringView since, StringView replacement) {
std::stringstream ss; std::stringstream ss;
ss << "`" << old << "` is deprecated since" << since << ", use `" << replacement << "` instead."; ss << "`" << old << "` is deprecated since" << since << ", use `" << replacement << "` instead.";
return ss.str(); return ss.str();

View File

@ -89,7 +89,7 @@ void WarnDeprecatedGPUId();
void WarnEmptyDataset(); void WarnEmptyDataset();
std::string DeprecatedFunc(StringView old, StringView since, StringView replacement); [[nodiscard]] std::string DeprecatedFunc(StringView old, StringView since, StringView replacement);
constexpr StringView InvalidCUDAOrdinal() { constexpr StringView InvalidCUDAOrdinal() {
return "Invalid device. `device` is required to be CUDA and there must be at least one GPU " return "Invalid device. `device` is required to be CUDA and there must be at least one GPU "

View File

@ -8,6 +8,7 @@
#define COMMON_HIST_UTIL_CUH_ #define COMMON_HIST_UTIL_CUH_
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
#include <thrust/sort.h> // for sort
#include <cstddef> // for size_t #include <cstddef> // for size_t

View File

@ -6,7 +6,6 @@
#include <algorithm> #include <algorithm>
#include <cstdint> #include <cstdint>
#include <mutex>
#include "xgboost/data.h" #include "xgboost/data.h"
#include "xgboost/host_device_vector.h" #include "xgboost/host_device_vector.h"

View File

@ -4,6 +4,7 @@
#include "quantile.h" #include "quantile.h"
#include <limits> #include <limits>
#include <numeric> // for partial_sum
#include <utility> #include <utility>
#include "../collective/aggregator.h" #include "../collective/aggregator.h"

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2023 by XGBoost Contributors * Copyright 2020-2024, XGBoost Contributors
*/ */
#include <thrust/binary_search.h> #include <thrust/binary_search.h>
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
@ -8,8 +8,8 @@
#include <thrust/transform_scan.h> #include <thrust/transform_scan.h>
#include <thrust/unique.h> #include <thrust/unique.h>
#include <limits> // std::numeric_limits #include <limits> // std::numeric_limits
#include <memory> #include <numeric> // for partial_sum
#include <utility> #include <utility>
#include "../collective/communicator-inl.cuh" #include "../collective/communicator-inl.cuh"

View File

@ -1,8 +1,9 @@
/**
* Copyright 2020-2024, XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_QUANTILE_CUH_ #ifndef XGBOOST_COMMON_QUANTILE_CUH_
#define XGBOOST_COMMON_QUANTILE_CUH_ #define XGBOOST_COMMON_QUANTILE_CUH_
#include <memory>
#include "xgboost/span.h" #include "xgboost/span.h"
#include "xgboost/data.h" #include "xgboost/data.h"
#include "device_helpers.cuh" #include "device_helpers.cuh"

View File

@ -1,9 +1,8 @@
/*! /**
* Copyright by Contributors 2019 * Copyright 2019-2024, XGBoost Contributors
*/ */
#include "timer.h" #include "timer.h"
#include <sstream>
#include <utility> #include <utility>
#include "../collective/communicator-inl.h" #include "../collective/communicator-inl.h"
@ -61,6 +60,9 @@ void Monitor::Print() const {
kv.second.timer.elapsed) kv.second.timer.elapsed)
.count()); .count());
} }
if (stat_map.empty()) {
return;
}
LOG(CONSOLE) << "======== Monitor (" << rank << "): " << label_ << " ========"; LOG(CONSOLE) << "======== Monitor (" << rank << "): " << label_ << " ========";
this->PrintStatistics(stat_map); this->PrintStatistics(stat_map);
} }

View File

@ -11,7 +11,6 @@
#include <cmath> // for abs #include <cmath> // for abs
#include <cstdint> // for uint64_t, int32_t, uint8_t, uint32_t #include <cstdint> // for uint64_t, int32_t, uint8_t, uint32_t
#include <cstring> // for size_t, strcmp, memcpy #include <cstring> // for size_t, strcmp, memcpy
#include <exception> // for exception
#include <iostream> // for operator<<, basic_ostream, basic_ostream::op... #include <iostream> // for operator<<, basic_ostream, basic_ostream::op...
#include <map> // for map, operator!= #include <map> // for map, operator!=
#include <numeric> // for accumulate, partial_sum #include <numeric> // for accumulate, partial_sum
@ -22,7 +21,6 @@
#include "../collective/communicator.h" // for Operation #include "../collective/communicator.h" // for Operation
#include "../common/algorithm.h" // for StableSort #include "../common/algorithm.h" // for StableSort
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry #include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/common.h" // for Split
#include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData #include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData
#include "../common/group_data.h" // for ParallelGroupBuilder #include "../common/group_data.h" // for ParallelGroupBuilder
#include "../common/io.h" // for PeekableInStream #include "../common/io.h" // for PeekableInStream
@ -473,11 +471,11 @@ void MetaInfo::SetInfo(Context const& ctx, StringView key, StringView interface_
<< ", must have at least 1 column even if it's empty."; << ", must have at least 1 column even if it's empty.";
auto const& first = get<Object const>(array.front()); auto const& first = get<Object const>(array.front());
auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(first); auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(first);
is_cuda = ArrayInterfaceHandler::IsCudaPtr(ptr); is_cuda = first.find("stream") != first.cend() || ArrayInterfaceHandler::IsCudaPtr(ptr);
} else { } else {
auto const& first = get<Object const>(j_interface); auto const& first = get<Object const>(j_interface);
auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(first); auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(first);
is_cuda = ArrayInterfaceHandler::IsCudaPtr(ptr); is_cuda = first.find("stream") != first.cend() || ArrayInterfaceHandler::IsCudaPtr(ptr);
} }
if (is_cuda) { if (is_cuda) {
@ -567,46 +565,6 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
} }
} }
void MetaInfo::SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype,
size_t num) {
CHECK(key);
auto proc = [&](auto cast_d_ptr) {
using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
auto t = linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, DeviceOrd::CPU());
CHECK(t.CContiguous());
Json interface {
linalg::ArrayInterface(t)
};
assert(ArrayInterface<1>{interface}.is_contiguous);
return interface;
};
// Legacy code using XGBoost dtype, which is a small subset of array interface types.
switch (dtype) {
case xgboost::DataType::kFloat32: {
auto cast_ptr = reinterpret_cast<const float*>(dptr);
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break;
}
case xgboost::DataType::kDouble: {
auto cast_ptr = reinterpret_cast<const double*>(dptr);
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break;
}
case xgboost::DataType::kUInt32: {
auto cast_ptr = reinterpret_cast<const uint32_t*>(dptr);
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break;
}
case xgboost::DataType::kUInt64: {
auto cast_ptr = reinterpret_cast<const uint64_t*>(dptr);
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break;
}
default:
LOG(FATAL) << "Unknown data type" << static_cast<uint8_t>(dtype);
}
}
void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype, void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
const void** out_dptr) const { const void** out_dptr) const {
if (dtype == DataType::kFloat32) { if (dtype == DataType::kFloat32) {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2021-2023, XGBoost contributors * Copyright 2021-2024, XGBoost contributors
*/ */
#include "file_iterator.h" #include "file_iterator.h"
@ -10,7 +10,10 @@
#include <ostream> // for operator<<, basic_ostream, istringstream #include <ostream> // for operator<<, basic_ostream, istringstream
#include <vector> // for vector #include <vector> // for vector
#include "../common/common.h" // for Split #include "../common/common.h" // for Split
#include "xgboost/linalg.h" // for ArrayInterfaceStr, MakeVec
#include "xgboost/linalg.h"
#include "xgboost/logging.h" // for CHECK
#include "xgboost/string_view.h" // for operator<<, StringView #include "xgboost/string_view.h" // for operator<<, StringView
namespace xgboost::data { namespace xgboost::data {
@ -28,10 +31,10 @@ std::string ValidateFileFormat(std::string const& uri) {
for (size_t i = 0; i < arg_list.size(); ++i) { for (size_t i = 0; i < arg_list.size(); ++i) {
std::istringstream is(arg_list[i]); std::istringstream is(arg_list[i]);
std::pair<std::string, std::string> kv; std::pair<std::string, std::string> kv;
CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format" CHECK(std::getline(is, kv.first, '='))
<< " for key in arg " << i + 1; << "Invalid uri argument format" << " for key in arg " << i + 1;
CHECK(std::getline(is, kv.second)) << "Invalid uri argument format" CHECK(std::getline(is, kv.second))
<< " for value in arg " << i + 1; << "Invalid uri argument format" << " for value in arg " << i + 1;
args.insert(kv); args.insert(kv);
} }
if (args.find("format") == args.cend()) { if (args.find("format") == args.cend()) {
@ -48,4 +51,41 @@ std::string ValidateFileFormat(std::string const& uri) {
return name_args[0] + "?" + name_args[1] + '#' + name_args_cache[1]; return name_args[0] + "?" + name_args[1] + '#' + name_args_cache[1];
} }
} }
int FileIterator::Next() {
CHECK(parser_);
if (parser_->Next()) {
row_block_ = parser_->Value();
indptr_ = linalg::Make1dInterface(row_block_.offset, row_block_.size + 1);
values_ = linalg::Make1dInterface(row_block_.value, row_block_.offset[row_block_.size]);
indices_ = linalg::Make1dInterface(row_block_.index, row_block_.offset[row_block_.size]);
size_t n_columns =
*std::max_element(row_block_.index, row_block_.index + row_block_.offset[row_block_.size]);
// dmlc parser converts 1-based indexing back to 0-based indexing so we can ignore
// this condition and just add 1 to n_columns
n_columns += 1;
XGProxyDMatrixSetDataCSR(proxy_, indptr_.c_str(), indices_.c_str(), values_.c_str(), n_columns);
if (row_block_.label) {
auto str = linalg::Make1dInterface(row_block_.label, row_block_.size);
XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str());
}
if (row_block_.qid) {
auto str = linalg::Make1dInterface(row_block_.qid, row_block_.size);
XGDMatrixSetInfoFromInterface(proxy_, "qid", str.c_str());
}
if (row_block_.weight) {
auto str = linalg::Make1dInterface(row_block_.weight, row_block_.size);
XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str());
}
// Continue iteration
return true;
} else {
// Stop iteration
return false;
}
}
} // namespace xgboost::data } // namespace xgboost::data

View File

@ -1,20 +1,16 @@
/** /**
* Copyright 2021-2023, XGBoost contributors * Copyright 2021-2024, XGBoost contributors
*/ */
#ifndef XGBOOST_DATA_FILE_ITERATOR_H_ #ifndef XGBOOST_DATA_FILE_ITERATOR_H_
#define XGBOOST_DATA_FILE_ITERATOR_H_ #define XGBOOST_DATA_FILE_ITERATOR_H_
#include <algorithm> // for max_element
#include <cstddef> // for size_t
#include <cstdint> // for uint32_t #include <cstdint> // for uint32_t
#include <memory> // for unique_ptr #include <memory> // for unique_ptr
#include <string> // for string #include <string> // for string
#include <utility> // for move #include <utility> // for move
#include "dmlc/data.h" // for RowBlock, Parser #include "dmlc/data.h" // for RowBlock, Parser
#include "xgboost/c_api.h" // for XGDMatrixSetDenseInfo, XGDMatrixFree, XGProxyDMatrixCreate #include "xgboost/c_api.h" // for XGDMatrixFree, XGProxyDMatrixCreate
#include "xgboost/linalg.h" // for ArrayInterfaceStr, MakeVec
#include "xgboost/logging.h" // for CHECK
namespace xgboost::data { namespace xgboost::data {
[[nodiscard]] std::string ValidateFileFormat(std::string const& uri); [[nodiscard]] std::string ValidateFileFormat(std::string const& uri);
@ -53,41 +49,7 @@ class FileIterator {
XGDMatrixFree(proxy_); XGDMatrixFree(proxy_);
} }
int Next() { int Next();
CHECK(parser_);
if (parser_->Next()) {
row_block_ = parser_->Value();
using linalg::MakeVec;
indptr_ = ArrayInterfaceStr(MakeVec(row_block_.offset, row_block_.size + 1));
values_ = ArrayInterfaceStr(MakeVec(row_block_.value, row_block_.offset[row_block_.size]));
indices_ = ArrayInterfaceStr(MakeVec(row_block_.index, row_block_.offset[row_block_.size]));
size_t n_columns = *std::max_element(row_block_.index,
row_block_.index + row_block_.offset[row_block_.size]);
// dmlc parser converts 1-based indexing back to 0-based indexing so we can ignore
// this condition and just add 1 to n_columns
n_columns += 1;
XGProxyDMatrixSetDataCSR(proxy_, indptr_.c_str(), indices_.c_str(),
values_.c_str(), n_columns);
if (row_block_.label) {
XGDMatrixSetDenseInfo(proxy_, "label", row_block_.label, row_block_.size, 1);
}
if (row_block_.qid) {
XGDMatrixSetDenseInfo(proxy_, "qid", row_block_.qid, row_block_.size, 1);
}
if (row_block_.weight) {
XGDMatrixSetDenseInfo(proxy_, "weight", row_block_.weight, row_block_.size, 1);
}
// Continue iteration
return true;
} else {
// Stop iteration
return false;
}
}
auto Proxy() -> decltype(proxy_) { return proxy_; } auto Proxy() -> decltype(proxy_) { return proxy_; }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2014-2023, XGBoost Contributors * Copyright 2014-2024, XGBoost Contributors
* \file sparse_page_source.h * \file sparse_page_source.h
*/ */
#ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ #ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
@ -7,23 +7,26 @@
#include <algorithm> // for min #include <algorithm> // for min
#include <atomic> // for atomic #include <atomic> // for atomic
#include <cstdio> // for remove
#include <future> // for async #include <future> // for async
#include <map> #include <memory> // for unique_ptr
#include <memory> #include <mutex> // for mutex
#include <mutex> // for mutex #include <string> // for string
#include <string> #include <utility> // for pair, move
#include <thread> #include <vector> // for vector
#include <utility> // for pair, move
#include <vector>
#include "../common/common.h" #if !defined(XGBOOST_USE_CUDA)
#include "../common/io.h" // for PrivateMmapConstStream #include "../common/common.h" // for AssertGPUSupport
#include "../common/timer.h" // for Monitor, Timer #endif // !defined(XGBOOST_USE_CUDA)
#include "adapter.h"
#include "proxy_dmatrix.h" // for DMatrixProxy #include "../common/io.h" // for PrivateMmapConstStream
#include "sparse_page_writer.h" // for SparsePageFormat #include "../common/timer.h" // for Monitor, Timer
#include "xgboost/base.h" #include "proxy_dmatrix.h" // for DMatrixProxy
#include "xgboost/data.h" #include "sparse_page_writer.h" // for SparsePageFormat
#include "xgboost/base.h" // for bst_feature_t
#include "xgboost/data.h" // for SparsePage, CSCPage
#include "xgboost/global_config.h" // for GlobalConfigThreadLocalStore
#include "xgboost/logging.h" // for CHECK_EQ
namespace xgboost::data { namespace xgboost::data {
inline void TryDeleteCacheFile(const std::string& file) { inline void TryDeleteCacheFile(const std::string& file) {
@ -185,6 +188,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
exce_.Rethrow(); exce_.Rethrow();
auto const config = *GlobalConfigThreadLocalStore::Get();
for (std::int32_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) { for (std::int32_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
fetch_it %= n_batches_; // ring fetch_it %= n_batches_; // ring
if (ring_->at(fetch_it).valid()) { if (ring_->at(fetch_it).valid()) {
@ -192,7 +196,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
} }
auto const* self = this; // make sure it's const auto const* self = this; // make sure it's const
CHECK_LT(fetch_it, cache_info_->offset.size()); CHECK_LT(fetch_it, cache_info_->offset.size());
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, this]() { ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, config, this]() {
*GlobalConfigThreadLocalStore::Get() = config;
auto page = std::make_shared<S>(); auto page = std::make_shared<S>();
this->exce_.Run([&] { this->exce_.Run([&] {
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")}; std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2014-2023 by Contributors * Copyright 2014-2024, XGBoost Contributors
* \file gbtree.cc * \file gbtree.cc
* \brief gradient boosted tree implementation. * \brief gradient boosted tree implementation.
* \author Tianqi Chen * \author Tianqi Chen
@ -11,14 +11,12 @@
#include <algorithm> #include <algorithm>
#include <cstdint> // std::int32_t #include <cstdint> // std::int32_t
#include <map>
#include <memory> #include <memory>
#include <numeric> // for iota
#include <string> #include <string>
#include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "../common/common.h"
#include "../common/timer.h" #include "../common/timer.h"
#include "../tree/param.h" // TrainParam #include "../tree/param.h" // TrainParam
#include "gbtree_model.h" #include "gbtree_model.h"

View File

@ -10,15 +10,15 @@
#include <array> #include <array>
#include <cmath> #include <cmath>
#include <numeric> // for accumulate
#include "../collective/communicator-inl.h" #include "../common/common.h" // for AssertGPUSupport
#include "../common/common.h" // MetricNoCache
#include "../common/math.h" #include "../common/math.h"
#include "../common/optional_weight.h" // OptionalWeights #include "../common/optional_weight.h" // OptionalWeights
#include "../common/pseudo_huber.h" #include "../common/pseudo_huber.h"
#include "../common/quantile_loss_utils.h" // QuantileLossParam #include "../common/quantile_loss_utils.h" // QuantileLossParam
#include "../common/threading_utils.h" #include "../common/threading_utils.h"
#include "metric_common.h" #include "metric_common.h" // MetricNoCache
#include "xgboost/collective/result.h" // for SafeColl #include "xgboost/collective/result.h" // for SafeColl
#include "xgboost/metric.h" #include "xgboost/metric.h"

View File

@ -9,8 +9,6 @@
#include <string> #include <string>
#include "../collective/aggregator.h" #include "../collective/aggregator.h"
#include "../collective/communicator-inl.h"
#include "../common/common.h"
#include "xgboost/metric.h" #include "xgboost/metric.h"
namespace xgboost { namespace xgboost {

View File

@ -9,8 +9,8 @@
#include <array> #include <array>
#include <atomic> #include <atomic>
#include <cmath> #include <cmath>
#include <numeric> // for accumulate
#include "../collective/communicator-inl.h"
#include "../common/math.h" #include "../common/math.h"
#include "../common/threading_utils.h" #include "../common/threading_utils.h"
#include "metric_common.h" // MetricNoCache #include "metric_common.h" // MetricNoCache

View File

@ -9,10 +9,9 @@
#include <array> #include <array>
#include <memory> #include <memory>
#include <numeric> // for accumulate
#include <vector> #include <vector>
#include "../collective/communicator-inl.h"
#include "../common/math.h"
#include "../common/survival_util.h" #include "../common/survival_util.h"
#include "../common/threading_utils.h" #include "../common/threading_utils.h"
#include "metric_common.h" // MetricNoCache #include "metric_common.h" // MetricNoCache

View File

@ -1,13 +1,13 @@
/** /**
* Copyright 2019-2023 by XGBoost Contributors * Copyright 2019-2024, XGBoost Contributors
*/ */
#include <thrust/functional.h> #include <thrust/functional.h>
#include <thrust/random.h> #include <thrust/random.h>
#include <thrust/sort.h> // for sort
#include <thrust/transform.h> #include <thrust/transform.h>
#include <xgboost/host_device_vector.h> #include <xgboost/host_device_vector.h>
#include <xgboost/logging.h> #include <xgboost/logging.h>
#include <algorithm>
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include <limits> #include <limits>
#include <utility> #include <utility>

Some files were not shown because too many files have changed in this diff Show More