merge latest change from upstream

This commit is contained in:
Hui Liu 2024-04-22 09:35:31 -07:00
commit 8b75204fed
146 changed files with 3111 additions and 1027 deletions

View File

@ -8,7 +8,7 @@ updates:
- package-ecosystem: "maven"
directory: "/jvm-packages"
schedule:
interval: "daily"
interval: "monthly"
- package-ecosystem: "maven"
directory: "/jvm-packages/xgboost4j"
schedule:
@ -16,11 +16,11 @@ updates:
- package-ecosystem: "maven"
directory: "/jvm-packages/xgboost4j-gpu"
schedule:
interval: "daily"
interval: "monthly"
- package-ecosystem: "maven"
directory: "/jvm-packages/xgboost4j-example"
schedule:
interval: "daily"
interval: "monthly"
- package-ecosystem: "maven"
directory: "/jvm-packages/xgboost4j-spark"
schedule:
@ -28,4 +28,4 @@ updates:
- package-ecosystem: "maven"
directory: "/jvm-packages/xgboost4j-spark-gpu"
schedule:
interval: "daily"
interval: "monthly"

View File

@ -110,7 +110,7 @@ jobs:
name: Test R package on Debian
runs-on: ubuntu-latest
container:
image: rhub/debian-gcc-devel
image: rhub/debian-gcc-release
steps:
- name: Install system dependencies
@ -130,12 +130,12 @@ jobs:
- name: Install dependencies
shell: bash -l {0}
run: |
/tmp/R-devel/bin/Rscript -e "source('./R-package/tests/helper_scripts/install_deps.R')"
Rscript -e "source('./R-package/tests/helper_scripts/install_deps.R')"
- name: Test R
shell: bash -l {0}
run: |
python3 tests/ci_build/test_r_package.py --r=/tmp/R-devel/bin/R --build-tool=autotools --task=check
python3 tests/ci_build/test_r_package.py --r=/usr/bin/R --build-tool=autotools --task=check
- uses: dorny/paths-filter@v2
id: changes
@ -147,4 +147,4 @@ jobs:
- name: Run document check
if: steps.changes.outputs.r_package == 'true'
run: |
python3 tests/ci_build/test_r_package.py --r=/tmp/R-devel/bin/R --task=doc
python3 tests/ci_build/test_r_package.py --r=/usr/bin/R --task=doc

View File

@ -3,7 +3,7 @@ name: update-rapids
on:
workflow_dispatch:
schedule:
- cron: "0 20 * * *" # Run once daily
- cron: "0 20 * * 1" # Run once weekly
permissions:
pull-requests: write
@ -32,7 +32,7 @@ jobs:
run: |
bash tests/buildkite/update-rapids.sh
- name: Create Pull Request
uses: peter-evans/create-pull-request@v5
uses: peter-evans/create-pull-request@v6
if: github.ref == 'refs/heads/master'
with:
add-paths: |

View File

@ -2101,7 +2101,7 @@ This release marks a major milestone for the XGBoost project.
## v0.90 (2019.05.18)
### XGBoost Python package drops Python 2.x (#4379, #4381)
Python 2.x is reaching its end-of-life at the end of this year. [Many scientific Python packages are now moving to drop Python 2.x](https://python3statement.org/).
Python 2.x is reaching its end-of-life at the end of this year. [Many scientific Python packages are now moving to drop Python 2.x](https://python3statement.github.io/).
### XGBoost4J-Spark now requires Spark 2.4.x (#4377)
* Spark 2.3 is reaching its end-of-life soon. See discussion at #4389.

View File

@ -26,6 +26,11 @@ NVL <- function(x, val) {
'multi:softprob', 'rank:pairwise', 'rank:ndcg', 'rank:map'))
}
.RANKING_OBJECTIVES <- function() {
return(c('binary:logistic', 'binary:logitraw', 'binary:hinge', 'multi:softmax',
'multi:softprob'))
}
#
# Low-level functions for boosting --------------------------------------------
@ -235,33 +240,43 @@ convert.labels <- function(labels, objective_name) {
}
# Generates random (stratified if needed) CV folds
generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
if (NROW(group)) {
if (stratified) {
warning(
paste0(
"Stratified splitting is not supported when using 'group' attribute.",
" Will use unstratified splitting."
)
)
}
return(generate.group.folds(nfold, group))
}
objective <- params$objective
if (!is.character(objective)) {
warning("Will use unstratified splitting (custom objective used)")
stratified <- FALSE
}
# cannot stratify if label is NULL
if (stratified && is.null(label)) {
warning("Will use unstratified splitting (no 'labels' available)")
stratified <- FALSE
}
# cannot do it for rank
objective <- params$objective
if (is.character(objective) && strtrim(objective, 5) == 'rank:') {
stop("\n\tAutomatic generation of CV-folds is not implemented for ranking!\n",
stop("\n\tAutomatic generation of CV-folds is not implemented for ranking without 'group' field!\n",
"\tConsider providing pre-computed CV-folds through the 'folds=' parameter.\n")
}
# shuffle
rnd_idx <- sample.int(nrows)
if (stratified &&
length(label) == length(rnd_idx)) {
if (stratified && length(label) == length(rnd_idx)) {
y <- label[rnd_idx]
# WARNING: some heuristic logic is employed to identify classification setting!
# - For classification, need to convert y labels to factor before making the folds,
# and then do stratification by factor levels.
# - For regression, leave y numeric and do stratification by quantiles.
if (is.character(objective)) {
y <- convert.labels(y, params$objective)
} else {
# If no 'objective' given in params, it means that user either wants to
# use the default 'reg:squarederror' objective or has provided a custom
# obj function. Here, assume classification setting when y has 5 or less
# unique values:
if (length(unique(y)) <= 5) {
y <- factor(y)
}
y <- convert.labels(y, objective)
}
folds <- xgb.createFolds(y = y, k = nfold)
} else {
@ -277,6 +292,29 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
return(folds)
}
generate.group.folds <- function(nfold, group) {
ngroups <- length(group) - 1
if (ngroups < nfold) {
stop("DMatrix has fewer groups than folds.")
}
seq_groups <- seq_len(ngroups)
indices <- lapply(seq_groups, function(gr) seq(group[gr] + 1, group[gr + 1]))
assignments <- base::split(seq_groups, as.integer(seq_groups %% nfold))
assignments <- unname(assignments)
out <- vector("list", nfold)
randomized_groups <- sample(ngroups)
for (idx in seq_len(nfold)) {
groups_idx_test <- randomized_groups[assignments[[idx]]]
groups_test <- indices[groups_idx_test]
idx_test <- unlist(groups_test)
attributes(idx_test)$group_test <- lengths(groups_test)
attributes(idx_test)$group_train <- lengths(indices[-groups_idx_test])
out[[idx]] <- idx_test
}
return(out)
}
# Creates CV folds stratified by the values of y.
# It was borrowed from caret::createFolds and simplified
# by always returning an unnamed list of fold indices.

View File

@ -1259,8 +1259,11 @@ xgb.get.DMatrix.data <- function(dmat) {
#' Get a new DMatrix containing the specified rows of
#' original xgb.DMatrix object
#'
#' @param object Object of class "xgb.DMatrix"
#' @param idxset a integer vector of indices of rows needed
#' @param object Object of class "xgb.DMatrix".
#' @param idxset An integer vector of indices of rows needed (base-1 indexing).
#' @param allow_groups Whether to allow slicing an `xgb.DMatrix` with `group` (or
#' equivalently `qid`) field. Note that in such case, the result will not have
#' the groups anymore - they need to be set manually through `setinfo`.
#' @param colset currently not used (columns subsetting is not available)
#'
#' @examples
@ -1275,11 +1278,11 @@ xgb.get.DMatrix.data <- function(dmat) {
#'
#' @rdname xgb.slice.DMatrix
#' @export
xgb.slice.DMatrix <- function(object, idxset) {
xgb.slice.DMatrix <- function(object, idxset, allow_groups = FALSE) {
if (!inherits(object, "xgb.DMatrix")) {
stop("object must be xgb.DMatrix")
}
ret <- .Call(XGDMatrixSliceDMatrix_R, object, idxset)
ret <- .Call(XGDMatrixSliceDMatrix_R, object, idxset, allow_groups)
attr_list <- attributes(object)
nr <- nrow(object)
@ -1296,7 +1299,15 @@ xgb.slice.DMatrix <- function(object, idxset) {
}
}
}
return(structure(ret, class = "xgb.DMatrix"))
out <- structure(ret, class = "xgb.DMatrix")
parent_fields <- as.list(attributes(object)$fields)
if (NROW(parent_fields)) {
child_fields <- parent_fields[!(names(parent_fields) %in% c("group", "qid"))]
child_fields <- as.environment(child_fields)
attributes(out)$fields <- child_fields
}
return(out)
}
#' @rdname xgb.slice.DMatrix
@ -1340,11 +1351,11 @@ print.xgb.DMatrix <- function(x, verbose = FALSE, ...) {
}
cat(class_print, ' dim:', nrow(x), 'x', ncol(x), ' info: ')
infos <- character(0)
if (xgb.DMatrix.hasinfo(x, 'label')) infos <- 'label'
if (xgb.DMatrix.hasinfo(x, 'weight')) infos <- c(infos, 'weight')
if (xgb.DMatrix.hasinfo(x, 'base_margin')) infos <- c(infos, 'base_margin')
if (length(infos) == 0) infos <- 'NA'
infos <- names(attributes(x)$fields)
infos <- infos[infos != "feature_name"]
if (!NROW(infos)) infos <- "NA"
infos <- infos[order(infos)]
infos <- paste(infos, collapse = ", ")
cat(infos)
cnames <- colnames(x)
cat(' colnames:')

View File

@ -1,6 +1,6 @@
#' Cross Validation
#'
#' The cross validation function of xgboost
#' The cross validation function of xgboost.
#'
#' @param params the list of parameters. The complete list of parameters is
#' available in the \href{http://xgboost.readthedocs.io/en/latest/parameter.html}{online documentation}. Below
@ -19,13 +19,17 @@
#'
#' See \code{\link{xgb.train}} for further details.
#' See also demo/ for walkthrough example in R.
#' @param data takes an \code{xgb.DMatrix}, \code{matrix}, or \code{dgCMatrix} as the input.
#'
#' Note that, while `params` accepts a `seed` entry and will use such parameter for model training if
#' supplied, this seed is not used for creation of train-test splits, which instead rely on R's own RNG
#' system - thus, for reproducible results, one needs to call the `set.seed` function beforehand.
#' @param data An `xgb.DMatrix` object, with corresponding fields like `label` or bounds as required
#' for model training by the objective.
#'
#' Note that only the basic `xgb.DMatrix` class is supported - variants such as `xgb.QuantileDMatrix`
#' or `xgb.ExternalDMatrix` are not supported here.
#' @param nrounds the max number of iterations
#' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples.
#' @param label vector of response values. Should be provided only when data is an R-matrix.
#' @param missing is only used when input is a dense matrix. By default is set to NA, which means
#' that NA values should be considered as 'missing' by the algorithm.
#' Sometimes, 0 or other extreme value might be used to represent missing values.
#' @param prediction A logical value indicating whether to return the test fold predictions
#' from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback.
#' @param showsd \code{boolean}, whether to show standard deviation of cross validation
@ -47,13 +51,30 @@
#' @param feval customized evaluation function. Returns
#' \code{list(metric='metric-name', value='metric-value')} with given
#' prediction and dtrain.
#' @param stratified a \code{boolean} indicating whether sampling of folds should be stratified
#' by the values of outcome labels.
#' @param stratified A \code{boolean} indicating whether sampling of folds should be stratified
#' by the values of outcome labels. For real-valued labels in regression objectives,
#' stratification will be done by discretizing the labels into up to 5 buckets beforehand.
#'
#' If passing "auto", will be set to `TRUE` if the objective in `params` is a classification
#' objective (from XGBoost's built-in objectives, doesn't apply to custom ones), and to
#' `FALSE` otherwise.
#'
#' This parameter is ignored when `data` has a `group` field - in such case, the splitting
#' will be based on whole groups (note that this might make the folds have different sizes).
#'
#' Value `TRUE` here is \bold{not} supported for custom objectives.
#' @param folds \code{list} provides a possibility to use a list of pre-defined CV folds
#' (each element must be a vector of test fold's indices). When folds are supplied,
#' the \code{nfold} and \code{stratified} parameters are ignored.
#'
#' If `data` has a `group` field and the objective requires this field, each fold (list element)
#' must additionally have two attributes (retrievable through \link{attributes}) named `group_test`
#' and `group_train`, which should hold the `group` to assign through \link{setinfo.xgb.DMatrix} to
#' the resulting DMatrices.
#' @param train_folds \code{list} list specifying which indicies to use for training. If \code{NULL}
#' (the default) all indices not specified in \code{folds} will be used for training.
#'
#' This is not supported when `data` has `group` field.
#' @param verbose \code{boolean}, print the statistics during the process
#' @param print_every_n Print each n-th iteration evaluation messages when \code{verbose>0}.
#' Default is 1 which means all messages are printed. This parameter is passed to the
@ -118,13 +139,14 @@
#' print(cv, verbose=TRUE)
#'
#' @export
xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing = NA,
xgb.cv <- function(params = list(), data, nrounds, nfold,
prediction = FALSE, showsd = TRUE, metrics = list(),
obj = NULL, feval = NULL, stratified = TRUE, folds = NULL, train_folds = NULL,
obj = NULL, feval = NULL, stratified = "auto", folds = NULL, train_folds = NULL,
verbose = TRUE, print_every_n = 1L,
early_stopping_rounds = NULL, maximize = NULL, callbacks = list(), ...) {
check.deprecation(...)
stopifnot(inherits(data, "xgb.DMatrix"))
if (inherits(data, "xgb.DMatrix") && .Call(XGCheckNullPtr_R, data)) {
stop("'data' is an invalid 'xgb.DMatrix' object. Must be constructed again.")
}
@ -137,16 +159,22 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
check.custom.obj()
check.custom.eval()
# Check the labels
if ((inherits(data, 'xgb.DMatrix') && !xgb.DMatrix.hasinfo(data, 'label')) ||
(!inherits(data, 'xgb.DMatrix') && is.null(label))) {
stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix")
} else if (inherits(data, 'xgb.DMatrix')) {
if (!is.null(label))
warning("xgb.cv: label will be ignored, since data is of type xgb.DMatrix")
cv_label <- getinfo(data, 'label')
} else {
cv_label <- label
if (stratified == "auto") {
if (is.character(params$objective)) {
stratified <- (
(params$objective %in% .CLASSIFICATION_OBJECTIVES())
&& !(params$objective %in% .RANKING_OBJECTIVES())
)
} else {
stratified <- FALSE
}
}
# Check the labels and groups
cv_label <- getinfo(data, "label")
cv_group <- getinfo(data, "group")
if (!is.null(train_folds) && NROW(cv_group)) {
stop("'train_folds' is not supported for DMatrix object with 'group' field.")
}
# CV folds
@ -157,7 +185,7 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
} else {
if (nfold <= 1)
stop("'nfold' must be > 1")
folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, params)
folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, cv_group, params)
}
# Callbacks
@ -195,20 +223,18 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
# create the booster-folds
# train_folds
dall <- xgb.get.DMatrix(
data = data,
label = label,
missing = missing,
weight = NULL,
nthread = params$nthread
)
dall <- data
bst_folds <- lapply(seq_along(folds), function(k) {
dtest <- xgb.slice.DMatrix(dall, folds[[k]])
dtest <- xgb.slice.DMatrix(dall, folds[[k]], allow_groups = TRUE)
# code originally contributed by @RolandASc on stackoverflow
if (is.null(train_folds))
dtrain <- xgb.slice.DMatrix(dall, unlist(folds[-k]))
dtrain <- xgb.slice.DMatrix(dall, unlist(folds[-k]), allow_groups = TRUE)
else
dtrain <- xgb.slice.DMatrix(dall, train_folds[[k]])
dtrain <- xgb.slice.DMatrix(dall, train_folds[[k]], allow_groups = TRUE)
if (!is.null(attributes(folds[[k]])$group_test)) {
setinfo(dtest, "group", attributes(folds[[k]])$group_test)
setinfo(dtrain, "group", attributes(folds[[k]])$group_train)
}
bst <- xgb.Booster(
params = params,
cachelist = list(dtrain, dtest),
@ -312,8 +338,8 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
#' @examples
#' data(agaricus.train, package='xgboost')
#' train <- agaricus.train
#' cv <- xgb.cv(data = train$data, label = train$label, nfold = 5, max_depth = 2,
#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
#' cv <- xgb.cv(data = xgb.DMatrix(train$data, label = train$label), nfold = 5, max_depth = 2,
#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
#' print(cv)
#' print(cv, verbose=TRUE)
#'

View File

@ -23,8 +23,8 @@ including the best iteration (when available).
\examples{
data(agaricus.train, package='xgboost')
train <- agaricus.train
cv <- xgb.cv(data = train$data, label = train$label, nfold = 5, max_depth = 2,
eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
cv <- xgb.cv(data = xgb.DMatrix(train$data, label = train$label), nfold = 5, max_depth = 2,
eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
print(cv)
print(cv, verbose=TRUE)

View File

@ -9,14 +9,12 @@ xgb.cv(
data,
nrounds,
nfold,
label = NULL,
missing = NA,
prediction = FALSE,
showsd = TRUE,
metrics = list(),
obj = NULL,
feval = NULL,
stratified = TRUE,
stratified = "auto",
folds = NULL,
train_folds = NULL,
verbose = TRUE,
@ -44,20 +42,23 @@ is a shorter summary:
}
See \code{\link{xgb.train}} for further details.
See also demo/ for walkthrough example in R.}
See also demo/ for walkthrough example in R.
\item{data}{takes an \code{xgb.DMatrix}, \code{matrix}, or \code{dgCMatrix} as the input.}
Note that, while \code{params} accepts a \code{seed} entry and will use such parameter for model training if
supplied, this seed is not used for creation of train-test splits, which instead rely on R's own RNG
system - thus, for reproducible results, one needs to call the \code{set.seed} function beforehand.}
\item{data}{An \code{xgb.DMatrix} object, with corresponding fields like \code{label} or bounds as required
for model training by the objective.
\if{html}{\out{<div class="sourceCode">}}\preformatted{ Note that only the basic `xgb.DMatrix` class is supported - variants such as `xgb.QuantileDMatrix`
or `xgb.ExternalDMatrix` are not supported here.
}\if{html}{\out{</div>}}}
\item{nrounds}{the max number of iterations}
\item{nfold}{the original dataset is randomly partitioned into \code{nfold} equal size subsamples.}
\item{label}{vector of response values. Should be provided only when data is an R-matrix.}
\item{missing}{is only used when input is a dense matrix. By default is set to NA, which means
that NA values should be considered as 'missing' by the algorithm.
Sometimes, 0 or other extreme value might be used to represent missing values.}
\item{prediction}{A logical value indicating whether to return the test fold predictions
from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback.}
@ -84,15 +85,35 @@ gradient with given prediction and dtrain.}
\code{list(metric='metric-name', value='metric-value')} with given
prediction and dtrain.}
\item{stratified}{a \code{boolean} indicating whether sampling of folds should be stratified
by the values of outcome labels.}
\item{stratified}{A \code{boolean} indicating whether sampling of folds should be stratified
by the values of outcome labels. For real-valued labels in regression objectives,
stratification will be done by discretizing the labels into up to 5 buckets beforehand.
\if{html}{\out{<div class="sourceCode">}}\preformatted{ If passing "auto", will be set to `TRUE` if the objective in `params` is a classification
objective (from XGBoost's built-in objectives, doesn't apply to custom ones), and to
`FALSE` otherwise.
This parameter is ignored when `data` has a `group` field - in such case, the splitting
will be based on whole groups (note that this might make the folds have different sizes).
Value `TRUE` here is \\bold\{not\} supported for custom objectives.
}\if{html}{\out{</div>}}}
\item{folds}{\code{list} provides a possibility to use a list of pre-defined CV folds
(each element must be a vector of test fold's indices). When folds are supplied,
the \code{nfold} and \code{stratified} parameters are ignored.}
the \code{nfold} and \code{stratified} parameters are ignored.
\if{html}{\out{<div class="sourceCode">}}\preformatted{ If `data` has a `group` field and the objective requires this field, each fold (list element)
must additionally have two attributes (retrievable through \link{attributes}) named `group_test`
and `group_train`, which should hold the `group` to assign through \link{setinfo.xgb.DMatrix} to
the resulting DMatrices.
}\if{html}{\out{</div>}}}
\item{train_folds}{\code{list} list specifying which indicies to use for training. If \code{NULL}
(the default) all indices not specified in \code{folds} will be used for training.}
(the default) all indices not specified in \code{folds} will be used for training.
\if{html}{\out{<div class="sourceCode">}}\preformatted{ This is not supported when `data` has `group` field.
}\if{html}{\out{</div>}}}
\item{verbose}{\code{boolean}, print the statistics during the process}
@ -142,7 +163,7 @@ such as saving also the models created during cross validation); or a list \code
will contain elements such as \code{best_iteration} when using the early stopping callback (\link{xgb.cb.early.stop}).
}
\description{
The cross validation function of xgboost
The cross validation function of xgboost.
}
\details{
The original sample is randomly partitioned into \code{nfold} equal size subsamples.

View File

@ -6,14 +6,18 @@
\title{Get a new DMatrix containing the specified rows of
original xgb.DMatrix object}
\usage{
xgb.slice.DMatrix(object, idxset)
xgb.slice.DMatrix(object, idxset, allow_groups = FALSE)
\method{[}{xgb.DMatrix}(object, idxset, colset = NULL)
}
\arguments{
\item{object}{Object of class "xgb.DMatrix"}
\item{object}{Object of class "xgb.DMatrix".}
\item{idxset}{a integer vector of indices of rows needed}
\item{idxset}{An integer vector of indices of rows needed (base-1 indexing).}
\item{allow_groups}{Whether to allow slicing an \code{xgb.DMatrix} with \code{group} (or
equivalently \code{qid}) field. Note that in such case, the result will not have
the groups anymore - they need to be set manually through \code{setinfo}.}
\item{colset}{currently not used (columns subsetting is not available)}
}

View File

@ -99,10 +99,12 @@ OBJECTS= \
$(PKGROOT)/src/context.o \
$(PKGROOT)/src/logging.o \
$(PKGROOT)/src/global_config.o \
$(PKGROOT)/src/collective/result.o \
$(PKGROOT)/src/collective/allgather.o \
$(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/comm_group.o \
$(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/communicator-inl.o \
$(PKGROOT)/src/collective/tracker.o \

View File

@ -99,10 +99,12 @@ OBJECTS= \
$(PKGROOT)/src/context.o \
$(PKGROOT)/src/logging.o \
$(PKGROOT)/src/global_config.o \
$(PKGROOT)/src/collective/result.o \
$(PKGROOT)/src/collective/allgather.o \
$(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/comm_group.o \
$(PKGROOT)/src/collective/coll.o \
$(PKGROOT)/src/collective/communicator-inl.o \
$(PKGROOT)/src/collective/tracker.o \

View File

@ -71,7 +71,7 @@ extern SEXP XGDMatrixGetDataAsCSR_R(SEXP);
extern SEXP XGDMatrixSaveBinary_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixSetStrFeatureInfo_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP);
extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP, SEXP);
extern SEXP XGBSetGlobalConfig_R(SEXP);
extern SEXP XGBGetGlobalConfig_R(void);
extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP);
@ -134,7 +134,7 @@ static const R_CallMethodDef CallEntries[] = {
{"XGDMatrixSaveBinary_R", (DL_FUNC) &XGDMatrixSaveBinary_R, 3},
{"XGDMatrixSetInfo_R", (DL_FUNC) &XGDMatrixSetInfo_R, 3},
{"XGDMatrixSetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixSetStrFeatureInfo_R, 3},
{"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 2},
{"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 3},
{"XGBSetGlobalConfig_R", (DL_FUNC) &XGBSetGlobalConfig_R, 1},
{"XGBGetGlobalConfig_R", (DL_FUNC) &XGBGetGlobalConfig_R, 0},
{"XGBoosterFeatureScore_R", (DL_FUNC) &XGBoosterFeatureScore_R, 2},

View File

@ -512,7 +512,7 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP
return ret;
}
XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset, SEXP allow_groups) {
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
R_API_BEGIN();
R_xlen_t len = Rf_xlength(idxset);
@ -531,7 +531,7 @@ XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
res_code = XGDMatrixSliceDMatrixEx(R_ExternalPtrAddr(handle),
BeginPtr(idxvec), len,
&res,
0);
Rf_asLogical(allow_groups));
}
CHECK_CALL(res_code);
R_SetExternalPtrAddr(ret, res);

View File

@ -112,9 +112,10 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP
* \brief create a new dmatrix from sliced content of existing matrix
* \param handle instance of data matrix to be sliced
* \param idxset index set
* \param allow_groups Whether to allow slicing the DMatrix if it has a 'group' field
* \return a sliced new matrix
*/
XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset);
XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset, SEXP allow_groups);
/*!
* \brief load a data matrix into binary file

View File

@ -334,7 +334,7 @@ test_that("xgb.cv works", {
set.seed(11)
expect_output(
cv <- xgb.cv(
data = train$data, label = train$label, max_depth = 2, nfold = 5,
data = xgb.DMatrix(train$data, label = train$label), max_depth = 2, nfold = 5,
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
eval_metric = "error", verbose = TRUE
),
@ -357,13 +357,13 @@ test_that("xgb.cv works with stratified folds", {
cv <- xgb.cv(
data = dtrain, max_depth = 2, nfold = 5,
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
verbose = TRUE, stratified = FALSE
verbose = FALSE, stratified = FALSE
)
set.seed(314159)
cv2 <- xgb.cv(
data = dtrain, max_depth = 2, nfold = 5,
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
verbose = TRUE, stratified = TRUE
verbose = FALSE, stratified = TRUE
)
# Stratified folds should result in a different evaluation logs
expect_true(all(cv$evaluation_log[, test_logloss_mean] != cv2$evaluation_log[, test_logloss_mean]))
@ -885,3 +885,57 @@ test_that("Seed in params override PRNG from R", {
)
)
})
test_that("xgb.cv works for AFT", {
X <- matrix(c(1, -1, -1, 1, 0, 1, 1, 0), nrow = 4, byrow = TRUE) # 4x2 matrix
dtrain <- xgb.DMatrix(X, nthread = n_threads)
params <- list(objective = 'survival:aft', learning_rate = 0.2, max_depth = 2L)
# data must have bounds
expect_error(
xgb.cv(
params = params,
data = dtrain,
nround = 5L,
nfold = 4L,
nthread = n_threads
)
)
setinfo(dtrain, 'label_lower_bound', c(2, 3, 0, 4))
setinfo(dtrain, 'label_upper_bound', c(2, Inf, 4, 5))
# automatic stratified splitting is turned off
expect_warning(
xgb.cv(
params = params, data = dtrain, nround = 5L, nfold = 4L,
nthread = n_threads, stratified = TRUE, verbose = FALSE
)
)
# this works without any issue
expect_no_warning(
xgb.cv(params = params, data = dtrain, nround = 5L, nfold = 4L, verbose = FALSE)
)
})
test_that("xgb.cv works for ranking", {
data(iris)
x <- iris[, -(4:5)]
y <- as.integer(iris$Petal.Width)
group <- rep(50, 3)
dm <- xgb.DMatrix(x, label = y, group = group)
res <- xgb.cv(
data = dm,
params = list(
objective = "rank:pairwise",
max_depth = 3
),
nrounds = 3,
nfold = 2,
verbose = FALSE,
stratified = FALSE
)
expect_equal(length(res$folds), 2L)
})

View File

@ -367,7 +367,7 @@ test_that("prediction in early-stopping xgb.cv works", {
expect_output(
cv <- xgb.cv(param, dtrain, nfold = 5, eta = 0.1, nrounds = 20,
early_stopping_rounds = 5, maximize = FALSE, stratified = FALSE,
prediction = TRUE, base_score = 0.5)
prediction = TRUE, base_score = 0.5, verbose = TRUE)
, "Stopping. Best iteration")
expect_false(is.null(cv$early_stop$best_iteration))
@ -387,7 +387,7 @@ test_that("prediction in xgb.cv for softprob works", {
lb <- as.numeric(iris$Species) - 1
set.seed(11)
expect_warning(
cv <- xgb.cv(data = as.matrix(iris[, -5]), label = lb, nfold = 4,
cv <- xgb.cv(data = xgb.DMatrix(as.matrix(iris[, -5]), label = lb), nfold = 4,
eta = 0.5, nrounds = 5, max_depth = 3, nthread = n_threads,
subsample = 0.8, gamma = 2, verbose = 0,
prediction = TRUE, objective = "multi:softprob", num_class = 3)

View File

@ -243,7 +243,7 @@ test_that("xgb.DMatrix: print", {
txt <- capture.output({
print(dtrain)
})
expect_equal(txt, "xgb.DMatrix dim: 6513 x 126 info: label weight base_margin colnames: yes")
expect_equal(txt, "xgb.DMatrix dim: 6513 x 126 info: base_margin, label, weight colnames: yes")
# DMatrix with just features
dtrain <- xgb.DMatrix(
@ -724,6 +724,44 @@ test_that("xgb.DMatrix: quantile cuts look correct", {
)
})
test_that("xgb.DMatrix: slicing keeps field indicators", {
data(mtcars)
x <- as.matrix(mtcars[, -1])
y <- mtcars[, 1]
dm <- xgb.DMatrix(
data = x,
label_lower_bound = -y,
label_upper_bound = y,
nthread = 1
)
idx_take <- seq(1, 5)
dm_slice <- xgb.slice.DMatrix(dm, idx_take)
expect_true(xgb.DMatrix.hasinfo(dm_slice, "label_lower_bound"))
expect_true(xgb.DMatrix.hasinfo(dm_slice, "label_upper_bound"))
expect_false(xgb.DMatrix.hasinfo(dm_slice, "label"))
expect_equal(getinfo(dm_slice, "label_lower_bound"), -y[idx_take], tolerance = 1e-6)
expect_equal(getinfo(dm_slice, "label_upper_bound"), y[idx_take], tolerance = 1e-6)
})
test_that("xgb.DMatrix: can slice with groups", {
data(iris)
x <- as.matrix(iris[, -5])
set.seed(123)
y <- sample(3, size = nrow(x), replace = TRUE)
group <- c(50, 50, 50)
dm <- xgb.DMatrix(x, label = y, group = group, nthread = 1)
idx_take <- seq(1, 50)
dm_slice <- xgb.slice.DMatrix(dm, idx_take, allow_groups = TRUE)
expect_true(xgb.DMatrix.hasinfo(dm_slice, "label"))
expect_false(xgb.DMatrix.hasinfo(dm_slice, "group"))
expect_false(xgb.DMatrix.hasinfo(dm_slice, "qid"))
expect_null(getinfo(dm_slice, "group"))
expect_equal(getinfo(dm_slice, "label"), y[idx_take], tolerance = 1e-6)
})
test_that("xgb.DMatrix: can read CSV", {
txt <- paste(
"1,2,3",

View File

@ -40,7 +40,7 @@ def main(client):
# you can pass output directly into `predict` too.
prediction = dxgb.predict(client, bst, dtrain)
print("Evaluation history:", history)
return prediction
print("Error:", da.sqrt((prediction - y) ** 2).mean().compute())
if __name__ == "__main__":

View File

@ -144,6 +144,14 @@ which provides higher flexibility. For example:
ctest --verbose
If you need to debug errors on Windows using the debugger from VS, you can append the gtest flags in `test_main.cc`:
.. code-block::
::testing::GTEST_FLAG(filter) = "Suite.Test";
::testing::GTEST_FLAG(repeat) = 10;
***********************************************
Sanitizers: Detect memory errors and data races
***********************************************

View File

@ -28,7 +28,7 @@ Contents
Python Package <python/index>
R Package <R-package/index>
JVM Package <jvm/index>
Ruby Package <https://github.com/ankane/xgb>
Ruby Package <https://github.com/ankane/xgboost-ruby>
Swift Package <https://github.com/kongzii/SwiftXGBoost>
Julia Package <julia>
C Package <c>

View File

@ -52,7 +52,7 @@ Notice that the samples are sorted based on their query index in a non-decreasin
X, y = make_classification(random_state=seed)
rng = np.random.default_rng(seed)
n_query_groups = 3
qid = rng.integers(0, 3, size=X.shape[0])
qid = rng.integers(0, n_query_groups, size=X.shape[0])
# Sort the inputs based on query index
sorted_idx = np.argsort(qid)
@ -65,14 +65,14 @@ The simplest way to train a ranking model is by using the scikit-learn estimator
.. code-block:: python
ranker = xgb.XGBRanker(tree_method="hist", lambdarank_num_pair_per_sample=8, objective="rank:ndcg", lambdarank_pair_method="topk")
ranker.fit(X, y, qid=qid)
ranker.fit(X, y, qid=qid[sorted_idx])
Please note that, as of writing, there's no learning-to-rank interface in scikit-learn. As a result, the :py:class:`xgboost.XGBRanker` class does not fully conform the scikit-learn estimator guideline and can not be directly used with some of its utility functions. For instances, the ``auc_score`` and ``ndcg_score`` in scikit-learn don't consider query group information nor the pairwise loss. Most of the metrics are implemented as part of XGBoost, but to use scikit-learn utilities like :py:func:`sklearn.model_selection.cross_validation`, we need to make some adjustments in order to pass the ``qid`` as an additional parameter for :py:meth:`xgboost.XGBRanker.score`. Given a data frame ``X`` (either pandas or cuDF), add the column ``qid`` as follows:
.. code-block:: python
df = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])])
df["qid"] = qid
df["qid"] = qid[sorted_idx]
ranker.fit(df, y) # No need to pass qid as a separate argument
from sklearn.model_selection import StratifiedGroupKFold, cross_val_score

View File

@ -1,5 +1,5 @@
/**
* Copyright 2015~2023 by XGBoost Contributors
* Copyright 2015-2024, XGBoost Contributors
* \file c_api.h
* \author Tianqi Chen
* \brief C API of XGBoost, used for interfacing to other languages.
@ -639,21 +639,14 @@ XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle,
* \param len length of array
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
const char *field,
const float *array,
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const float *array,
bst_ulong len);
/*!
* \brief set uint32 vector to a content in info
* \param handle a instance of data matrix
* \param field field name
* \param array pointer to unsigned int vector
* \param len length of array
* \return 0 when success, -1 when failure happens
/**
* @deprecated since 2.1.0
*
* Use @ref XGDMatrixSetInfoFromInterface instead.
*/
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
const char *field,
const unsigned *array,
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const unsigned *array,
bst_ulong len);
/*!
@ -725,42 +718,13 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
bst_ulong *size,
const char ***out_features);
/*!
* \brief Set meta info from dense matrix. Valid field names are:
/**
* @deprecated since 2.1.0
*
* - label
* - weight
* - base_margin
* - group
* - label_lower_bound
* - label_upper_bound
* - feature_weights
*
* \param handle An instance of data matrix
* \param field Field name
* \param data Pointer to consecutive memory storing data.
* \param size Size of the data, this is relative to size of type. (Meaning NOT number
* of bytes.)
* \param type Indicator of data type. This is defined in xgboost::DataType enum class.
* - float = 1
* - double = 2
* - uint32_t = 3
* - uint64_t = 4
* \return 0 when success, -1 when failure happens
* Use @ref XGDMatrixSetInfoFromInterface instead.
*/
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field,
void const *data, bst_ulong size, int type);
/*!
* \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix
* \param handle a instance of data matrix
* \param group pointer to group size
* \param len length of array
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
const unsigned *group,
bst_ulong len);
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void const *data,
bst_ulong size, int type);
/*!
* \brief get float info vector from matrix.
@ -1591,7 +1555,7 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
/**
* @brief Get the arguments needed for running workers. This should be called after
* XGTrackerRun() and XGTrackerWait()
* XGTrackerRun().
*
* @param handle The handle to the tracker.
* @param args The arguments returned as a JSON document.
@ -1601,16 +1565,19 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args);
/**
* @brief Run the tracker.
* @brief Start the tracker. The tracker runs in the background and this function returns
* once the tracker is started.
*
* @param handle The handle to the tracker.
* @param config Unused at the moment, preserved for the future.
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGTrackerRun(TrackerHandle handle);
XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *config);
/**
* @brief Wait for the tracker to finish, should be called after XGTrackerRun().
* @brief Wait for the tracker to finish, should be called after XGTrackerRun(). This
* function will block until the tracker task is finished or timeout is reached.
*
* @param handle The handle to the tracker.
* @param config JSON encoded configuration. No argument is required yet, preserved for
@ -1618,11 +1585,12 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle);
*
* @return 0 for success, -1 for failure.
*/
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config);
XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config);
/**
* @brief Free a tracker instance. XGTrackerWait() is called internally. If the tracker
* cannot close properly, manual interruption is required.
* @brief Free a tracker instance. This should be called after XGTrackerWaitFor(). If the
* tracker is not properly waited, this function will shutdown all connections with
* the tracker, potentially leading to undefined behavior.
*
* @param handle The handle to the tracker.
*

View File

@ -3,13 +3,11 @@
*/
#pragma once
#include <xgboost/logging.h>
#include <memory> // for unique_ptr
#include <sstream> // for stringstream
#include <stack> // for stack
#include <string> // for string
#include <utility> // for move
#include <cstdint> // for int32_t
#include <memory> // for unique_ptr
#include <string> // for string
#include <system_error> // for error_code
#include <utility> // for move
namespace xgboost::collective {
namespace detail {
@ -48,48 +46,19 @@ struct ResultImpl {
return cur_eq;
}
[[nodiscard]] std::string Report() {
std::stringstream ss;
ss << "\n- " << this->message;
if (this->errc != std::error_code{}) {
ss << " system error:" << this->errc.message();
}
[[nodiscard]] std::string Report() const;
[[nodiscard]] std::error_code Code() const;
auto ptr = prev.get();
while (ptr) {
ss << "\n- ";
ss << ptr->message;
if (ptr->errc != std::error_code{}) {
ss << " " << ptr->errc.message();
}
ptr = ptr->prev.get();
}
return ss.str();
}
[[nodiscard]] auto Code() const {
// Find the root error.
std::stack<ResultImpl const*> stack;
auto ptr = this;
while (ptr) {
stack.push(ptr);
if (ptr->prev) {
ptr = ptr->prev.get();
} else {
break;
}
}
while (!stack.empty()) {
auto frame = stack.top();
stack.pop();
if (frame->errc != std::error_code{}) {
return frame->errc;
}
}
return std::error_code{};
}
void Concat(std::unique_ptr<ResultImpl> rhs);
};
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
#define __builtin_FILE() nullptr
#define __builtin_LINE() (-1)
std::string MakeMsg(std::string&& msg, char const*, std::int32_t);
#else
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line);
#endif
} // namespace detail
/**
@ -131,8 +100,21 @@ struct Result {
}
return *impl_ == *that.impl_;
}
friend Result operator+(Result&& lhs, Result&& rhs);
};
[[nodiscard]] inline Result operator+(Result&& lhs, Result&& rhs) {
if (lhs.OK()) {
return std::forward<Result>(rhs);
}
if (rhs.OK()) {
return std::forward<Result>(lhs);
}
lhs.impl_->Concat(std::move(rhs.impl_));
return std::forward<Result>(lhs);
}
/**
* @brief Return success.
*/
@ -140,38 +122,43 @@ struct Result {
/**
* @brief Return failure.
*/
[[nodiscard]] inline auto Fail(std::string msg) { return Result{std::move(msg)}; }
[[nodiscard]] inline auto Fail(std::string msg, char const* file = __builtin_FILE(),
std::int32_t line = __builtin_LINE()) {
return Result{detail::MakeMsg(std::move(msg), file, line)};
}
/**
* @brief Return failure with `errno`.
*/
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc) {
return Result{std::move(msg), std::move(errc)};
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc,
char const* file = __builtin_FILE(),
std::int32_t line = __builtin_LINE()) {
return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc)};
}
/**
* @brief Return failure with a previous error.
*/
[[nodiscard]] inline auto Fail(std::string msg, Result&& prev) {
return Result{std::move(msg), std::forward<Result>(prev)};
[[nodiscard]] inline auto Fail(std::string msg, Result&& prev, char const* file = __builtin_FILE(),
std::int32_t line = __builtin_LINE()) {
return Result{detail::MakeMsg(std::move(msg), file, line), std::forward<Result>(prev)};
}
/**
* @brief Return failure with a previous error and a new `errno`.
*/
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev) {
return Result{std::move(msg), std::move(errc), std::forward<Result>(prev)};
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev,
char const* file = __builtin_FILE(),
std::int32_t line = __builtin_LINE()) {
return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc),
std::forward<Result>(prev)};
}
// We don't have monad, a simple helper would do.
template <typename Fn>
[[nodiscard]] Result operator<<(Result&& r, Fn&& fn) {
[[nodiscard]] std::enable_if_t<std::is_invocable_v<Fn>, Result> operator<<(Result&& r, Fn&& fn) {
if (!r.OK()) {
return std::forward<Result>(r);
}
return fn();
}
inline void SafeColl(Result const& rc) {
if (!rc.OK()) {
LOG(FATAL) << rc.Report();
}
}
void SafeColl(Result const& rc);
} // namespace xgboost::collective

View File

@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, XGBoost Contributors
* Copyright (c) 2022-2024, XGBoost Contributors
*/
#pragma once
@ -12,7 +12,6 @@
#include <cstddef> // std::size_t
#include <cstdint> // std::int32_t, std::uint16_t
#include <cstring> // memset
#include <limits> // std::numeric_limits
#include <string> // std::string
#include <system_error> // std::error_code, std::system_category
#include <utility> // std::swap
@ -125,6 +124,21 @@ inline std::int32_t CloseSocket(SocketT fd) {
#endif
}
inline std::int32_t ShutdownSocket(SocketT fd) {
#if defined(_WIN32)
auto rc = shutdown(fd, SD_BOTH);
if (rc != 0 && LastError() == WSANOTINITIALISED) {
return 0;
}
#else
auto rc = shutdown(fd, SHUT_RDWR);
if (rc != 0 && LastError() == ENOTCONN) {
return 0;
}
#endif
return rc;
}
inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
#ifdef _WIN32
return errsv == WSAEWOULDBLOCK;
@ -468,19 +482,30 @@ class TCPSocket {
*addr = SockAddress{SockAddrV6{caddr}};
*out = TCPSocket{newfd};
}
// On MacOS, this is automatically set to async socket if the parent socket is async
// We make sure all socket are blocking by default.
//
// On Windows, a closed socket is returned during shutdown. We guard against it when
// setting non-blocking.
if (!out->IsClosed()) {
return out->NonBlocking(false);
}
return Success();
}
~TCPSocket() {
if (!IsClosed()) {
Close();
auto rc = this->Close();
if (!rc.OK()) {
LOG(WARNING) << rc.Report();
}
}
}
TCPSocket(TCPSocket const &that) = delete;
TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); }
TCPSocket &operator=(TCPSocket const &that) = delete;
TCPSocket &operator=(TCPSocket &&that) {
TCPSocket &operator=(TCPSocket &&that) noexcept(true) {
std::swap(this->handle_, that.handle_);
return *this;
}
@ -489,36 +514,49 @@ class TCPSocket {
*/
[[nodiscard]] HandleT const &Handle() const { return handle_; }
/**
* \brief Listen to incoming requests. Should be called after bind.
* @brief Listen to incoming requests. Should be called after bind.
*/
void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); }
[[nodiscard]] Result Listen(std::int32_t backlog = 16) {
if (listen(handle_, backlog) != 0) {
return system::FailWithCode("Failed to listen.");
}
return Success();
}
/**
* \brief Bind socket to INADDR_ANY, return the port selected by the OS.
* @brief Bind socket to INADDR_ANY, return the port selected by the OS.
*/
[[nodiscard]] in_port_t BindHost() {
[[nodiscard]] Result BindHost(std::int32_t* p_out) {
// Use int32 instead of in_port_t for consistency. We take port as parameter from
// users using other languages, the port is usually stored and passed around as int.
if (Domain() == SockDomain::kV6) {
auto addr = SockAddrV6::InaddrAny();
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
xgboost_CHECK_SYS_CALL(
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
return system::FailWithCode("bind failed.");
}
sockaddr_in6 res_addr;
socklen_t addrlen = sizeof(res_addr);
xgboost_CHECK_SYS_CALL(
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
return ntohs(res_addr.sin6_port);
if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
return system::FailWithCode("getsockname failed.");
}
*p_out = ntohs(res_addr.sin6_port);
} else {
auto addr = SockAddrV4::InaddrAny();
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
xgboost_CHECK_SYS_CALL(
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
return system::FailWithCode("bind failed.");
}
sockaddr_in res_addr;
socklen_t addrlen = sizeof(res_addr);
xgboost_CHECK_SYS_CALL(
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
return ntohs(res_addr.sin_port);
if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
return system::FailWithCode("getsockname failed.");
}
*p_out = ntohs(res_addr.sin_port);
}
return Success();
}
[[nodiscard]] auto Port() const {
@ -631,26 +669,49 @@ class TCPSocket {
*/
std::size_t Send(StringView str);
/**
* \brief Receive string, format is matched with the Python socket wrapper in RABIT.
* @brief Receive string, format is matched with the Python socket wrapper in RABIT.
*/
std::size_t Recv(std::string *p_str);
[[nodiscard]] Result Recv(std::string *p_str);
/**
* \brief Close the socket, called automatically in destructor if the socket is not closed.
* @brief Close the socket, called automatically in destructor if the socket is not closed.
*/
void Close() {
[[nodiscard]] Result Close() {
if (InvalidSocket() != handle_) {
#if defined(_WIN32)
auto rc = system::CloseSocket(handle_);
#if defined(_WIN32)
// it's possible that we close TCP sockets after finalizing WSA due to detached thread.
if (rc != 0 && system::LastError() != WSANOTINITIALISED) {
system::ThrowAtError("close", rc);
return system::FailWithCode("Failed to close the socket.");
}
#else
xgboost_CHECK_SYS_CALL(system::CloseSocket(handle_), 0);
if (rc != 0) {
return system::FailWithCode("Failed to close the socket.");
}
#endif
handle_ = InvalidSocket();
}
return Success();
}
/**
* @brief Call shutdown on the socket.
*/
[[nodiscard]] Result Shutdown() {
if (this->IsClosed()) {
return Success();
}
auto rc = system::ShutdownSocket(this->Handle());
#if defined(_WIN32)
// Windows cannot shutdown a socket if it's not connected.
if (rc == -1 && system::LastError() == WSAENOTCONN) {
return Success();
}
#endif
if (rc != 0) {
return system::FailWithCode("Failed to shutdown socket.");
}
return Success();
}
/**
* \brief Create a TCP socket on specified domain.
*/

View File

@ -19,7 +19,6 @@
#include <algorithm>
#include <limits>
#include <memory>
#include <numeric>
#include <string>
#include <utility>
#include <vector>
@ -137,14 +136,6 @@ class MetaInfo {
* \param fo The output stream.
*/
void SaveBinary(dmlc::Stream* fo) const;
/*!
* \brief Set information in the meta info.
* \param key The key of the information.
* \param dptr The data pointer of the source array.
* \param dtype The type of the source data.
* \param num Number of elements in the source array.
*/
void SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype, size_t num);
/*!
* \brief Set information in the meta info with array interface.
* \param key The key of the information.
@ -517,10 +508,6 @@ class DMatrix {
DMatrix() = default;
/*! \brief meta information of the dataset */
virtual MetaInfo& Info() = 0;
virtual void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
auto const& ctx = *this->Ctx();
this->Info().SetInfo(ctx, key, dptr, dtype, num);
}
virtual void SetInfo(const char* key, std::string const& interface_str) {
auto const& ctx = *this->Ctx();
this->Info().SetInfo(ctx, key, StringView{interface_str});

View File

@ -190,13 +190,14 @@ constexpr auto ArrToTuple(T (&arr)[N]) {
// uint division optimization inspired by the CIndexer in cupy. Division operation is
// slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64
// bit when the index is smaller, then try to avoid division when it's exp of 2.
template <typename I, int32_t D>
template <typename I, std::int32_t D>
LINALG_HD auto UnravelImpl(I idx, common::Span<size_t const, D> shape) {
size_t index[D]{0};
std::size_t index[D]{0};
static_assert(std::is_signed<decltype(D)>::value,
"Don't change the type without changing the for loop.");
auto const sptr = shape.data();
for (int32_t dim = D; --dim > 0;) {
auto s = static_cast<std::remove_const_t<std::remove_reference_t<I>>>(shape[dim]);
auto s = static_cast<std::remove_const_t<std::remove_reference_t<I>>>(sptr[dim]);
if (s & (s - 1)) {
auto t = idx / s;
index[dim] = idx - t * s;
@ -745,6 +746,14 @@ auto ArrayInterfaceStr(TensorView<T, D> const &t) {
return str;
}
template <typename T>
auto Make1dInterface(T const *vec, std::size_t len) {
Context ctx;
auto t = linalg::MakeTensorView(&ctx, common::Span{vec, len}, len);
auto str = linalg::ArrayInterfaceStr(t);
return str;
}
/**
* \brief A tensor storage. To use it for other functionality like slicing one needs to
* obtain a view first. This way we can use it on both host and device.

View File

@ -30,9 +30,8 @@
#define XGBOOST_SPAN_H_
#include <xgboost/base.h>
#include <xgboost/logging.h>
#include <cinttypes> // size_t
#include <cstddef> // size_t
#include <cstdio>
#include <iterator>
#include <limits> // numeric_limits
@ -75,8 +74,7 @@
#endif // defined(_MSC_VER) && _MSC_VER < 1910
namespace xgboost {
namespace common {
namespace xgboost::common {
#if defined(__CUDA_ARCH__)
// Usual logging facility is not available inside device code.
@ -744,8 +742,8 @@ class IterSpan {
return it_ + size();
}
};
} // namespace common
} // namespace xgboost
} // namespace xgboost::common
#if defined(_MSC_VER) &&_MSC_VER < 1910
#undef constexpr

View File

@ -33,7 +33,7 @@
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<flink.version>1.18.0</flink.version>
<flink.version>1.19.0</flink.version>
<junit.version>4.13.2</junit.version>
<spark.version>3.4.1</spark.version>
<spark.version.gpu>3.4.1</spark.version.gpu>
@ -46,9 +46,9 @@
<cudf.version>23.12.1</cudf.version>
<spark.rapids.version>23.12.1</spark.rapids.version>
<cudf.classifier>cuda12</cudf.classifier>
<scalatest.version>3.2.18</scalatest.version>
<scala-collection-compat.version>2.12.0</scala-collection-compat.version>
<use.hip>OFF</use.hip>
<scalatest.version>3.2.17</scalatest.version>
<scala-collection-compat.version>2.11.0</scala-collection-compat.version>
<!-- SPARK-36796 for JDK-17 test-->
<extraJavaTestArgs>
@ -124,7 +124,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>3.3.0</version>
<version>3.4.0</version>
<executions>
<execution>
<id>empty-javadoc-jar</id>
@ -153,7 +153,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-gpg-plugin</artifactId>
<version>3.1.0</version>
<version>3.2.3</version>
<executions>
<execution>
<id>sign-artifacts</id>
@ -167,7 +167,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId>
<version>3.3.0</version>
<version>3.3.1</version>
<executions>
<execution>
<id>attach-sources</id>
@ -205,7 +205,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<version>3.6.0</version>
<version>3.7.1</version>
<configuration>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
@ -446,7 +446,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>3.2.2</version>
<version>3.2.5</version>
<configuration>
<skipTests>false</skipTests>
<useSystemClassLoader>false</useSystemClassLoader>
@ -488,12 +488,12 @@
<dependency>
<groupId>com.esotericsoftware</groupId>
<artifactId>kryo</artifactId>
<version>5.5.0</version>
<version>5.6.0</version>
</dependency>
<dependency>
<groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId>
<version>1.3.0</version>
<version>1.3.1</version>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>

View File

@ -72,7 +72,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<version>3.6.2</version>
<version>3.6.3</version>
<configuration>
<show>protected</show>
<nohelp>true</nohelp>
@ -88,7 +88,7 @@
<plugin>
<artifactId>exec-maven-plugin</artifactId>
<groupId>org.codehaus.mojo</groupId>
<version>3.1.0</version>
<version>3.2.0</version>
<executions>
<execution>
<id>native</id>
@ -115,7 +115,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>3.3.0</version>
<version>3.4.0</version>
<executions>
<execution>
<goals>

View File

@ -22,7 +22,7 @@ pom_template = """
<scala.version>{scala_version}</scala.version>
<scalatest.version>3.2.15</scalatest.version>
<scala.binary.version>{scala_binary_version}</scala.binary.version>
<kryo.version>5.5.0</kryo.version>
<kryo.version>5.6.0</kryo.version>
</properties>
<dependencies>

View File

@ -60,7 +60,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<version>3.6.2</version>
<version>3.6.3</version>
<configuration>
<show>protected</show>
<nohelp>true</nohelp>
@ -76,7 +76,7 @@
<plugin>
<artifactId>exec-maven-plugin</artifactId>
<groupId>org.codehaus.mojo</groupId>
<version>3.1.0</version>
<version>3.2.0</version>
<executions>
<execution>
<id>native</id>
@ -99,7 +99,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>3.3.0</version>
<version>3.4.0</version>
<executions>
<execution>
<goals>

View File

@ -408,7 +408,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatI
jfloat* array = jenv->GetFloatArrayElements(jarray, NULL);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
int ret = XGDMatrixSetFloatInfo(handle, field, (float const *)array, len);
auto str = xgboost::linalg::Make1dInterface(array, len);
int ret = XGDMatrixSetInfoFromInterface(handle, field, str.c_str());
JVM_CHECK_CALL(ret);
//release
if (field) jenv->ReleaseStringUTFChars(jfield, field);
@ -427,7 +428,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntIn
const char* field = jenv->GetStringUTFChars(jfield, 0);
jint* array = jenv->GetIntArrayElements(jarray, NULL);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
int ret = XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len);
auto str = xgboost::linalg::Make1dInterface(array, len);
int ret = XGDMatrixSetInfoFromInterface(handle, field, str.c_str());
JVM_CHECK_CALL(ret);
//release
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
@ -730,8 +732,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFr
if (jmargin) {
margin = jenv->GetFloatArrayElements(jmargin, nullptr);
JVM_CHECK_CALL(XGProxyDMatrixCreate(&proxy));
JVM_CHECK_CALL(
XGDMatrixSetFloatInfo(proxy, "base_margin", margin, jenv->GetArrayLength(jmargin)));
auto str = xgboost::linalg::Make1dInterface(margin, jenv->GetArrayLength(jmargin));
JVM_CHECK_CALL(XGDMatrixSetInfoFromInterface(proxy, "base_margin", str.c_str()));
}
bst_ulong const *out_shape;

View File

@ -89,19 +89,15 @@ Coll *FederatedColl::MakeCUDAVar() {
[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) {
if (comm.Rank() == root) {
return BroadcastImpl(comm, &this->sequence_number_, data, root);
} else {
return BroadcastImpl(comm, &this->sequence_number_, data, root);
}
return BroadcastImpl(comm, &this->sequence_number_, data, root);
}
[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data,
std::int64_t size) {
[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data) {
using namespace federated; // NOLINT
auto fed = dynamic_cast<FederatedComm const *>(&comm);
CHECK(fed);
auto stub = fed->Handle();
auto size = data.size_bytes() / comm.World();
auto offset = comm.Rank() * size;
auto segment = data.subspan(offset, size);

View File

@ -53,8 +53,7 @@ Coll *FederatedColl::MakeCUDAVar() {
};
}
[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data,
std::int64_t size) {
[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data) {
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
CHECK(cufed);
std::vector<std::int8_t> h_data(data.size());
@ -63,7 +62,7 @@ Coll *FederatedColl::MakeCUDAVar() {
return GetCUDAResult(
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
} << [&] {
return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()}, size);
return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()});
} << [&] {
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
cudaMemcpyHostToDevice, cufed->Stream()));

View File

@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*/
#include "../../src/collective/comm.h" // for Comm, Coll
#include "federated_coll.h" // for FederatedColl
@ -16,8 +16,7 @@ class CUDAFederatedColl : public Coll {
ArrayInterfaceHandler::Type type, Op op) override;
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) override;
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
std::int64_t size) override;
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data) override;
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,

View File

@ -1,12 +1,9 @@
/**
* Copyright 2023, XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*/
#pragma once
#include "../../src/collective/coll.h" // for Coll
#include "../../src/collective/comm.h" // for Comm
#include "../../src/common/io.h" // for ReadAll
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
class FederatedColl : public Coll {
@ -20,8 +17,7 @@ class FederatedColl : public Coll {
ArrayInterfaceHandler::Type type, Op op) override;
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) override;
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
std::int64_t) override;
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data) override;
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,

View File

@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
@ -9,7 +9,6 @@
#include "../../src/common/device_helpers.cuh" // for CUDAStreamView
#include "federated_comm.h" // for FederatedComm
#include "xgboost/context.h" // for Context
#include "xgboost/logging.h"
namespace xgboost::collective {
class CUDAFederatedComm : public FederatedComm {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*/
#pragma once
@ -11,7 +11,6 @@
#include <string> // for string
#include "../../src/collective/comm.h" // for HostComm
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h"
namespace xgboost::collective {
@ -51,6 +50,10 @@ class FederatedComm : public HostComm {
std::int32_t rank) {
this->Init(host, port, world, rank, {}, {}, {});
}
[[nodiscard]] Result Shutdown() final {
this->ResetState();
return Success();
}
~FederatedComm() override { stub_.reset(); }
[[nodiscard]] std::shared_ptr<Channel> Chan(std::int32_t) const override {
@ -65,5 +68,13 @@ class FederatedComm : public HostComm {
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
/**
* @brief Get a string ID for the current process.
*/
[[nodiscard]] Result ProcessorName(std::string* out) const final {
auto rank = this->Rank();
*out = "rank:" + std::to_string(rank);
return Success();
};
};
} // namespace xgboost::collective

View File

@ -1,22 +1,18 @@
/**
* Copyright 2022-2023, XGBoost contributors
* Copyright 2022-2024, XGBoost contributors
*/
#pragma once
#include <federated.old.grpc.pb.h>
#include <cstdint> // for int32_t
#include <future> // for future
#include "../../src/collective/in_memory_handler.h"
#include "../../src/collective/tracker.h" // for Tracker
#include "xgboost/collective/result.h" // for Result
namespace xgboost::federated {
class FederatedService final : public Federated::Service {
public:
explicit FederatedService(std::int32_t world_size)
: handler_{static_cast<std::size_t>(world_size)} {}
explicit FederatedService(std::int32_t world_size) : handler_{world_size} {}
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;

View File

@ -125,14 +125,14 @@ Result FederatedTracker::Shutdown() {
[[nodiscard]] Json FederatedTracker::WorkerArgs() const {
auto rc = this->WaitUntilReady();
CHECK(rc.OK()) << rc.Report();
SafeColl(rc);
std::string host;
rc = GetHostAddress(&host);
CHECK(rc.OK());
Json args{Object{}};
args["DMLC_TRACKER_URI"] = String{host};
args["DMLC_TRACKER_PORT"] = this->Port();
args["dmlc_tracker_uri"] = String{host};
args["dmlc_tracker_port"] = this->Port();
return args;
}
} // namespace xgboost::collective

View File

@ -17,8 +17,7 @@ namespace xgboost::collective {
namespace federated {
class FederatedService final : public Federated::Service {
public:
explicit FederatedService(std::int32_t world_size)
: handler_{static_cast<std::size_t>(world_size)} {}
explicit FederatedService(std::int32_t world_size) : handler_{world_size} {}
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;

View File

@ -0,0 +1,334 @@
/*!
* Copyright 2017-2023 by Contributors
* \file hist_util.cc
*/
#include <vector>
#include <limits>
#include <algorithm>
#include "../data/gradient_index.h"
#include "hist_util.h"
#include <CL/sycl.hpp>
namespace xgboost {
namespace sycl {
namespace common {
/*!
* \brief Fill histogram with zeroes
*/
template<typename GradientSumT>
void InitHist(::sycl::queue qu, GHistRow<GradientSumT, MemoryType::on_device>* hist,
size_t size, ::sycl::event* event) {
*event = qu.fill(hist->Begin(),
xgboost::detail::GradientPairInternal<GradientSumT>(), size, *event);
}
template void InitHist(::sycl::queue qu,
GHistRow<float, MemoryType::on_device>* hist,
size_t size, ::sycl::event* event);
template void InitHist(::sycl::queue qu,
GHistRow<double, MemoryType::on_device>* hist,
size_t size, ::sycl::event* event);
/*!
* \brief Compute Subtraction: dst = src1 - src2
*/
template<typename GradientSumT>
::sycl::event SubtractionHist(::sycl::queue qu,
GHistRow<GradientSumT, MemoryType::on_device>* dst,
const GHistRow<GradientSumT, MemoryType::on_device>& src1,
const GHistRow<GradientSumT, MemoryType::on_device>& src2,
size_t size, ::sycl::event event_priv) {
GradientSumT* pdst = reinterpret_cast<GradientSumT*>(dst->Data());
const GradientSumT* psrc1 = reinterpret_cast<const GradientSumT*>(src1.DataConst());
const GradientSumT* psrc2 = reinterpret_cast<const GradientSumT*>(src2.DataConst());
auto event_final = qu.submit([&](::sycl::handler& cgh) {
cgh.depends_on(event_priv);
cgh.parallel_for<>(::sycl::range<1>(2 * size), [pdst, psrc1, psrc2](::sycl::item<1> pid) {
const size_t i = pid.get_id(0);
pdst[i] = psrc1[i] - psrc2[i];
});
});
return event_final;
}
template ::sycl::event SubtractionHist(::sycl::queue qu,
GHistRow<float, MemoryType::on_device>* dst,
const GHistRow<float, MemoryType::on_device>& src1,
const GHistRow<float, MemoryType::on_device>& src2,
size_t size, ::sycl::event event_priv);
template ::sycl::event SubtractionHist(::sycl::queue qu,
GHistRow<double, MemoryType::on_device>* dst,
const GHistRow<double, MemoryType::on_device>& src1,
const GHistRow<double, MemoryType::on_device>& src2,
size_t size, ::sycl::event event_priv);
// Kernel with buffer using
template<typename FPType, typename BinIdxType, bool isDense>
::sycl::event BuildHistKernel(::sycl::queue qu,
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat,
GHistRow<FPType, MemoryType::on_device>* hist,
GHistRow<FPType, MemoryType::on_device>* hist_buffer,
::sycl::event event_priv) {
const size_t size = row_indices.Size();
const size_t* rid = row_indices.begin;
const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride;
const GradientPair::ValueT* pgh =
reinterpret_cast<const GradientPair::ValueT*>(gpair_device.DataConst());
const BinIdxType* gradient_index = gmat.index.data<BinIdxType>();
const uint32_t* offsets = gmat.index.Offset();
FPType* hist_data = reinterpret_cast<FPType*>(hist->Data());
const size_t nbins = gmat.nbins;
const size_t max_work_group_size =
qu.get_device().get_info<::sycl::info::device::max_work_group_size>();
const size_t work_group_size = n_columns < max_work_group_size ? n_columns : max_work_group_size;
const size_t max_nblocks = hist_buffer->Size() / (nbins * 2);
const size_t min_block_size = 128;
size_t nblocks = std::min(max_nblocks, size / min_block_size + !!(size % min_block_size));
const size_t block_size = size / nblocks + !!(size % nblocks);
FPType* hist_buffer_data = reinterpret_cast<FPType*>(hist_buffer->Data());
auto event_fill = qu.fill(hist_buffer_data, FPType(0), nblocks * nbins * 2, event_priv);
auto event_main = qu.submit([&](::sycl::handler& cgh) {
cgh.depends_on(event_fill);
cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(nblocks, work_group_size),
::sycl::range<2>(1, work_group_size)),
[=](::sycl::nd_item<2> pid) {
size_t block = pid.get_global_id(0);
size_t feat = pid.get_global_id(1);
FPType* hist_local = hist_buffer_data + block * nbins * 2;
for (size_t idx = 0; idx < block_size; ++idx) {
size_t i = block * block_size + idx;
if (i < size) {
const size_t icol_start = n_columns * rid[i];
const size_t idx_gh = rid[i];
pid.barrier(::sycl::access::fence_space::local_space);
const BinIdxType* gr_index_local = gradient_index + icol_start;
for (size_t j = feat; j < n_columns; j += work_group_size) {
uint32_t idx_bin = static_cast<uint32_t>(gr_index_local[j]);
if constexpr (isDense) {
idx_bin += offsets[j];
}
if (idx_bin < nbins) {
hist_local[2 * idx_bin] += pgh[2 * idx_gh];
hist_local[2 * idx_bin+1] += pgh[2 * idx_gh+1];
}
}
}
}
});
});
auto event_save = qu.submit([&](::sycl::handler& cgh) {
cgh.depends_on(event_main);
cgh.parallel_for<>(::sycl::range<1>(nbins), [=](::sycl::item<1> pid) {
size_t idx_bin = pid.get_id(0);
FPType gsum = 0.0f;
FPType hsum = 0.0f;
for (size_t j = 0; j < nblocks; ++j) {
gsum += hist_buffer_data[j * nbins * 2 + 2 * idx_bin];
hsum += hist_buffer_data[j * nbins * 2 + 2 * idx_bin + 1];
}
hist_data[2 * idx_bin] = gsum;
hist_data[2 * idx_bin + 1] = hsum;
});
});
return event_save;
}
// Kernel with atomic using
template<typename FPType, typename BinIdxType, bool isDense>
::sycl::event BuildHistKernel(::sycl::queue qu,
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat,
GHistRow<FPType, MemoryType::on_device>* hist,
::sycl::event event_priv) {
const size_t size = row_indices.Size();
const size_t* rid = row_indices.begin;
const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride;
const GradientPair::ValueT* pgh =
reinterpret_cast<const GradientPair::ValueT*>(gpair_device.DataConst());
const BinIdxType* gradient_index = gmat.index.data<BinIdxType>();
const uint32_t* offsets = gmat.index.Offset();
FPType* hist_data = reinterpret_cast<FPType*>(hist->Data());
const size_t nbins = gmat.nbins;
const size_t max_work_group_size =
qu.get_device().get_info<::sycl::info::device::max_work_group_size>();
const size_t feat_local = n_columns < max_work_group_size ? n_columns : max_work_group_size;
auto event_fill = qu.fill(hist_data, FPType(0), nbins * 2, event_priv);
auto event_main = qu.submit([&](::sycl::handler& cgh) {
cgh.depends_on(event_fill);
cgh.parallel_for<>(::sycl::range<2>(size, feat_local),
[=](::sycl::item<2> pid) {
size_t i = pid.get_id(0);
size_t feat = pid.get_id(1);
const size_t icol_start = n_columns * rid[i];
const size_t idx_gh = rid[i];
const BinIdxType* gr_index_local = gradient_index + icol_start;
for (size_t j = feat; j < n_columns; j += feat_local) {
uint32_t idx_bin = static_cast<uint32_t>(gr_index_local[j]);
if constexpr (isDense) {
idx_bin += offsets[j];
}
if (idx_bin < nbins) {
AtomicRef<FPType> gsum(hist_data[2 * idx_bin]);
AtomicRef<FPType> hsum(hist_data[2 * idx_bin + 1]);
gsum.fetch_add(pgh[2 * idx_gh]);
hsum.fetch_add(pgh[2 * idx_gh + 1]);
}
}
});
});
return event_main;
}
template<typename FPType, typename BinIdxType>
::sycl::event BuildHistDispatchKernel(
::sycl::queue qu,
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat,
GHistRow<FPType, MemoryType::on_device>* hist,
bool isDense,
GHistRow<FPType, MemoryType::on_device>* hist_buffer,
::sycl::event events_priv,
bool force_atomic_use) {
const size_t size = row_indices.Size();
const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride;
const size_t nbins = gmat.nbins;
// max cycle size, while atomics are still effective
const size_t max_cycle_size_atomics = nbins;
const size_t cycle_size = size;
// TODO(razdoburdin): replace the add-hock dispatching criteria by more sutable one
bool use_atomic = (size < nbins) || (gmat.max_num_bins == gmat.nbins / n_columns);
// force_atomic_use flag is used only for testing
use_atomic = use_atomic || force_atomic_use;
if (!use_atomic) {
if (isDense) {
return BuildHistKernel<FPType, BinIdxType, true>(qu, gpair_device, row_indices,
gmat, hist, hist_buffer,
events_priv);
} else {
return BuildHistKernel<FPType, uint32_t, false>(qu, gpair_device, row_indices,
gmat, hist, hist_buffer,
events_priv);
}
} else {
if (isDense) {
return BuildHistKernel<FPType, BinIdxType, true>(qu, gpair_device, row_indices,
gmat, hist, events_priv);
} else {
return BuildHistKernel<FPType, uint32_t, false>(qu, gpair_device, row_indices,
gmat, hist, events_priv);
}
}
}
template<typename FPType>
::sycl::event BuildHistKernel(::sycl::queue qu,
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat, const bool isDense,
GHistRow<FPType, MemoryType::on_device>* hist,
GHistRow<FPType, MemoryType::on_device>* hist_buffer,
::sycl::event event_priv,
bool force_atomic_use) {
const bool is_dense = isDense;
switch (gmat.index.GetBinTypeSize()) {
case BinTypeSize::kUint8BinsTypeSize:
return BuildHistDispatchKernel<FPType, uint8_t>(qu, gpair_device, row_indices,
gmat, hist, is_dense, hist_buffer,
event_priv, force_atomic_use);
break;
case BinTypeSize::kUint16BinsTypeSize:
return BuildHistDispatchKernel<FPType, uint16_t>(qu, gpair_device, row_indices,
gmat, hist, is_dense, hist_buffer,
event_priv, force_atomic_use);
break;
case BinTypeSize::kUint32BinsTypeSize:
return BuildHistDispatchKernel<FPType, uint32_t>(qu, gpair_device, row_indices,
gmat, hist, is_dense, hist_buffer,
event_priv, force_atomic_use);
break;
default:
CHECK(false); // no default behavior
}
}
template <typename GradientSumT>
::sycl::event GHistBuilder<GradientSumT>::BuildHist(
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix &gmat,
GHistRowT<MemoryType::on_device>* hist,
bool isDense,
GHistRowT<MemoryType::on_device>* hist_buffer,
::sycl::event event_priv,
bool force_atomic_use) {
return BuildHistKernel<GradientSumT>(qu_, gpair_device, row_indices, gmat,
isDense, hist, hist_buffer, event_priv,
force_atomic_use);
}
template
::sycl::event GHistBuilder<float>::BuildHist(
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat,
GHistRow<float, MemoryType::on_device>* hist,
bool isDense,
GHistRow<float, MemoryType::on_device>* hist_buffer,
::sycl::event event_priv,
bool force_atomic_use);
template
::sycl::event GHistBuilder<double>::BuildHist(
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat,
GHistRow<double, MemoryType::on_device>* hist,
bool isDense,
GHistRow<double, MemoryType::on_device>* hist_buffer,
::sycl::event event_priv,
bool force_atomic_use);
template<typename GradientSumT>
void GHistBuilder<GradientSumT>::SubtractionTrick(GHistRowT<MemoryType::on_device>* self,
const GHistRowT<MemoryType::on_device>& sibling,
const GHistRowT<MemoryType::on_device>& parent) {
const size_t size = self->Size();
CHECK_EQ(sibling.Size(), size);
CHECK_EQ(parent.Size(), size);
SubtractionHist(qu_, self, parent, sibling, size, ::sycl::event());
}
template
void GHistBuilder<float>::SubtractionTrick(GHistRow<float, MemoryType::on_device>* self,
const GHistRow<float, MemoryType::on_device>& sibling,
const GHistRow<float, MemoryType::on_device>& parent);
template
void GHistBuilder<double>::SubtractionTrick(GHistRow<double, MemoryType::on_device>* self,
const GHistRow<double, MemoryType::on_device>& sibling,
const GHistRow<double, MemoryType::on_device>& parent);
} // namespace common
} // namespace sycl
} // namespace xgboost

View File

@ -0,0 +1,89 @@
/*!
* Copyright 2017-2023 by Contributors
* \file hist_util.h
*/
#ifndef PLUGIN_SYCL_COMMON_HIST_UTIL_H_
#define PLUGIN_SYCL_COMMON_HIST_UTIL_H_
#include <vector>
#include <unordered_map>
#include <memory>
#include "../data.h"
#include "row_set.h"
#include "../../src/common/hist_util.h"
#include "../data/gradient_index.h"
#include <CL/sycl.hpp>
namespace xgboost {
namespace sycl {
namespace common {
template<typename GradientSumT, MemoryType memory_type = MemoryType::shared>
using GHistRow = USMVector<xgboost::detail::GradientPairInternal<GradientSumT>, memory_type>;
using BinTypeSize = ::xgboost::common::BinTypeSize;
class ColumnMatrix;
/*!
* \brief Fill histogram with zeroes
*/
template<typename GradientSumT>
void InitHist(::sycl::queue qu,
GHistRow<GradientSumT, MemoryType::on_device>* hist,
size_t size, ::sycl::event* event);
/*!
* \brief Compute subtraction: dst = src1 - src2
*/
template<typename GradientSumT>
::sycl::event SubtractionHist(::sycl::queue qu,
GHistRow<GradientSumT, MemoryType::on_device>* dst,
const GHistRow<GradientSumT, MemoryType::on_device>& src1,
const GHistRow<GradientSumT, MemoryType::on_device>& src2,
size_t size, ::sycl::event event_priv);
/*!
* \brief Builder for histograms of gradient statistics
*/
template<typename GradientSumT>
class GHistBuilder {
public:
template<MemoryType memory_type = MemoryType::shared>
using GHistRowT = GHistRow<GradientSumT, memory_type>;
GHistBuilder() = default;
GHistBuilder(::sycl::queue qu, uint32_t nbins) : qu_{qu}, nbins_{nbins} {}
// Construct a histogram via histogram aggregation
::sycl::event BuildHist(const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat,
GHistRowT<MemoryType::on_device>* HistCollection,
bool isDense,
GHistRowT<MemoryType::on_device>* hist_buffer,
::sycl::event event,
bool force_atomic_use = false);
// Construct a histogram via subtraction trick
void SubtractionTrick(GHistRowT<MemoryType::on_device>* self,
const GHistRowT<MemoryType::on_device>& sibling,
const GHistRowT<MemoryType::on_device>& parent);
uint32_t GetNumBins() const {
return nbins_;
}
private:
/*! \brief Number of all bins over all features */
uint32_t nbins_ { 0 };
::sycl::queue qu_;
};
} // namespace common
} // namespace sycl
} // namespace xgboost
#endif // PLUGIN_SYCL_COMMON_HIST_UTIL_H_

View File

@ -0,0 +1,55 @@
/*!
* Copyright 2017-2024 by Contributors
* \file updater_quantile_hist.cc
*/
#include <vector>
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include "xgboost/tree_updater.h"
#pragma GCC diagnostic pop
#include "xgboost/logging.h"
#include "updater_quantile_hist.h"
#include "../data.h"
namespace xgboost {
namespace sycl {
namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_quantile_hist_sycl);
DMLC_REGISTER_PARAMETER(HistMakerTrainParam);
void QuantileHistMaker::Configure(const Args& args) {
const DeviceOrd device_spec = ctx_->Device();
qu_ = device_manager.GetQueue(device_spec);
param_.UpdateAllowUnknown(args);
hist_maker_param_.UpdateAllowUnknown(args);
}
void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param,
linalg::Matrix<GradientPair>* gpair,
DMatrix *dmat,
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) {
LOG(FATAL) << "Not Implemented yet";
}
bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data,
linalg::MatrixView<float> out_preds) {
LOG(FATAL) << "Not Implemented yet";
}
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker_sycl")
.describe("Grow tree using quantized histogram with SYCL.")
.set_body(
[](Context const* ctx, ObjInfo const * task) {
return new QuantileHistMaker(ctx, task);
});
} // namespace tree
} // namespace sycl
} // namespace xgboost

View File

@ -0,0 +1,91 @@
/*!
* Copyright 2017-2024 by Contributors
* \file updater_quantile_hist.h
*/
#ifndef PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_
#define PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_
#include <dmlc/timer.h>
#include <xgboost/tree_updater.h>
#include <vector>
#include "../data/gradient_index.h"
#include "../common/hist_util.h"
#include "../common/row_set.h"
#include "../common/partition_builder.h"
#include "split_evaluator.h"
#include "../device_manager.h"
#include "xgboost/data.h"
#include "xgboost/json.h"
#include "../../src/tree/constraints.h"
#include "../../src/common/random.h"
namespace xgboost {
namespace sycl {
namespace tree {
// training parameters specific to this algorithm
struct HistMakerTrainParam
: public XGBoostParameter<HistMakerTrainParam> {
bool single_precision_histogram = false;
// declare parameters
DMLC_DECLARE_PARAMETER(HistMakerTrainParam) {
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
"Use single precision to build histograms.");
}
};
/*! \brief construct a tree using quantized feature values with SYCL backend*/
class QuantileHistMaker: public TreeUpdater {
public:
QuantileHistMaker(Context const* ctx, ObjInfo const * task) :
TreeUpdater(ctx), task_{task} {
updater_monitor_.Init("SYCLQuantileHistMaker");
}
void Configure(const Args& args) override;
void Update(xgboost::tree::TrainParam const *param,
linalg::Matrix<GradientPair>* gpair,
DMatrix* dmat,
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& trees) override;
bool UpdatePredictionCache(const DMatrix* data,
linalg::MatrixView<float> out_preds) override;
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
FromJson(config.at("train_param"), &this->param_);
FromJson(config.at("sycl_hist_train_param"), &this->hist_maker_param_);
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["train_param"] = ToJson(param_);
out["sycl_hist_train_param"] = ToJson(hist_maker_param_);
}
char const* Name() const override {
return "grow_quantile_histmaker_sycl";
}
protected:
HistMakerTrainParam hist_maker_param_;
// training parameter
xgboost::tree::TrainParam param_;
xgboost::common::Monitor updater_monitor_;
::sycl::queue qu_;
DeviceManager device_manager;
ObjInfo const *task_{nullptr};
};
} // namespace tree
} // namespace sycl
} // namespace xgboost
#endif // PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_

View File

@ -909,9 +909,19 @@ def _transform_cudf_df(
enable_categorical: bool,
) -> Tuple[ctypes.c_void_p, list, Optional[FeatureNames], Optional[FeatureTypes]]:
try:
from cudf.api.types import is_categorical_dtype
from cudf.api.types import is_bool_dtype, is_categorical_dtype
except ImportError:
from cudf.utils.dtypes import is_categorical_dtype
from pandas.api.types import is_bool_dtype
# Work around https://github.com/dmlc/xgboost/issues/10181
if _is_cudf_ser(data):
if is_bool_dtype(data.dtype):
data = data.astype(np.uint8)
else:
data = data.astype(
{col: np.uint8 for col in data.select_dtypes(include="bool")}
)
if _is_cudf_ser(data):
dtypes = [data.dtype]

View File

@ -347,15 +347,14 @@ class _SparkXGBParams(
predict_params[param.name] = self.getOrDefault(param)
return predict_params
def _validate_gpu_params(self) -> None:
def _validate_gpu_params(
self, spark_version: str, conf: SparkConf, is_local: bool = False
) -> None:
"""Validate the gpu parameters and gpu configurations"""
if self._run_on_gpu():
ss = _get_spark_session()
sc = ss.sparkContext
if _is_local(sc):
# Support GPU training in Spark local mode is just for debugging
if is_local:
# Supporting GPU training in Spark local mode is just for debugging
# purposes, so it's okay for printing the below warning instead of
# checking the real gpu numbers and raising the exception.
get_logger(self.__class__.__name__).warning(
@ -364,33 +363,41 @@ class _SparkXGBParams(
self.getOrDefault(self.num_workers),
)
else:
executor_gpus = sc.getConf().get("spark.executor.resource.gpu.amount")
executor_gpus = conf.get("spark.executor.resource.gpu.amount")
if executor_gpus is None:
raise ValueError(
"The `spark.executor.resource.gpu.amount` is required for training"
" on GPU."
)
if not (
ss.version >= "3.4.0"
and _is_standalone_or_localcluster(sc.getConf())
gpu_per_task = conf.get("spark.task.resource.gpu.amount")
if gpu_per_task is not None and float(gpu_per_task) > 1.0:
get_logger(self.__class__.__name__).warning(
"The configuration assigns %s GPUs to each Spark task, but each "
"XGBoost training task only utilizes 1 GPU, which will lead to "
"unnecessary GPU waste",
gpu_per_task,
)
# For 3.5.1+, Spark supports task stage-level scheduling for
# Yarn/K8s/Standalone/Local cluster
# From 3.4.0 ~ 3.5.0, Spark only supports task stage-level scheduing for
# Standalone/Local cluster
# For spark below 3.4.0, Task stage-level scheduling is not supported.
#
# With stage-level scheduling, spark.task.resource.gpu.amount is not required
# to be set explicitly. Or else, spark.task.resource.gpu.amount is a must-have and
# must be set to 1.0
if spark_version < "3.4.0" or (
"3.4.0" <= spark_version < "3.5.1"
and not _is_standalone_or_localcluster(conf)
):
# We will enable stage-level scheduling in spark 3.4.0+ which doesn't
# require spark.task.resource.gpu.amount to be set explicitly
gpu_per_task = sc.getConf().get("spark.task.resource.gpu.amount")
if gpu_per_task is not None:
if float(gpu_per_task) < 1.0:
raise ValueError(
"XGBoost doesn't support GPU fractional configurations. "
"Please set `spark.task.resource.gpu.amount=spark.executor"
".resource.gpu.amount`"
)
if float(gpu_per_task) > 1.0:
get_logger(self.__class__.__name__).warning(
"%s GPUs for each Spark task is configured, but each "
"XGBoost training task uses only 1 GPU.",
gpu_per_task,
"XGBoost doesn't support GPU fractional configurations. Please set "
"`spark.task.resource.gpu.amount=spark.executor.resource.gpu."
"amount`. To enable GPU fractional configurations, you can try "
"standalone/localcluster with spark 3.4.0+ and"
"YARN/K8S with spark 3.5.1+"
)
else:
raise ValueError(
@ -475,7 +482,9 @@ class _SparkXGBParams(
"`pyspark.ml.linalg.Vector` type."
)
self._validate_gpu_params()
ss = _get_spark_session()
sc = ss.sparkContext
self._validate_gpu_params(ss.version, sc.getConf(), _is_local(sc))
def _run_on_gpu(self) -> bool:
"""If train or transform on the gpu according to the parameters"""
@ -925,10 +934,14 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
)
return True
if not _is_standalone_or_localcluster(conf):
if (
"3.4.0" <= spark_version < "3.5.1"
and not _is_standalone_or_localcluster(conf)
):
self.logger.info(
"Stage-level scheduling in xgboost requires spark standalone or "
"local-cluster mode"
"For %s, Stage-level scheduling in xgboost requires spark standalone "
"or local-cluster mode",
spark_version,
)
return True
@ -980,7 +993,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
"""Try to enable stage-level scheduling"""
ss = _get_spark_session()
conf = ss.sparkContext.getConf()
if self._skip_stage_level_scheduling(ss.version, conf):
if _is_local(ss.sparkContext) or self._skip_stage_level_scheduling(
ss.version, conf
):
return rdd
# executor_cores will not be None
@ -1052,6 +1067,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
dev_ordinal = None
use_qdm = _can_use_qdm(booster_params.get("tree_method", None))
verbosity = booster_params.get("verbosity", 1)
msg = "Training on CPUs"
if run_on_gpu:
dev_ordinal = (
@ -1089,15 +1105,16 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
evals_result: Dict[str, Any] = {}
with CommunicatorContext(context, **_rabit_args):
dtrain, dvalid = create_dmatrix_from_partitions(
pandas_df_iter,
feature_prop.features_cols_names,
dev_ordinal,
use_qdm,
dmatrix_kwargs,
enable_sparse_data_optim=feature_prop.enable_sparse_data_optim,
has_validation_col=feature_prop.has_validation_col,
)
with xgboost.config_context(verbosity=verbosity):
dtrain, dvalid = create_dmatrix_from_partitions(
pandas_df_iter,
feature_prop.features_cols_names,
dev_ordinal,
use_qdm,
dmatrix_kwargs,
enable_sparse_data_optim=feature_prop.enable_sparse_data_optim,
has_validation_col=feature_prop.has_validation_col,
)
if dvalid is not None:
dval = [(dtrain, "training"), (dvalid, "validation")]
else:

View File

@ -14,7 +14,8 @@ import pyspark
from pyspark import BarrierTaskContext, SparkConf, SparkContext, SparkFiles, TaskContext
from pyspark.sql.session import SparkSession
from xgboost import Booster, XGBModel, collective
from xgboost import Booster, XGBModel
from xgboost.collective import CommunicatorContext as CCtx
from xgboost.tracker import RabitTracker
@ -42,22 +43,12 @@ def _get_default_params_from_func(
return filtered_params_dict
class CommunicatorContext:
"""A context controlling collective communicator initialization and finalization.
This isn't specificially necessary (note Part 3), but it is more understandable
coding-wise.
"""
class CommunicatorContext(CCtx): # pylint: disable=too-few-public-methods
"""Context with PySpark specific task ID."""
def __init__(self, context: BarrierTaskContext, **args: Any) -> None:
self.args = args
self.args["DMLC_TASK_ID"] = str(context.partitionId())
def __enter__(self) -> None:
collective.init(**self.args)
def __exit__(self, *args: Any) -> None:
collective.finalize()
args["DMLC_TASK_ID"] = str(context.partitionId())
super().__init__(**args)
def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:

View File

@ -429,8 +429,8 @@ def make_categorical(
categories = np.arange(0, n_categories)
for col in df.columns:
if rng.binomial(1, cat_ratio, size=1)[0] == 1:
df.loc[:, col] = df[col].astype("category")
df.loc[:, col] = df[col].cat.set_categories(categories)
df[col] = df[col].astype("category")
df[col] = df[col].cat.set_categories(categories)
if sparsity > 0.0:
for i in range(n_features):

View File

@ -100,6 +100,24 @@ std::enable_if_t<std::is_integral_v<E>, xgboost::collective::Result> PollError(E
if ((revents & POLLNVAL) != 0) {
return xgboost::system::FailWithCode("Invalid polling request.");
}
if ((revents & POLLHUP) != 0) {
// Excerpt from the Linux manual:
//
// Note that when reading from a channel such as a pipe or a stream socket, this event
// merely indicates that the peer closed its end of the channel.Subsequent reads from
// the channel will return 0 (end of file) only after all outstanding data in the
// channel has been consumed.
//
// We don't usually have a barrier for exiting workers, it's normal to have one end
// exit while the other still reading data.
return xgboost::collective::Success();
}
#if defined(POLLRDHUP)
// Linux only flag
if ((revents & POLLRDHUP) != 0) {
return xgboost::system::FailWithCode("Poll hung up on the other end.");
}
#endif // defined(POLLRDHUP)
return xgboost::collective::Success();
}
@ -179,9 +197,11 @@ struct PollHelper {
}
std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout);
if (ret == 0) {
return xgboost::collective::Fail("Poll timeout.", std::make_error_code(std::errc::timed_out));
return xgboost::collective::Fail(
"Poll timeout:" + std::to_string(timeout.count()) + " seconds.",
std::make_error_code(std::errc::timed_out));
} else if (ret < 0) {
return xgboost::system::FailWithCode("Poll failed.");
return xgboost::system::FailWithCode("Poll failed, nfds:" + std::to_string(fdset.size()));
}
for (auto& pfd : fdset) {

View File

@ -132,7 +132,7 @@ bool AllreduceBase::Shutdown() {
try {
for (auto &all_link : all_links) {
if (!all_link.sock.IsClosed()) {
all_link.sock.Close();
SafeColl(all_link.sock.Close());
}
}
all_links.clear();
@ -146,7 +146,7 @@ bool AllreduceBase::Shutdown() {
LOG(FATAL) << rc.Report();
}
tracker.Send(xgboost::StringView{"shutdown"});
tracker.Close();
SafeColl(tracker.Close());
xgboost::system::SocketFinalize();
return true;
} catch (std::exception const &e) {
@ -167,7 +167,7 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
tracker.Send(xgboost::StringView{"print"});
tracker.Send(xgboost::StringView{msg});
tracker.Close();
SafeColl(tracker.Close());
}
// util to parse data with unit suffix
@ -332,15 +332,15 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())};
// create listening socket
int port = sock_listen.BindHost();
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
sock_listen.Listen();
std::int32_t port{0};
SafeColl(sock_listen.BindHost(&port));
SafeColl(sock_listen.Listen());
// get number of to connect and number of to accept nodes from tracker
int num_conn, num_accept, num_error = 1;
do {
for (auto & all_link : all_links) {
all_link.sock.Close();
SafeColl(all_link.sock.Close());
}
// tracker construct goodset
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
@ -352,7 +352,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
LinkRecord r;
int hport, hrank;
std::string hname;
tracker.Recv(&hname);
SafeColl(tracker.Recv(&hname));
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
// connect to peer
@ -360,7 +360,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
timeout_sec, &r.sock)
.OK()) {
num_error += 1;
r.sock.Close();
SafeColl(r.sock.Close());
continue;
}
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
@ -386,7 +386,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
// send back socket listening port to tracker
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
// close connection to tracker
tracker.Close();
SafeColl(tracker.Close());
// listen to incoming links
for (int i = 0; i < num_accept; ++i) {
@ -408,7 +408,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
}
if (!match) all_links.emplace_back(std::move(r));
}
sock_listen.Close();
SafeColl(sock_listen.Close());
this->parent_index = -1;
// setup tree links and ring structure
@ -635,7 +635,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
if (len == 0) {
links[parent_index].sock.Close();
SafeColl(links[parent_index].sock.Close());
return ReportError(&links[parent_index], kRecvZeroLen);
}
if (len != -1) {

View File

@ -270,7 +270,7 @@ class AllreduceBase : public IEngine {
ssize_t len = sock.Recv(buffer_head + offset, nmax);
// length equals 0, remote disconnected
if (len == 0) {
sock.Close(); return kRecvZeroLen;
SafeColl(sock.Close()); return kRecvZeroLen;
}
if (len == -1) return Errno2Return();
size_read += static_cast<size_t>(len);
@ -289,7 +289,7 @@ class AllreduceBase : public IEngine {
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
// length equals 0, remote disconnected
if (len == 0) {
sock.Close(); return kRecvZeroLen;
SafeColl(sock.Close()); return kRecvZeroLen;
}
if (len == -1) return Errno2Return();
size_read += static_cast<size_t>(len);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2014-2024 by XGBoost Contributors
* Copyright 2014-2024, XGBoost Contributors
*/
#include "xgboost/c_api.h"
@ -617,8 +617,8 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const
API_BEGIN();
CHECK_HANDLE();
xgboost_CHECK_C_ARG_PTR(field);
auto const& p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
p_fmat->SetInfo(field, info, xgboost::DataType::kFloat32, len);
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
p_fmat->SetInfo(field, linalg::Make1dInterface(info, len));
API_END();
}
@ -637,8 +637,9 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const
API_BEGIN();
CHECK_HANDLE();
xgboost_CHECK_C_ARG_PTR(field);
LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGDMatrixSetInfoFromInterface");
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
p_fmat->SetInfo(field, info, xgboost::DataType::kUInt32, len);
p_fmat->SetInfo(field, linalg::Make1dInterface(info, len));
API_END();
}
@ -682,19 +683,52 @@ XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void
xgboost::bst_ulong size, int type) {
API_BEGIN();
CHECK_HANDLE();
LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGDMatrixSetInfoFromInterface");
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
CHECK(type >= 1 && type <= 4);
xgboost_CHECK_C_ARG_PTR(field);
p_fmat->SetInfo(field, data, static_cast<DataType>(type), size);
API_END();
}
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, const unsigned *group, xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead.";
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
p_fmat->SetInfo("group", group, xgboost::DataType::kUInt32, len);
Context ctx;
auto dtype = static_cast<DataType>(type);
std::string str;
auto proc = [&](auto cast_d_ptr) {
using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
auto t = linalg::TensorView<T, 1>(
common::Span<T>{cast_d_ptr, static_cast<typename common::Span<T>::index_type>(size)},
{size}, DeviceOrd::CPU());
CHECK(t.CContiguous());
Json iface{linalg::ArrayInterface(t)};
CHECK(ArrayInterface<1>{iface}.is_contiguous);
str = Json::Dump(iface);
return str;
};
// Legacy code using XGBoost dtype, which is a small subset of array interface types.
switch (dtype) {
case xgboost::DataType::kFloat32: {
auto cast_ptr = reinterpret_cast<const float *>(data);
p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr));
break;
}
case xgboost::DataType::kDouble: {
auto cast_ptr = reinterpret_cast<const double *>(data);
p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr));
break;
}
case xgboost::DataType::kUInt32: {
auto cast_ptr = reinterpret_cast<const uint32_t *>(data);
p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr));
break;
}
case xgboost::DataType::kUInt64: {
auto cast_ptr = reinterpret_cast<const uint64_t *>(data);
p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr));
break;
}
default:
LOG(FATAL) << "Unknown data type" << static_cast<uint8_t>(dtype);
}
API_END();
}
@ -990,7 +1024,7 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle, DMatrixHandle dtrain, bs
bst_float *hess, xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
error::DeprecatedFunc(__func__, "2.1.0", "XGBoosterTrainOneIter");
LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGBoosterTrainOneIter");
auto *learner = static_cast<Learner *>(handle);
auto ctx = learner->Ctx()->MakeCPU();

View File

@ -1,17 +1,18 @@
/**
* Copyright 2021-2023, XGBoost Contributors
* Copyright 2021-2024, XGBoost Contributors
*/
#ifndef XGBOOST_C_API_C_API_UTILS_H_
#define XGBOOST_C_API_C_API_UTILS_H_
#include <algorithm>
#include <cstddef>
#include <functional>
#include <memory> // for shared_ptr
#include <string> // for string
#include <tuple> // for make_tuple
#include <utility> // for move
#include <vector>
#include <algorithm> // for min
#include <cstddef> // for size_t
#include <functional> // for multiplies
#include <memory> // for shared_ptr
#include <numeric> // for accumulate
#include <string> // for string
#include <tuple> // for make_tuple
#include <utility> // for move
#include <vector> // for vector
#include "../common/json_utils.h" // for TypeCheck
#include "xgboost/c_api.h"

View File

@ -1,15 +1,17 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include <chrono> // for seconds
#include <cstddef> // for size_t
#include <future> // for future
#include <memory> // for unique_ptr
#include <string> // for string
#include <thread> // for sleep_for
#include <type_traits> // for is_same_v, remove_pointer_t
#include <utility> // for pair
#include "../collective/comm.h" // for DefaultTimeoutSec
#include "../collective/tracker.h" // for RabitTracker
#include "../common/timer.h" // for Timer
#include "c_api_error.h" // for API_BEGIN
#include "xgboost/c_api.h"
#include "xgboost/collective/result.h" // for Result
@ -26,7 +28,7 @@ using namespace xgboost; // NOLINT
namespace {
using TrackerHandleT =
std::pair<std::unique_ptr<collective::Tracker>, std::shared_future<collective::Result>>;
std::pair<std::shared_ptr<collective::Tracker>, std::shared_future<collective::Result>>;
TrackerHandleT *GetTrackerHandle(TrackerHandle handle) {
xgboost_CHECK_C_ARG_PTR(handle);
@ -40,17 +42,29 @@ struct CollAPIEntry {
};
using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
void WaitImpl(TrackerHandleT *ptr) {
std::chrono::seconds wait_for{100};
void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
constexpr std::int64_t kDft{collective::DefaultTimeoutSec()};
std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft};
common::Timer timer;
timer.Start();
auto ref = ptr->first; // hold a reference to that free don't delete it while waiting.
auto fut = ptr->second;
while (fut.valid()) {
auto res = fut.wait_for(wait_for);
CHECK(res != std::future_status::deferred);
if (res == std::future_status::ready) {
auto const &rc = ptr->second.get();
CHECK(rc.OK()) << rc.Report();
collective::SafeColl(rc);
break;
}
if (timer.Duration() > timeout && timeout.count() != 0) {
collective::SafeColl(collective::Fail("Timeout waiting for the tracker."));
}
}
}
} // namespace
@ -62,15 +76,15 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle) {
Json jconfig = Json::Load(config);
auto type = RequiredArg<String>(jconfig, "dmlc_communicator", __func__);
std::unique_ptr<collective::Tracker> tptr;
std::shared_ptr<collective::Tracker> tptr;
if (type == "federated") {
#if defined(XGBOOST_USE_FEDERATED)
tptr = std::make_unique<collective::FederatedTracker>(jconfig);
tptr = std::make_shared<collective::FederatedTracker>(jconfig);
#else
LOG(FATAL) << error::NoFederated();
#endif // defined(XGBOOST_USE_FEDERATED)
} else if (type == "rabit") {
tptr = std::make_unique<collective::RabitTracker>(jconfig);
tptr = std::make_shared<collective::RabitTracker>(jconfig);
} else {
LOG(FATAL) << "Unknown communicator:" << type;
}
@ -93,7 +107,7 @@ XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args) {
API_END();
}
XGB_DLL int XGTrackerRun(TrackerHandle handle) {
XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *) {
API_BEGIN();
auto *ptr = GetTrackerHandle(handle);
CHECK(!ptr->second.valid()) << "Tracker is already running.";
@ -101,19 +115,39 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle) {
API_END();
}
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) {
XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config) {
API_BEGIN();
auto *ptr = GetTrackerHandle(handle);
xgboost_CHECK_C_ARG_PTR(config);
auto jconfig = Json::Load(StringView{config});
WaitImpl(ptr);
// Internally, 0 indicates no timeout, which is the default since we don't want to
// interrupt the model training.
xgboost_CHECK_C_ARG_PTR(config);
auto timeout = OptionalArg<Integer>(jconfig, "timeout", std::int64_t{0});
WaitImpl(ptr, std::chrono::seconds{timeout});
API_END();
}
XGB_DLL int XGTrackerFree(TrackerHandle handle) {
API_BEGIN();
using namespace std::chrono_literals; // NOLINT
auto *ptr = GetTrackerHandle(handle);
WaitImpl(ptr);
ptr->first->Stop();
// The wait is not necessary since we just called stop, just reusing the function to do
// any potential cleanups.
WaitImpl(ptr, ptr->first->Timeout());
common::Timer timer;
timer.Start();
// Make sure no one else is waiting on the tracker.
while (!ptr->first.unique()) {
auto ela = timer.Duration().count();
if (ela > ptr->first->Timeout().count()) {
LOG(WARNING) << "Time out " << ptr->first->Timeout().count()
<< " seconds reached for TrackerFree, killing the tracker.";
break;
}
std::this_thread::sleep_for(64ms);
}
delete ptr;
API_END();
}

View File

@ -165,7 +165,7 @@ template <typename T>
T GlobalRatio(Context const* ctx, MetaInfo const& info, T dividend, T divisor) {
std::array<T, 2> results{dividend, divisor};
auto rc = GlobalSum(ctx, info, linalg::MakeVec(results.data(), results.size()));
collective::SafeColl(rc);
SafeColl(rc);
std::tie(dividend, divisor) = std::tuple_cat(results);
if (divisor <= 0) {
return std::numeric_limits<T>::quiet_NaN();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include "allgather.h"
@ -7,6 +7,7 @@
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int32_t, int64_t
#include <memory> // for shared_ptr
#include <utility> // for move
#include "broadcast.h"
#include "comm.h" // for Comm, Channel
@ -29,16 +30,22 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
auto rc = Success() << [&] {
auto send_rank = (rank + world - r + worker_off) % world;
auto send_off = send_rank * segment_size;
send_off = std::min(send_off, data.size_bytes());
auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off));
bool is_last_segment = send_rank == (world - 1);
auto send_nbytes = is_last_segment ? (data.size_bytes() - send_off) : segment_size;
auto send_seg = data.subspan(send_off, send_nbytes);
CHECK_NE(send_seg.size(), 0);
return next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
} << [&] {
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
auto recv_off = recv_rank * segment_size;
recv_off = std::min(recv_off, data.size_bytes());
auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off));
bool is_last_segment = recv_rank == (world - 1);
auto recv_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : segment_size;
auto recv_seg = data.subspan(recv_off, recv_nbytes);
CHECK_NE(recv_seg.size(), 0);
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
} << [&] { return prev_ch->Block(); };
} << [&] {
return comm.Block();
};
if (!rc.OK()) {
return rc;
}
@ -91,7 +98,9 @@ namespace detail {
auto recv_size = sizes[recv_rank];
auto recv_seg = erased_result.subspan(recv_off, recv_size);
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
} << [&] { return prev_ch->Block(); };
} << [&] {
return prev_ch->Block();
};
if (!rc.OK()) {
return rc;
}
@ -99,4 +108,47 @@ namespace detail {
return comm.Block();
}
} // namespace detail
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
Context const* ctx, CommGroup const& comm, std::vector<std::vector<char>> const& input) {
auto n_inputs = input.size();
std::vector<std::int64_t> sizes(n_inputs);
std::transform(input.cbegin(), input.cend(), sizes.begin(),
[](auto const& vec) { return vec.size(); });
std::vector<std::int64_t> recv_segments(comm.World() + 1, 0);
HostDeviceVector<std::int8_t> recv;
auto rc =
AllgatherV(ctx, comm, linalg::MakeVec(sizes.data(), sizes.size()), &recv_segments, &recv);
SafeColl(rc);
auto global_sizes = common::RestoreType<std::int64_t const>(recv.ConstHostSpan());
std::vector<std::int64_t> offset(global_sizes.size() + 1);
offset[0] = 0;
for (std::size_t i = 1; i < offset.size(); i++) {
offset[i] = offset[i - 1] + global_sizes[i - 1];
}
std::vector<char> collected;
for (auto const& vec : input) {
collected.insert(collected.end(), vec.cbegin(), vec.cend());
}
rc = AllgatherV(ctx, comm, linalg::MakeVec(collected.data(), collected.size()), &recv_segments,
&recv);
SafeColl(rc);
auto out = common::RestoreType<char const>(recv.ConstHostSpan());
std::vector<std::vector<char>> result;
for (std::size_t i = 1; i < offset.size(); ++i) {
std::vector<char> local(out.cbegin() + offset[i - 1], out.cbegin() + offset[i]);
result.emplace_back(std::move(local));
}
return result;
}
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
Context const* ctx, std::vector<std::vector<char>> const& input) {
return VectorAllgatherV(ctx, *GlobalCommGroup(), input);
}
} // namespace xgboost::collective

View File

@ -1,25 +1,27 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <memory> // for shared_ptr
#include <numeric> // for accumulate
#include <string> // for string
#include <type_traits> // for remove_cv_t
#include <vector> // for vector
#include "../common/type.h" // for EraseType
#include "../common/type.h" // for EraseType
#include "comm.h" // for Comm, Channel
#include "comm_group.h" // for CommGroup
#include "xgboost/collective/result.h" // for Result
#include "xgboost/linalg.h"
#include "xgboost/span.h" // for Span
#include "xgboost/linalg.h" // for MakeVec
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
namespace cpu_impl {
/**
* @param worker_off Segment offset. For example, if the rank 2 worker specifies
* worker_off = 1, then it owns the third segment.
* worker_off = 1, then it owns the third segment (2 + 1).
*/
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data,
std::size_t segment_size, std::int32_t worker_off,
@ -51,8 +53,10 @@ inline void AllgatherVOffset(common::Span<std::int64_t const> sizes,
} // namespace detail
template <typename T>
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
auto n_bytes = sizeof(T) * size;
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data) {
// This function is also used for ring allreduce, hence we allow the last segment to be
// larger due to round-down.
auto n_bytes_per_segment = data.size_bytes() / comm.World();
auto erased = common::EraseType(data);
auto rank = comm.Rank();
@ -61,7 +65,7 @@ template <typename T>
auto prev_ch = comm.Chan(prev);
auto next_ch = comm.Chan(next);
auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes, 0, prev_ch, next_ch);
auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes_per_segment, 0, prev_ch, next_ch);
if (!rc.OK()) {
return rc;
}
@ -76,7 +80,7 @@ template <typename T>
std::vector<std::int64_t> sizes(world, 0);
sizes[rank] = data.size_bytes();
auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()}, 1);
auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()});
if (!rc.OK()) {
return rc;
}
@ -98,4 +102,115 @@ template <typename T>
return detail::RingAllgatherV(comm, sizes, s_segments, erased_result);
}
template <typename T>
[[nodiscard]] Result Allgather(Context const* ctx, CommGroup const& comm,
linalg::VectorView<T> data) {
if (!comm.IsDistributed()) {
return Success();
}
CHECK(data.Contiguous());
auto erased = common::EraseType(data.Values());
auto const& cctx = comm.Ctx(ctx, data.Device());
auto backend = comm.Backend(data.Device());
return backend->Allgather(cctx, erased);
}
/**
* @brief Gather all data from all workers.
*
* @param data The input and output buffer, needs to be pre-allocated by the caller.
*/
template <typename T>
[[nodiscard]] Result Allgather(Context const* ctx, linalg::VectorView<T> data) {
auto const& cg = *GlobalCommGroup();
if (data.Size() % cg.World() != 0) {
return Fail("The total number of elements should be multiple of the number of workers.");
}
return Allgather(ctx, cg, data);
}
template <typename T>
[[nodiscard]] Result AllgatherV(Context const* ctx, CommGroup const& comm,
linalg::VectorView<T> data,
std::vector<std::int64_t>* recv_segments,
HostDeviceVector<std::int8_t>* recv) {
if (!comm.IsDistributed()) {
return Success();
}
std::vector<std::int64_t> sizes(comm.World(), 0);
sizes[comm.Rank()] = data.Values().size_bytes();
auto erased_sizes = common::EraseType(common::Span{sizes.data(), sizes.size()});
auto rc = comm.Backend(DeviceOrd::CPU())
->Allgather(comm.Ctx(ctx, DeviceOrd::CPU()), erased_sizes);
if (!rc.OK()) {
return rc;
}
recv_segments->resize(sizes.size() + 1);
detail::AllgatherVOffset(sizes, common::Span{recv_segments->data(), recv_segments->size()});
auto total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0LL);
recv->SetDevice(data.Device());
recv->Resize(total_bytes);
auto s_segments = common::Span{recv_segments->data(), recv_segments->size()};
auto backend = comm.Backend(data.Device());
auto erased = common::EraseType(data.Values());
return backend->AllgatherV(
comm.Ctx(ctx, data.Device()), erased, common::Span{sizes.data(), sizes.size()}, s_segments,
data.Device().IsCUDA() ? recv->DeviceSpan() : recv->HostSpan(), AllgatherVAlgo::kBcast);
}
/**
* @brief Allgather with variable length data.
*
* @param data The input data.
* @param recv_segments segment size for each worker. [0, 2, 5] means [0, 2) elements are
* from the first worker, [2, 5) elements are from the second one.
* @param recv The buffer storing the result.
*/
template <typename T>
[[nodiscard]] Result AllgatherV(Context const* ctx, linalg::VectorView<T> data,
std::vector<std::int64_t>* recv_segments,
HostDeviceVector<std::int8_t>* recv) {
return AllgatherV(ctx, *GlobalCommGroup(), data, recv_segments, recv);
}
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
Context const* ctx, CommGroup const& comm, std::vector<std::vector<char>> const& input);
/**
* @brief Gathers variable-length data from all processes and distributes it to all processes.
*
* @param inputs All the inputs from the local worker. The number of inputs can vary
* across different workers. Along with which, the size of each vector in
* the input can also vary.
*
* @return The AllgatherV result, containing vectors from all workers.
*/
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
Context const* ctx, std::vector<std::vector<char>> const& input);
/**
* @brief Gathers variable-length strings from all processes and distributes them to all processes.
* @param input Variable-length list of variable-length strings.
*/
[[nodiscard]] inline Result AllgatherStrings(std::vector<std::string> const& input,
std::vector<std::string>* p_result) {
std::vector<std::vector<char>> inputs(input.size());
for (std::size_t i = 0; i < input.size(); ++i) {
inputs[i] = {input[i].cbegin(), input[i].cend()};
}
Context ctx;
auto out = VectorAllgatherV(&ctx, *GlobalCommGroup(), inputs);
auto& result = *p_result;
result.resize(out.size());
for (std::size_t i = 0; i < out.size(); ++i) {
result[i] = {out[i].cbegin(), out[i].cend()};
}
return Success();
}
} // namespace xgboost::collective

View File

@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include "allreduce.h"
@ -16,7 +16,44 @@
#include "xgboost/span.h" // for Span
namespace xgboost::collective::cpu_impl {
namespace {
template <typename T>
Result RingAllreduceSmall(Comm const& comm, common::Span<std::int8_t> data, Func const& op) {
auto rank = comm.Rank();
auto world = comm.World();
auto next_ch = comm.Chan(BootstrapNext(rank, world));
auto prev_ch = comm.Chan(BootstrapPrev(rank, world));
std::vector<std::int8_t> buffer(data.size_bytes() * world, 0);
auto s_buffer = common::Span{buffer.data(), buffer.size()};
auto offset = data.size_bytes() * rank;
auto self = s_buffer.subspan(offset, data.size_bytes());
std::copy_n(data.data(), data.size_bytes(), self.data());
auto typed = common::RestoreType<T>(s_buffer);
auto rc = RingAllgather(comm, typed);
if (!rc.OK()) {
return rc;
}
auto first = s_buffer.subspan(0, data.size_bytes());
CHECK_EQ(first.size(), data.size());
for (std::int32_t r = 1; r < world; ++r) {
auto offset = data.size_bytes() * r;
auto buf = s_buffer.subspan(offset, data.size_bytes());
op(buf, first);
}
std::copy_n(first.data(), first.size(), data.data());
return Success();
}
} // namespace
template <typename T>
// note that n_bytes_in_seg is calculated with round-down.
Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
std::size_t n_bytes_in_seg, Func const& op) {
auto rank = comm.Rank();
@ -27,33 +64,39 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
auto next_ch = comm.Chan(dst_rank);
auto prev_ch = comm.Chan(src_rank);
std::vector<std::int8_t> buffer(n_bytes_in_seg, 0);
std::vector<std::int8_t> buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, 0);
auto s_buf = common::Span{buffer.data(), buffer.size()};
for (std::int32_t r = 0; r < world - 1; ++r) {
// send to ring next
auto send_off = ((rank + world - r) % world) * n_bytes_in_seg;
send_off = std::min(send_off, data.size_bytes());
auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg);
auto send_seg = data.subspan(send_off, seg_nbytes);
common::Span<std::int8_t> seg, recv_seg;
auto rc = Success() << [&] {
// send to ring next
auto send_rank = (rank + world - r) % world;
auto send_off = send_rank * n_bytes_in_seg;
auto rc = next_ch->SendAll(send_seg);
if (!rc.OK()) {
return rc;
}
bool is_last_segment = send_rank == (world - 1);
// receive from ring prev
auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg;
recv_off = std::min(recv_off, data.size_bytes());
seg_nbytes = std::min(data.size_bytes() - recv_off, n_bytes_in_seg);
CHECK_EQ(seg_nbytes % sizeof(T), 0);
auto recv_seg = data.subspan(recv_off, seg_nbytes);
auto seg = s_buf.subspan(0, recv_seg.size());
auto seg_nbytes = is_last_segment ? data.size_bytes() - send_off : n_bytes_in_seg;
CHECK_EQ(seg_nbytes % sizeof(T), 0);
rc = std::move(rc) << [&] { return prev_ch->RecvAll(seg); } << [&] { return comm.Block(); };
if (!rc.OK()) {
return rc;
}
auto send_seg = data.subspan(send_off, seg_nbytes);
return next_ch->SendAll(send_seg);
} << [&] {
// receive from ring prev
auto recv_rank = (rank + world - r - 1) % world;
auto recv_off = recv_rank * n_bytes_in_seg;
bool is_last_segment = recv_rank == (world - 1);
auto seg_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : n_bytes_in_seg;
CHECK_EQ(seg_nbytes % sizeof(T), 0);
recv_seg = data.subspan(recv_off, seg_nbytes);
seg = s_buf.subspan(0, recv_seg.size());
return prev_ch->RecvAll(seg);
} << [&] {
return comm.Block();
};
// accumulate to recv_seg
CHECK_EQ(seg.size(), recv_seg.size());
@ -68,6 +111,9 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
if (comm.World() == 1) {
return Success();
}
if (data.size_bytes() == 0) {
return Success();
}
return DispatchDType(type, [&](auto t) {
using T = decltype(t);
// Divide the data into segments according to the number of workers.
@ -75,7 +121,11 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
CHECK_EQ(data.size_bytes() % n_bytes_elem, 0);
auto n = data.size_bytes() / n_bytes_elem;
auto world = comm.World();
auto n_bytes_in_seg = common::DivRoundUp(n, world) * sizeof(T);
if (n < static_cast<decltype(n)>(world)) {
return RingAllreduceSmall<T>(comm, data, op);
}
auto n_bytes_in_seg = (n / world) * sizeof(T);
auto rc = RingScatterReduceTyped<T>(comm, data, n_bytes_in_seg, op);
if (!rc.OK()) {
return rc;
@ -88,7 +138,9 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
return std::move(rc) << [&] {
return RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
} << [&] { return comm.Block(); };
} << [&] {
return comm.Block();
};
});
}
} // namespace xgboost::collective::cpu_impl

View File

@ -1,15 +1,18 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int8_t
#include <functional> // for function
#include <type_traits> // for is_invocable_v, enable_if_t
#include <vector> // for vector
#include "../common/type.h" // for EraseType, RestoreType
#include "../data/array_interface.h" // for ArrayInterfaceHandler
#include "../data/array_interface.h" // for ToDType, ArrayInterfaceHandler
#include "comm.h" // for Comm, RestoreType
#include "comm_group.h" // for GlobalCommGroup
#include "xgboost/collective/result.h" // for Result
#include "xgboost/context.h" // for Context
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
@ -27,8 +30,7 @@ std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>
auto erased = common::EraseType(data);
auto type = ToDType<T>::kType;
auto erased_fn = [type, redop](common::Span<std::int8_t const> lhs,
common::Span<std::int8_t> out) {
auto erased_fn = [redop](common::Span<std::int8_t const> lhs, common::Span<std::int8_t> out) {
CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction.";
auto lhs_t = common::RestoreType<T const>(lhs);
auto rhs_t = common::RestoreType<T>(out);
@ -37,4 +39,40 @@ std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>
return cpu_impl::RingAllreduce(comm, erased, erased_fn, type);
}
template <typename T, std::int32_t kDim>
[[nodiscard]] Result Allreduce(Context const* ctx, CommGroup const& comm,
linalg::TensorView<T, kDim> data, Op op) {
if (!comm.IsDistributed()) {
return Success();
}
CHECK(data.Contiguous());
auto erased = common::EraseType(data.Values());
auto type = ToDType<T>::kType;
auto backend = comm.Backend(data.Device());
return backend->Allreduce(comm.Ctx(ctx, data.Device()), erased, type, op);
}
template <typename T, std::int32_t kDim>
[[nodiscard]] Result Allreduce(Context const* ctx, linalg::TensorView<T, kDim> data, Op op) {
return Allreduce(ctx, *GlobalCommGroup(), data, op);
}
/**
* @brief Specialization for std::vector.
*/
template <typename T, typename Alloc>
[[nodiscard]] Result Allreduce(Context const* ctx, std::vector<T, Alloc>* data, Op op) {
return Allreduce(ctx, linalg::MakeVec(data->data(), data->size()), op);
}
/**
* @brief Specialization for scalar value.
*/
template <typename T>
[[nodiscard]] std::enable_if_t<std::is_standard_layout_v<T> && std::is_trivial_v<T>, Result>
Allreduce(Context const* ctx, T* data, Op op) {
return Allreduce(ctx, linalg::MakeVec(data, 1), op);
}
} // namespace xgboost::collective

View File

@ -1,11 +1,15 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int32_t, int8_t
#include "comm.h" // for Comm
#include "xgboost/collective/result.h" // for
#include "../common/type.h"
#include "comm.h" // for Comm, EraseType
#include "comm_group.h" // for CommGroup
#include "xgboost/collective/result.h" // for Result
#include "xgboost/context.h" // for Context
#include "xgboost/linalg.h" // for VectorView
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
@ -23,4 +27,21 @@ template <typename T>
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes};
return cpu_impl::Broadcast(comm, erased, root);
}
template <typename T>
[[nodiscard]] Result Broadcast(Context const* ctx, CommGroup const& comm,
linalg::VectorView<T> data, std::int32_t root) {
if (!comm.IsDistributed()) {
return Success();
}
CHECK(data.Contiguous());
auto erased = common::EraseType(data.Values());
auto backend = comm.Backend(data.Device());
return backend->Broadcast(comm.Ctx(ctx, data.Device()), erased, root);
}
template <typename T>
[[nodiscard]] Result Broadcast(Context const* ctx, linalg::VectorView<T> data, std::int32_t root) {
return Broadcast(ctx, *GlobalCommGroup(), data, root);
}
} // namespace xgboost::collective

View File

@ -42,6 +42,10 @@ bool constexpr IsFloatingPointV() {
auto redop_fn = [](auto lhs, auto out, auto elem_op) {
auto p_lhs = lhs.data();
auto p_out = out.data();
#if defined(__GNUC__) || defined(__clang__)
// For the sum op, one can verify the simd by: addps %xmm15, %xmm14
#pragma omp simd
#endif
for (std::size_t i = 0; i < lhs.size(); ++i) {
p_out[i] = elem_op(p_lhs[i], p_out[i]);
}
@ -108,9 +112,8 @@ bool constexpr IsFloatingPointV() {
return cpu_impl::Broadcast(comm, data, root);
}
[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span<std::int8_t> data,
std::int64_t size) {
return RingAllgather(comm, data, size);
[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span<std::int8_t> data) {
return RingAllgather(comm, data);
}
[[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,

View File

@ -1,10 +1,9 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
#include <cstdint> // for int8_t, int64_t
#include "../common/cuda_context.cuh"
#include "../common/device_helpers.cuh"
#include "../data/array_interface.h"
#include "allgather.h" // for AllgatherVOffset
@ -166,14 +165,14 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
} << [&] { return nccl->Block(); };
}
[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span<std::int8_t> data,
std::int64_t size) {
[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span<std::int8_t> data) {
if (!comm.IsDistributed()) {
return Success();
}
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
CHECK(nccl);
auto stub = nccl->Stub();
auto size = data.size_bytes() / comm.World();
auto send = data.subspan(comm.Rank() * size, size);
return Success() << [&] {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
@ -8,8 +8,7 @@
#include "../data/array_interface.h" // for ArrayInterfaceHandler
#include "coll.h" // for Coll
#include "comm.h" // for Comm
#include "nccl_stub.h"
#include "xgboost/span.h" // for Span
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
class NCCLColl : public Coll {
@ -20,8 +19,7 @@ class NCCLColl : public Coll {
ArrayInterfaceHandler::Type type, Op op) override;
[[nodiscard]] Result Broadcast(Comm const& comm, common::Span<std::int8_t> data,
std::int32_t root) override;
[[nodiscard]] Result Allgather(Comm const& comm, common::Span<std::int8_t> data,
std::int64_t size) override;
[[nodiscard]] Result Allgather(Comm const& comm, common::Span<std::int8_t> data) override;
[[nodiscard]] Result AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,

View File

@ -48,10 +48,8 @@ class Coll : public std::enable_shared_from_this<Coll> {
* @brief Allgather
*
* @param [in,out] data Data buffer for input and output.
* @param [in] size Size of data for each worker.
*/
[[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span<std::int8_t> data,
std::int64_t size);
[[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span<std::int8_t> data);
/**
* @brief Allgather with variable length.
*

View File

@ -1,16 +1,19 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include "comm.h"
#include <algorithm> // for copy
#include <chrono> // for seconds
#include <cstdint> // for int32_t
#include <cstdlib> // for exit
#include <memory> // for shared_ptr
#include <string> // for string
#include <thread> // for thread
#include <utility> // for move, forward
#include "../common/common.h" // for AssertGPUSupport
#if !defined(XGBOOST_USE_NCCL)
#include "../common/common.h" // for AssertNCCLSupport
#endif // !defined(XGBOOST_USE_NCCL)
#include "allgather.h" // for RingAllgather
#include "protocol.h" // for kMagic
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
@ -21,11 +24,7 @@
namespace xgboost::collective {
Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t retry, std::string task_id)
: timeout_{timeout},
retry_{retry},
tracker_{host, port, -1},
task_id_{std::move(task_id)},
loop_{std::shared_ptr<Loop>{new Loop{timeout}}} {}
: timeout_{timeout}, retry_{retry}, tracker_{host, port, -1}, task_id_{std::move(task_id)} {}
Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
std::string const& task_id, TCPSocket* out, std::int32_t rank,
@ -187,12 +186,30 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
return Success();
}
RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t retry, std::string task_id, StringView nccl_path)
: HostComm{std::move(host), port, timeout, retry, std::move(task_id)},
namespace {
std::string InitLog(std::string task_id, std::int32_t rank) {
if (task_id.empty()) {
return "Rank " + std::to_string(rank);
}
return "Task " + task_id + " got rank " + std::to_string(rank);
}
} // namespace
RabitComm::RabitComm(std::string const& tracker_host, std::int32_t tracker_port,
std::chrono::seconds timeout, std::int32_t retry, std::string task_id,
StringView nccl_path)
: HostComm{tracker_host, tracker_port, timeout, retry, std::move(task_id)},
nccl_path_{std::move(nccl_path)} {
if (this->TrackerInfo().host.empty()) {
// Not in a distributed environment.
LOG(CONSOLE) << InitLog(task_id_, rank_);
return;
}
loop_.reset(new Loop{std::chrono::seconds{timeout_}}); // NOLINT
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
if (!rc.OK()) {
this->ResetState();
SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc)));
}
}
@ -219,20 +236,54 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
// Start command
TCPSocket listener = TCPSocket::Create(tracker.Domain());
std::int32_t lport = listener.BindHost();
listener.Listen();
std::int32_t lport{0};
rc = std::move(rc) << [&] {
return listener.BindHost(&lport);
} << [&] {
return listener.Listen();
};
if (!rc.OK()) {
return rc;
}
// create worker for listening to error notice.
auto domain = tracker.Domain();
std::shared_ptr<TCPSocket> error_sock{TCPSocket::CreatePtr(domain)};
auto eport = error_sock->BindHost();
error_sock->Listen();
std::int32_t eport{0};
rc = std::move(rc) << [&] {
return error_sock->BindHost(&eport);
} << [&] {
return error_sock->Listen();
};
if (!rc.OK()) {
return rc;
}
error_port_ = eport;
error_worker_ = std::thread{[error_sock = std::move(error_sock)] {
auto conn = error_sock->Accept();
TCPSocket conn;
SockAddress addr;
auto rc = error_sock->Accept(&conn, &addr);
// On Linux, a shutdown causes an invalid argument error;
if (rc.Code() == std::errc::invalid_argument) {
return;
}
// On Windows, accept returns a closed socket after finalize.
if (conn.IsClosed()) {
return;
}
// The error signal is from the tracker, while shutdown signal is from the shutdown method
// of the RabitComm class (this).
bool is_error{false};
rc = proto::Error{}.RecvSignal(&conn, &is_error);
if (!rc.OK()) {
LOG(WARNING) << rc.Report();
return;
}
if (!is_error) {
return; // shutdown
}
LOG(WARNING) << "Another worker is running into error.";
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
// exit is nicer than abort as the former performs cleanups.
@ -241,6 +292,9 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
LOG(FATAL) << "abort";
#endif
}};
// The worker thread is detached here to avoid the need to handle it later during
// destruction. For C++, if a thread is not joined or detached, it will segfault during
// destruction.
error_worker_.detach();
proto::Start start;
@ -253,7 +307,7 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
// get ring neighbors
std::string snext;
tracker.Recv(&snext);
rc = tracker.Recv(&snext);
if (!rc.OK()) {
return Fail("Failed to receive the rank for the next worker.", std::move(rc));
}
@ -273,14 +327,21 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
CHECK(this->channels_.empty());
for (auto& w : workers) {
if (w) {
rc = std::move(rc) << [&] { return w->SetNoDelay(); } << [&] { return w->NonBlocking(true); }
<< [&] { return w->SetKeepAlive(); };
rc = std::move(rc) << [&] {
return w->SetNoDelay();
} << [&] {
return w->NonBlocking(true);
} << [&] {
return w->SetKeepAlive();
};
}
if (!rc.OK()) {
return rc;
}
this->channels_.emplace_back(std::make_shared<Channel>(*this, w));
}
LOG(CONSOLE) << InitLog(task_id_, rank_);
return rc;
}
@ -288,6 +349,8 @@ RabitComm::~RabitComm() noexcept(false) {
if (!this->IsDistributed()) {
return;
}
LOG(WARNING) << "The communicator is being destroyed without a call to shutdown first. This can "
"lead to undefined behaviour.";
auto rc = this->Shutdown();
if (!rc.OK()) {
LOG(WARNING) << rc.Report();
@ -295,24 +358,52 @@ RabitComm::~RabitComm() noexcept(false) {
}
[[nodiscard]] Result RabitComm::Shutdown() {
if (!this->IsDistributed()) {
return Success();
}
// Tell the tracker that this worker is shutting down.
TCPSocket tracker;
// Tell the error hanlding thread that we are shutting down.
TCPSocket err_client;
return Success() << [&] {
return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World());
} << [&] {
return this->Block();
} << [&] {
Json jcmd{Object{}};
jcmd["cmd"] = Integer{static_cast<std::int32_t>(proto::CMD::kShutdown)};
auto scmd = Json::Dump(jcmd);
auto n_bytes = tracker.Send(scmd);
if (n_bytes != scmd.size()) {
return Fail("Faled to send cmd.");
}
return proto::ShutdownCMD{}.Send(&tracker);
} << [&] {
this->channels_.clear();
return Success();
} << [&] {
// Use tracker address to determine whether we want to use IPv6.
auto taddr = MakeSockAddress(xgboost::StringView{this->tracker_.host}, this->tracker_.port);
// Shutdown the error handling thread. We signal the thread through socket,
// alternatively, we can get the native handle and use pthread_cancel. But using a
// socket seems to be clearer as we know what's happening.
auto const& addr = taddr.IsV4() ? SockAddrV4::Loopback().Addr() : SockAddrV6::Loopback().Addr();
// We use hardcoded 10 seconds and 1 retry here since we are just connecting to a
// local socket. For a normal OS, this should be enough time to schedule the
// connection.
auto rc = Connect(StringView{addr}, this->error_port_, 1,
std::min(std::chrono::seconds{10}, timeout_), &err_client);
this->ResetState();
if (!rc.OK()) {
return Fail("Failed to connect to the error socket.", std::move(rc));
}
return rc;
} << [&] {
// We put error thread shutdown at the end so that we have a better chance to finish
// the previous more important steps.
return proto::Error{}.SignalShutdown(&err_client);
};
}
[[nodiscard]] Result RabitComm::LogTracker(std::string msg) const {
if (!this->IsDistributed()) {
LOG(CONSOLE) << msg;
return Success();
}
TCPSocket out;
proto::Print print;
return Success() << [&] { return this->ConnectTracker(&out); }
@ -320,8 +411,11 @@ RabitComm::~RabitComm() noexcept(false) {
}
[[nodiscard]] Result RabitComm::SignalError(Result const& res) {
TCPSocket out;
return Success() << [&] { return this->ConnectTracker(&out); }
<< [&] { return proto::ErrorCMD{}.WorkerSend(&out, res); };
TCPSocket tracker;
return Success() << [&] {
return this->ConnectTracker(&tracker);
} << [&] {
return proto::ErrorCMD{}.WorkerSend(&tracker, res);
};
}
} // namespace xgboost::collective

View File

@ -27,7 +27,7 @@ Result GetUniqueId(Comm const& comm, std::shared_ptr<NcclStub> stub, std::shared
ncclUniqueId id;
if (comm.Rank() == kRootRank) {
auto rc = stub->GetUniqueId(&id);
CHECK(rc.OK()) << rc.Report();
SafeColl(rc);
}
auto rc = coll->Broadcast(
comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)}, kRootRank);
@ -90,9 +90,8 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
auto s_this_uuid = s_uuid.subspan(root.Rank() * kUuidLength, kUuidLength);
GetCudaUUID(s_this_uuid, ctx->Device());
auto rc = pimpl->Allgather(root, common::EraseType(s_uuid), s_this_uuid.size_bytes());
CHECK(rc.OK()) << rc.Report();
auto rc = pimpl->Allgather(root, common::EraseType(s_uuid));
SafeColl(rc);
std::vector<xgboost::common::Span<std::uint64_t, kUuidLength>> converted(root.World());
std::size_t j = 0;
@ -113,7 +112,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
[&] {
return this->stub_->CommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank());
};
CHECK(rc.OK()) << rc.Report();
SafeColl(rc);
for (std::int32_t r = 0; r < root.World(); ++r) {
this->channels_.emplace_back(
@ -124,7 +123,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
NCCLComm::~NCCLComm() {
if (nccl_comm_) {
auto rc = stub_->CommDestroy(nccl_comm_);
CHECK(rc.OK()) << rc.Report();
SafeColl(rc);
}
}
} // namespace xgboost::collective

View File

@ -53,6 +53,10 @@ class NCCLComm : public Comm {
auto rc = this->Stream().Sync(false);
return GetCUDAResult(rc);
}
[[nodiscard]] Result Shutdown() final {
this->ResetState();
return Success();
}
};
class NCCLChannel : public Channel {

View File

@ -1,10 +1,10 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
#include <chrono> // for seconds
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <cstdint> // for int32_t, int64_t
#include <memory> // for shared_ptr
#include <string> // for string
#include <thread> // for thread
@ -14,13 +14,13 @@
#include "loop.h" // for Loop
#include "protocol.h" // for PeerInfo
#include "xgboost/collective/result.h" // for Result
#include "xgboost/collective/socket.h" // for TCPSocket
#include "xgboost/collective/socket.h" // for TCPSocket, GetHostName
#include "xgboost/context.h" // for Context
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
inline constexpr std::int32_t DefaultTimeoutSec() { return 300; } // 5min
inline constexpr std::int64_t DefaultTimeoutSec() { return 300; } // 5min
inline constexpr std::int32_t DefaultRetry() { return 3; }
// indexing into the ring
@ -51,11 +51,25 @@ class Comm : public std::enable_shared_from_this<Comm> {
proto::PeerInfo tracker_;
SockDomain domain_{SockDomain::kV4};
std::thread error_worker_;
std::int32_t error_port_;
std::string task_id_;
std::vector<std::shared_ptr<Channel>> channels_;
std::shared_ptr<Loop> loop_{new Loop{std::chrono::seconds{
DefaultTimeoutSec()}}}; // fixme: require federated comm to have a timeout
std::shared_ptr<Loop> loop_{nullptr}; // fixme: require federated comm to have a timeout
void ResetState() {
this->world_ = -1;
this->rank_ = 0;
this->timeout_ = std::chrono::seconds{DefaultTimeoutSec()};
tracker_ = proto::PeerInfo{};
this->task_id_.clear();
channels_.clear();
loop_.reset();
}
public:
Comm() = default;
@ -75,10 +89,13 @@ class Comm : public std::enable_shared_from_this<Comm> {
[[nodiscard]] auto Retry() const { return retry_; }
[[nodiscard]] auto TaskID() const { return task_id_; }
[[nodiscard]] auto Rank() const { return rank_; }
[[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; }
[[nodiscard]] bool IsDistributed() const { return world_ != -1; }
void Submit(Loop::Op op) const { loop_->Submit(op); }
[[nodiscard]] auto Rank() const noexcept { return rank_; }
[[nodiscard]] auto World() const noexcept { return IsDistributed() ? world_ : 1; }
[[nodiscard]] bool IsDistributed() const noexcept { return world_ != -1; }
void Submit(Loop::Op op) const {
CHECK(loop_);
loop_->Submit(op);
}
[[nodiscard]] virtual Result Block() const { return loop_->Block(); }
[[nodiscard]] virtual std::shared_ptr<Channel> Chan(std::int32_t rank) const {
@ -88,6 +105,14 @@ class Comm : public std::enable_shared_from_this<Comm> {
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
/**
* @brief Get a string ID for the current process.
*/
[[nodiscard]] virtual Result ProcessorName(std::string* out) const {
auto rc = GetHostName(out);
return rc;
}
[[nodiscard]] virtual Result Shutdown() = 0;
};
/**
@ -105,20 +130,20 @@ class RabitComm : public HostComm {
[[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
std::string task_id);
[[nodiscard]] Result Shutdown();
public:
// bootstrapping construction.
RabitComm() = default;
// ctor for testing where environment is known.
RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t retry, std::string task_id, StringView nccl_path);
RabitComm(std::string const& tracker_host, std::int32_t tracker_port,
std::chrono::seconds timeout, std::int32_t retry, std::string task_id,
StringView nccl_path);
~RabitComm() noexcept(false) override;
[[nodiscard]] bool IsFederated() const override { return false; }
[[nodiscard]] Result LogTracker(std::string msg) const override;
[[nodiscard]] Result SignalError(Result const&) override;
[[nodiscard]] Result Shutdown() final;
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
};

View File

@ -1,22 +1,21 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include "comm_group.h"
#include <algorithm> // for transform
#include <cctype> // for tolower
#include <chrono> // for seconds
#include <cstdint> // for int32_t
#include <iterator> // for back_inserter
#include <memory> // for shared_ptr, unique_ptr
#include <string> // for string
#include <vector> // for vector
#include "../common/json_utils.h" // for OptionalArg
#include "coll.h" // for Coll
#include "comm.h" // for Comm
#include "tracker.h" // for GetHostAddress
#include "xgboost/collective/result.h" // for Result
#include "xgboost/context.h" // for DeviceOrd
#include "xgboost/json.h" // for Json
#include "../common/json_utils.h" // for OptionalArg
#include "coll.h" // for Coll
#include "comm.h" // for Comm
#include "xgboost/context.h" // for DeviceOrd
#include "xgboost/json.h" // for Json
#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_coll.h"
@ -65,6 +64,9 @@ CommGroup::CommGroup()
auto const& obj = get<Object const>(config);
auto it = obj.find(upper);
if (it != obj.cend() && obj.find(name) != obj.cend()) {
LOG(FATAL) << "Duplicated parameter:" << name;
}
if (it != obj.cend()) {
return OptionalArg<decltype(t)>(config, upper, dft);
} else {
@ -78,14 +80,14 @@ CommGroup::CommGroup()
auto task_id = get_param("dmlc_task_id", std::string{}, String{});
if (type == "rabit") {
auto host = get_param("dmlc_tracker_uri", std::string{}, String{});
auto port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{});
auto tracker_host = get_param("dmlc_tracker_uri", std::string{}, String{});
auto tracker_port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{});
auto nccl = get_param("dmlc_nccl_path", std::string{DefaultNcclName()}, String{});
auto ptr =
new CommGroup{std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT
host, static_cast<std::int32_t>(port), std::chrono::seconds{timeout},
static_cast<std::int32_t>(retry), task_id, nccl}},
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
auto ptr = new CommGroup{
std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT
tracker_host, static_cast<std::int32_t>(tracker_port), std::chrono::seconds{timeout},
static_cast<std::int32_t>(retry), task_id, nccl}},
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
return ptr;
} else if (type == "federated") {
#if defined(XGBOOST_USE_FEDERATED)
@ -117,6 +119,8 @@ void GlobalCommGroupInit(Json config) {
void GlobalCommGroupFinalize() {
auto& sptr = GlobalCommGroup();
auto rc = sptr->Finalize();
sptr.reset();
SafeColl(rc);
}
} // namespace xgboost::collective

View File

@ -9,7 +9,6 @@
#include "coll.h" // for Comm
#include "comm.h" // for Coll
#include "xgboost/collective/result.h" // for Result
#include "xgboost/collective/socket.h" // for GetHostName
namespace xgboost::collective {
/**
@ -31,19 +30,35 @@ class CommGroup {
public:
CommGroup();
[[nodiscard]] auto World() const { return comm_->World(); }
[[nodiscard]] auto Rank() const { return comm_->Rank(); }
[[nodiscard]] bool IsDistributed() const { return comm_->IsDistributed(); }
[[nodiscard]] auto World() const noexcept { return comm_->World(); }
[[nodiscard]] auto Rank() const noexcept { return comm_->Rank(); }
[[nodiscard]] bool IsDistributed() const noexcept { return comm_->IsDistributed(); }
[[nodiscard]] Result Finalize() const {
return Success() << [this] {
if (gpu_comm_) {
return gpu_comm_->Shutdown();
}
return Success();
} << [&] {
return comm_->Shutdown();
};
}
[[nodiscard]] static CommGroup* Create(Json config);
[[nodiscard]] std::shared_ptr<Coll> Backend(DeviceOrd device) const;
/**
* @brief Decide the context to use for communication.
*
* @param ctx Global context, provides the CUDA stream and ordinal.
* @param device The device used by the data to be communicated.
*/
[[nodiscard]] Comm const& Ctx(Context const* ctx, DeviceOrd device) const;
[[nodiscard]] Result SignalError(Result const& res) { return comm_->SignalError(res); }
[[nodiscard]] Result ProcessorName(std::string* out) const {
auto rc = GetHostName(out);
return rc;
return this->comm_->ProcessorName(out);
}
};

View File

@ -32,7 +32,8 @@ class InMemoryHandler {
*
* This is used when the handler only needs to be initialized once with a known world size.
*/
explicit InMemoryHandler(std::size_t worldSize) : world_size_{worldSize} {}
explicit InMemoryHandler(std::int32_t worldSize)
: world_size_{static_cast<std::size_t>(worldSize)} {}
/**
* @brief Initialize the handler with the world size and rank.

View File

@ -18,9 +18,11 @@
#include "xgboost/logging.h" // for CHECK
namespace xgboost::collective {
Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
timer_.Start(__func__);
auto error = [this] { timer_.Stop(__func__); };
auto error = [this] {
timer_.Stop(__func__);
};
if (stop_) {
timer_.Stop(__func__);
@ -48,6 +50,9 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
poll.WatchWrite(*op.sock);
break;
}
case Op::kSleep: {
break;
}
default: {
error();
return Fail("Invalid socket operation.");
@ -59,12 +64,14 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
// poll, work on fds that are ready.
timer_.Start("poll");
auto rc = poll.Poll(timeout_);
timer_.Stop("poll");
if (!rc.OK()) {
error();
return rc;
if (!poll.fds.empty()) {
auto rc = poll.Poll(timeout_);
if (!rc.OK()) {
error();
return rc;
}
}
timer_.Stop("poll");
// we wonldn't be here if the queue is empty.
CHECK(!qcopy.empty());
@ -75,12 +82,20 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
qcopy.pop();
std::int32_t n_bytes_done{0};
CHECK(op.sock->NonBlocking());
if (!op.sock) {
CHECK(op.code == Op::kSleep);
} else {
CHECK(op.sock->NonBlocking());
}
switch (op.code) {
case Op::kRead: {
if (poll.CheckRead(*op.sock)) {
n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off);
if (n_bytes_done == 0) {
error();
return Fail("Encountered EOF. The other end is likely closed.");
}
}
break;
}
@ -90,6 +105,12 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
}
break;
}
case Op::kSleep: {
// For testing only.
std::this_thread::sleep_for(std::chrono::seconds{op.n});
n_bytes_done = op.n;
break;
}
default: {
error();
return Fail("Invalid socket operation.");
@ -110,6 +131,10 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
qcopy.push(op);
}
}
if (!blocking) {
break;
}
}
timer_.Stop(__func__);
@ -128,6 +153,15 @@ void Loop::Process() {
while (true) {
try {
std::unique_lock lock{mu_};
// This can handle missed notification: wait(lock, predicate) is equivalent to:
//
// while (!predicate()) {
// cv.wait(lock);
// }
//
// As a result, if there's a missed notification, the queue wouldn't be empty, hence
// the predicate would be false and the actual wait wouldn't be invoked. Therefore,
// the blocking call can never go unanswered.
cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; });
if (stop_) {
break; // only point where this loop can exit.
@ -142,26 +176,27 @@ void Loop::Process() {
queue_.pop();
if (op.code == Op::kBlock) {
is_blocking = true;
// Block must be the last op in the current batch since no further submit can be
// issued until the blocking call is finished.
CHECK(queue_.empty());
} else {
qcopy.push(op);
}
}
if (!is_blocking) {
// Unblock, we can write to the global queue again.
lock.unlock();
lock.unlock();
// Clear the local queue, if `is_blocking` is true, this is blocking the current
// worker thread (but not the client thread), wait until all operations are
// finished.
auto rc = this->ProcessQueue(&qcopy, is_blocking);
if (is_blocking && rc.OK()) {
CHECK(qcopy.empty());
}
// Clear the local queue, this is blocking the current worker thread (but not the
// client thread), wait until all operations are finished.
auto rc = this->EmptyQueue(&qcopy);
if (is_blocking) {
// The unlock is delayed if this is a blocking call
lock.unlock();
// Push back the remaining operations.
if (rc.OK()) {
std::unique_lock lock{mu_};
while (!qcopy.empty()) {
queue_.push(qcopy.front());
qcopy.pop();
}
}
// Notify the client thread who called block after all error conditions are set.
@ -228,7 +263,6 @@ Result Loop::Stop() {
}
this->Submit(Op{Op::kBlock});
{
// Wait for the block call to finish.
std::unique_lock lock{mu_};
@ -243,8 +277,20 @@ Result Loop::Stop() {
}
}
void Loop::Submit(Op op) {
std::unique_lock lock{mu_};
if (op.code != Op::kBlock) {
CHECK_NE(op.n, 0);
}
queue_.push(op);
lock.unlock();
cv_.notify_one();
}
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
timer_.Init(__func__);
worker_ = std::thread{[this] { this->Process(); }};
worker_ = std::thread{[this] {
this->Process();
}};
}
} // namespace xgboost::collective

View File

@ -19,20 +19,27 @@ namespace xgboost::collective {
class Loop {
public:
struct Op {
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2 } code;
// kSleep is only for testing
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2, kSleep = 4 } code;
std::int32_t rank{-1};
std::int8_t* ptr{nullptr};
std::size_t n{0};
TCPSocket* sock{nullptr};
std::size_t off{0};
explicit Op(Code c) : code{c} { CHECK(c == kBlock); }
explicit Op(Code c) : code{c} { CHECK(c == kBlock || c == kSleep); }
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
Op(Op const&) = default;
Op& operator=(Op const&) = default;
Op(Op&&) = default;
Op& operator=(Op&&) = default;
// For testing purpose only
[[nodiscard]] static Op Sleep(std::size_t seconds) {
Op op{kSleep};
op.n = seconds;
return op;
}
};
private:
@ -54,7 +61,7 @@ class Loop {
std::exception_ptr curr_exce_{nullptr};
common::Monitor mutable timer_;
Result EmptyQueue(std::queue<Op>* p_queue) const;
Result ProcessQueue(std::queue<Op>* p_queue, bool blocking) const;
// The cunsumer function that runs inside a worker thread.
void Process();
@ -64,12 +71,7 @@ class Loop {
*/
Result Stop();
void Submit(Op op) {
std::unique_lock lock{mu_};
queue_.push(op);
lock.unlock();
cv_.notify_one();
}
void Submit(Op op);
/**
* @brief Block the event loop until all ops are finished. In the case of failure, this

View File

@ -2,6 +2,8 @@
* Copyright 2023 XGBoost contributors
*/
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
#include <numeric> // for accumulate
#include "comm.cuh"
#include "nccl_device_communicator.cuh"

View File

@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int32_t
@ -58,6 +58,7 @@ struct Magic {
}
};
// Basic commands for communication between workers and the tracker.
enum class CMD : std::int32_t {
kInvalid = 0,
kStart = 1,
@ -84,7 +85,10 @@ struct Connect {
[[nodiscard]] Result TrackerRecv(TCPSocket* sock, std::int32_t* world, std::int32_t* rank,
std::string* task_id) const {
std::string init;
sock->Recv(&init);
auto rc = sock->Recv(&init);
if (!rc.OK()) {
return Fail("Connect protocol failed.", std::move(rc));
}
auto jinit = Json::Load(StringView{init});
*world = get<Integer const>(jinit["world_size"]);
*rank = get<Integer const>(jinit["rank"]);
@ -122,9 +126,9 @@ class Start {
}
[[nodiscard]] Result WorkerRecv(TCPSocket* tracker, std::int32_t* p_world) const {
std::string scmd;
auto n_bytes = tracker->Recv(&scmd);
if (n_bytes <= 0) {
return Fail("Failed to recv init command from tracker.");
auto rc = tracker->Recv(&scmd);
if (!rc.OK()) {
return Fail("Failed to recv init command from tracker.", std::move(rc));
}
auto jcmd = Json::Load(scmd);
auto world = get<Integer const>(jcmd["world_size"]);
@ -132,7 +136,7 @@ class Start {
return Fail("Invalid world size.");
}
*p_world = world;
return Success();
return rc;
}
[[nodiscard]] Result TrackerHandle(Json jcmd, std::int32_t* recv_world, std::int32_t world,
std::int32_t* p_port, TCPSocket* p_sock,
@ -150,6 +154,7 @@ class Start {
}
};
// Protocol for communicating with the tracker for printing message.
struct Print {
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::string msg) const {
Json jcmd{Object{}};
@ -172,6 +177,7 @@ struct Print {
}
};
// Protocol for communicating with the tracker during error.
struct ErrorCMD {
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, Result const& res) const {
auto msg = res.Report();
@ -199,6 +205,7 @@ struct ErrorCMD {
}
};
// Protocol for communicating with the tracker during shutdown.
struct ShutdownCMD {
[[nodiscard]] Result Send(TCPSocket* peer) const {
Json jcmd{Object{}};
@ -211,4 +218,40 @@ struct ShutdownCMD {
return Success();
}
};
// Protocol for communicating with the local error handler during error or shutdown. Only
// one protocol that doesn't have the tracker involved.
struct Error {
constexpr static std::int32_t ShutdownSignal() { return 0; }
constexpr static std::int32_t ErrorSignal() { return -1; }
[[nodiscard]] Result SignalError(TCPSocket* worker) const {
std::int32_t err{ErrorSignal()};
auto n_sent = worker->SendAll(&err, sizeof(err));
if (n_sent == sizeof(err)) {
return Success();
}
return Fail("Failed to send error signal");
}
// self is localhost, we are sending the signal to the error handling thread for it to
// close.
[[nodiscard]] Result SignalShutdown(TCPSocket* self) const {
std::int32_t err{ShutdownSignal()};
auto n_sent = self->SendAll(&err, sizeof(err));
if (n_sent == sizeof(err)) {
return Success();
}
return Fail("Failed to send shutdown signal");
}
// get signal, either for error or for shutdown.
[[nodiscard]] Result RecvSignal(TCPSocket* peer, bool* p_is_error) const {
std::int32_t err{ShutdownSignal()};
auto n_recv = peer->RecvAll(&err, sizeof(err));
if (n_recv == sizeof(err)) {
*p_is_error = err == 1;
return Success();
}
return Fail("Failed to receive error signal.");
}
};
} // namespace xgboost::collective::proto

86
src/collective/result.cc Normal file
View File

@ -0,0 +1,86 @@
/**
* Copyright 2024, XGBoost Contributors
*/
#include "xgboost/collective/result.h"
#include <filesystem> // for path
#include <sstream> // for stringstream
#include <stack> // for stack
#include "xgboost/logging.h"
namespace xgboost::collective {
namespace detail {
[[nodiscard]] std::string ResultImpl::Report() const {
std::stringstream ss;
ss << "\n- " << this->message;
if (this->errc != std::error_code{}) {
ss << " system error:" << this->errc.message();
}
auto ptr = prev.get();
while (ptr) {
ss << "\n- ";
ss << ptr->message;
if (ptr->errc != std::error_code{}) {
ss << " " << ptr->errc.message();
}
ptr = ptr->prev.get();
}
return ss.str();
}
[[nodiscard]] std::error_code ResultImpl::Code() const {
// Find the root error.
std::stack<ResultImpl const*> stack;
auto ptr = this;
while (ptr) {
stack.push(ptr);
if (ptr->prev) {
ptr = ptr->prev.get();
} else {
break;
}
}
while (!stack.empty()) {
auto frame = stack.top();
stack.pop();
if (frame->errc != std::error_code{}) {
return frame->errc;
}
}
return std::error_code{};
}
void ResultImpl::Concat(std::unique_ptr<ResultImpl> rhs) {
auto ptr = this;
while (ptr->prev) {
ptr = ptr->prev.get();
}
ptr->prev = std::move(rhs);
}
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
std::string MakeMsg(std::string&& msg, char const*, std::int32_t) {
return std::forward<std::string>(msg);
}
#else
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line) {
auto name = std::filesystem::path{file}.filename();
if (file && line != -1) {
return "[" + name.string() + ":" + std::to_string(line) + // NOLINT
"]: " + std::forward<std::string>(msg);
}
return std::forward<std::string>(msg);
}
#endif
} // namespace detail
void SafeColl(Result const& rc) {
if (!rc.OK()) {
LOG(FATAL) << rc.Report();
}
}
} // namespace xgboost::collective

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022-2023 by XGBoost Contributors
* Copyright 2022-2024, XGBoost Contributors
*/
#include "xgboost/collective/socket.h"
@ -8,7 +8,8 @@
#include <cstdint> // std::int32_t
#include <cstring> // std::memcpy, std::memset
#include <filesystem> // for path
#include <system_error> // std::error_code, std::system_category
#include <system_error> // for error_code, system_category
#include <thread> // for sleep_for
#include "rabit/internal/socket.h" // for PollHelper
#include "xgboost/collective/result.h" // for Result
@ -65,14 +66,18 @@ std::size_t TCPSocket::Send(StringView str) {
return bytes;
}
std::size_t TCPSocket::Recv(std::string *p_str) {
[[nodiscard]] Result TCPSocket::Recv(std::string *p_str) {
CHECK(!this->IsClosed());
std::int32_t len;
CHECK_EQ(this->RecvAll(&len, sizeof(len)), sizeof(len)) << "Failed to recv string length.";
if (this->RecvAll(&len, sizeof(len)) != sizeof(len)) {
return Fail("Failed to recv string length.");
}
p_str->resize(len);
auto bytes = this->RecvAll(&(*p_str)[0], len);
CHECK_EQ(bytes, len) << "Failed to recv string.";
return bytes;
if (static_cast<decltype(len)>(bytes) != len) {
return Fail("Failed to recv string.");
}
return Success();
}
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
@ -110,11 +115,7 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
if (attempt > 0) {
LOG(WARNING) << "Retrying connection to " << host << " for the " << attempt << " time.";
#if defined(_MSC_VER) || defined(__MINGW32__)
Sleep(attempt << 1);
#else
sleep(attempt << 1);
#endif
std::this_thread::sleep_for(std::chrono::seconds{attempt << 1});
}
auto rc = connect(conn.Handle(), addr_handle, addr_len);
@ -158,8 +159,8 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
std::stringstream ss;
ss << "Failed to connect to " << host << ":" << port;
conn.Close();
return Fail(ss.str(), std::move(last_error));
auto close_rc = conn.Close();
return Fail(ss.str(), std::move(close_rc) + std::move(last_error));
}
[[nodiscard]] Result GetHostName(std::string *p_out) {

View File

@ -1,6 +1,7 @@
/**
* Copyright 2023-2024, XGBoost Contributors
*/
#include "rabit/internal/socket.h"
#if defined(__unix__) || defined(__APPLE__)
#include <netdb.h> // gethostbyname
#include <sys/socket.h> // socket, AF_INET6, AF_INET, connect, getsockname
@ -70,10 +71,13 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA
return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_);
} << [&] {
std::string cmd;
sock_.Recv(&cmd);
auto rc = sock_.Recv(&cmd);
if (!rc.OK()) {
return rc;
}
jcmd = Json::Load(StringView{cmd});
cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"]));
return Success();
return rc;
} << [&] {
if (cmd_ == proto::CMD::kStart) {
proto::Start start;
@ -100,14 +104,18 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA
RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
std::string self;
auto rc = collective::GetHostAddress(&self);
host_ = OptionalArg<String>(config, "host", self);
auto rc = Success() << [&] {
return collective::GetHostAddress(&self);
} << [&] {
host_ = OptionalArg<String>(config, "host", self);
auto addr = MakeSockAddress(xgboost::StringView{host_}, 0);
listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6);
rc = listener_.Bind(host_, &this->port_);
auto addr = MakeSockAddress(xgboost::StringView{host_}, 0);
listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6);
return listener_.Bind(host_, &this->port_);
} << [&] {
return listener_.Listen();
};
SafeColl(rc);
listener_.Listen();
}
Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
@ -220,9 +228,13 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
//
// retry is set to 1, just let the worker timeout or error. Otherwise the
// tracker and the worker might be waiting for each other.
auto rc = Connect(w.first, w.second, 1, timeout_, &out);
auto rc = Success() << [&] {
return Connect(w.first, w.second, 1, timeout_, &out);
} << [&] {
return proto::Error{}.SignalError(&out);
};
if (!rc.OK()) {
return Fail("Failed to inform workers to stop.");
return Fail("Failed to inform worker:" + w.first + " for error.", std::move(rc));
}
}
return Success();
@ -231,13 +243,37 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
return std::async(std::launch::async, [this, handle_error] {
State state{this->n_workers_};
auto select_accept = [&](TCPSocket* sock, auto* addr) {
// accept with poll so that we can enable timeout and interruption.
rabit::utils::PollHelper poll;
auto rc = Success() << [&] {
std::lock_guard lock{listener_mu_};
return listener_.NonBlocking(true);
} << [&] {
std::lock_guard lock{listener_mu_};
poll.WatchRead(listener_);
if (state.running) {
// Don't timeout if the communicator group is up and running.
return poll.Poll(std::chrono::seconds{-1});
} else {
// Have timeout for workers to bootstrap.
return poll.Poll(timeout_);
}
} << [&] {
// this->Stop() closes the socket with a lock. Therefore, when the accept returns
// due to shutdown, the state is still valid (closed).
return listener_.Accept(sock, addr);
};
return rc;
};
while (state.ShouldContinue()) {
TCPSocket sock;
SockAddress addr;
this->ready_ = true;
auto rc = listener_.Accept(&sock, &addr);
auto rc = select_accept(&sock, &addr);
if (!rc.OK()) {
return Fail("Failed to accept connection.", std::move(rc));
return Fail("Failed to accept connection.", this->Stop() + std::move(rc));
}
auto worker = WorkerProxy{n_workers_, std::move(sock), std::move(addr)};
@ -252,7 +288,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
state.Error();
rc = handle_error(worker);
if (!rc.OK()) {
return Fail("Failed to handle abort.", std::move(rc));
return Fail("Failed to handle abort.", this->Stop() + std::move(rc));
}
}
@ -262,7 +298,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
state.Bootstrap();
}
if (!rc.OK()) {
return rc;
return this->Stop() + std::move(rc);
}
continue;
}
@ -289,12 +325,11 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
}
case proto::CMD::kInvalid:
default: {
return Fail("Invalid command received.");
return Fail("Invalid command received.", this->Stop());
}
}
}
ready_ = false;
return Success();
return this->Stop();
});
}
@ -303,11 +338,30 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
SafeColl(rc);
Json args{Object{}};
args["DMLC_TRACKER_URI"] = String{host_};
args["DMLC_TRACKER_PORT"] = this->Port();
args["dmlc_tracker_uri"] = String{host_};
args["dmlc_tracker_port"] = this->Port();
return args;
}
[[nodiscard]] Result RabitTracker::Stop() {
if (!this->Ready()) {
return Success();
}
ready_ = false;
std::lock_guard lock{listener_mu_};
if (this->listener_.IsClosed()) {
return Success();
}
return Success() << [&] {
// This should have the effect of stopping the `accept` call.
return this->listener_.Shutdown();
} << [&] {
return listener_.Close();
};
}
[[nodiscard]] Result GetHostAddress(std::string* out) {
auto rc = GetHostName(out);
if (!rc.OK()) {

View File

@ -36,15 +36,18 @@ namespace xgboost::collective {
* signal an error to the tracker and the tracker will notify other workers.
*/
class Tracker {
public:
enum class SortBy : std::int8_t {
kHost = 0,
kTask = 1,
};
protected:
// How to sort the workers, either by host name or by task ID. When using a multi-GPU
// setting, multiple workers can occupy the same host, in which case one should sort
// workers by task. Due to compatibility reason, the task ID is not always available, so
// we use host as the default.
enum class SortBy : std::int8_t {
kHost = 0,
kTask = 1,
} sortby_;
SortBy sortby_;
protected:
std::int32_t n_workers_{0};
@ -54,10 +57,7 @@ class Tracker {
public:
explicit Tracker(Json const& config);
Tracker(std::int32_t n_worders, std::int32_t port, std::chrono::seconds timeout)
: n_workers_{n_worders}, port_{port}, timeout_{timeout} {}
virtual ~Tracker() noexcept(false){}; // NOLINT
virtual ~Tracker() = default;
[[nodiscard]] Result WaitUntilReady() const;
@ -69,6 +69,11 @@ class Tracker {
* @brief Flag to indicate whether the server is running.
*/
[[nodiscard]] bool Ready() const { return ready_; }
/**
* @brief Shutdown the tracker, cannot be restarted again. Useful when the tracker hangs while
* calling accept.
*/
virtual Result Stop() { return Success(); }
};
class RabitTracker : public Tracker {
@ -127,28 +132,22 @@ class RabitTracker : public Tracker {
// record for how to reach out to workers if error happens.
std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_;
// listening socket for incoming workers.
//
// At the moment, the listener calls accept without first polling. We can add an
// additional unix domain socket to allow cancelling the accept.
TCPSocket listener_;
// mutex for protecting the listener, used to prevent race when it's listening while
// another thread tries to shut it down.
std::mutex listener_mu_;
Result Bootstrap(std::vector<WorkerProxy>* p_workers);
public:
explicit RabitTracker(StringView host, std::int32_t n_worders, std::int32_t port,
std::chrono::seconds timeout)
: Tracker{n_worders, port, timeout}, host_{host.c_str(), host.size()} {
listener_ = TCPSocket::Create(SockDomain::kV4);
auto rc = listener_.Bind(host, &this->port_);
CHECK(rc.OK()) << rc.Report();
listener_.Listen();
}
explicit RabitTracker(Json const& config);
~RabitTracker() noexcept(false) override = default;
~RabitTracker() override = default;
std::future<Result> Run() override;
[[nodiscard]] Json WorkerArgs() const override;
// Stop the tracker without waiting. This is to prevent the tracker from hanging when
// one of the workers failes to start.
[[nodiscard]] Result Stop() override;
};
// Prob the public IP address of the host, need a better method.

View File

@ -14,7 +14,6 @@
#include <thrust/iterator/transform_output_iterator.h> // make_transform_output_iterator
#include <thrust/logical.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#include <thrust/transform_scan.h>
@ -301,21 +300,22 @@ class MemoryLogger {
void RegisterAllocation(void *ptr, size_t n) {
device_allocations[ptr] = n;
currently_allocated_bytes += n;
peak_allocated_bytes =
std::max(peak_allocated_bytes, currently_allocated_bytes);
peak_allocated_bytes = std::max(peak_allocated_bytes, currently_allocated_bytes);
num_allocations++;
CHECK_GT(num_allocations, num_deallocations);
}
void RegisterDeallocation(void *ptr, size_t n, int current_device) {
auto itr = device_allocations.find(ptr);
if (itr == device_allocations.end()) {
LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device "
<< current_device << " that was never allocated ";
LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device " << current_device
<< " that was never allocated\n"
<< dmlc::StackTrace();
} else {
num_deallocations++;
CHECK_LE(num_deallocations, num_allocations);
currently_allocated_bytes -= itr->second;
device_allocations.erase(itr);
}
num_deallocations++;
CHECK_LE(num_deallocations, num_allocations);
currently_allocated_bytes -= itr->second;
device_allocations.erase(itr);
}
};
DeviceStats stats_;

View File

@ -11,7 +11,7 @@
#include "xgboost/logging.h"
namespace xgboost::error {
std::string DeprecatedFunc(StringView old, StringView since, StringView replacement) {
[[nodiscard]] std::string DeprecatedFunc(StringView old, StringView since, StringView replacement) {
std::stringstream ss;
ss << "`" << old << "` is deprecated since" << since << ", use `" << replacement << "` instead.";
return ss.str();

View File

@ -89,7 +89,7 @@ void WarnDeprecatedGPUId();
void WarnEmptyDataset();
std::string DeprecatedFunc(StringView old, StringView since, StringView replacement);
[[nodiscard]] std::string DeprecatedFunc(StringView old, StringView since, StringView replacement);
constexpr StringView InvalidCUDAOrdinal() {
return "Invalid device. `device` is required to be CUDA and there must be at least one GPU "

View File

@ -8,6 +8,7 @@
#define COMMON_HIST_UTIL_CUH_
#include <thrust/host_vector.h>
#include <thrust/sort.h> // for sort
#include <cstddef> // for size_t

View File

@ -6,7 +6,6 @@
#include <algorithm>
#include <cstdint>
#include <mutex>
#include "xgboost/data.h"
#include "xgboost/host_device_vector.h"

View File

@ -4,6 +4,7 @@
#include "quantile.h"
#include <limits>
#include <numeric> // for partial_sum
#include <utility>
#include "../collective/aggregator.h"

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2023 by XGBoost Contributors
* Copyright 2020-2024, XGBoost Contributors
*/
#include <thrust/binary_search.h>
#include <thrust/execution_policy.h>
@ -8,8 +8,8 @@
#include <thrust/transform_scan.h>
#include <thrust/unique.h>
#include <limits> // std::numeric_limits
#include <memory>
#include <limits> // std::numeric_limits
#include <numeric> // for partial_sum
#include <utility>
#include "../collective/communicator-inl.cuh"

View File

@ -1,8 +1,9 @@
/**
* Copyright 2020-2024, XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_QUANTILE_CUH_
#define XGBOOST_COMMON_QUANTILE_CUH_
#include <memory>
#include "xgboost/span.h"
#include "xgboost/data.h"
#include "device_helpers.cuh"

View File

@ -1,9 +1,8 @@
/*!
* Copyright by Contributors 2019
/**
* Copyright 2019-2024, XGBoost Contributors
*/
#include "timer.h"
#include <sstream>
#include <utility>
#include "../collective/communicator-inl.h"
@ -61,6 +60,9 @@ void Monitor::Print() const {
kv.second.timer.elapsed)
.count());
}
if (stat_map.empty()) {
return;
}
LOG(CONSOLE) << "======== Monitor (" << rank << "): " << label_ << " ========";
this->PrintStatistics(stat_map);
}

View File

@ -11,7 +11,6 @@
#include <cmath> // for abs
#include <cstdint> // for uint64_t, int32_t, uint8_t, uint32_t
#include <cstring> // for size_t, strcmp, memcpy
#include <exception> // for exception
#include <iostream> // for operator<<, basic_ostream, basic_ostream::op...
#include <map> // for map, operator!=
#include <numeric> // for accumulate, partial_sum
@ -22,7 +21,6 @@
#include "../collective/communicator.h" // for Operation
#include "../common/algorithm.h" // for StableSort
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/common.h" // for Split
#include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData
#include "../common/group_data.h" // for ParallelGroupBuilder
#include "../common/io.h" // for PeekableInStream
@ -473,11 +471,11 @@ void MetaInfo::SetInfo(Context const& ctx, StringView key, StringView interface_
<< ", must have at least 1 column even if it's empty.";
auto const& first = get<Object const>(array.front());
auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(first);
is_cuda = ArrayInterfaceHandler::IsCudaPtr(ptr);
is_cuda = first.find("stream") != first.cend() || ArrayInterfaceHandler::IsCudaPtr(ptr);
} else {
auto const& first = get<Object const>(j_interface);
auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(first);
is_cuda = ArrayInterfaceHandler::IsCudaPtr(ptr);
is_cuda = first.find("stream") != first.cend() || ArrayInterfaceHandler::IsCudaPtr(ptr);
}
if (is_cuda) {
@ -567,46 +565,6 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
}
}
void MetaInfo::SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype,
size_t num) {
CHECK(key);
auto proc = [&](auto cast_d_ptr) {
using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
auto t = linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, DeviceOrd::CPU());
CHECK(t.CContiguous());
Json interface {
linalg::ArrayInterface(t)
};
assert(ArrayInterface<1>{interface}.is_contiguous);
return interface;
};
// Legacy code using XGBoost dtype, which is a small subset of array interface types.
switch (dtype) {
case xgboost::DataType::kFloat32: {
auto cast_ptr = reinterpret_cast<const float*>(dptr);
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break;
}
case xgboost::DataType::kDouble: {
auto cast_ptr = reinterpret_cast<const double*>(dptr);
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break;
}
case xgboost::DataType::kUInt32: {
auto cast_ptr = reinterpret_cast<const uint32_t*>(dptr);
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break;
}
case xgboost::DataType::kUInt64: {
auto cast_ptr = reinterpret_cast<const uint64_t*>(dptr);
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break;
}
default:
LOG(FATAL) << "Unknown data type" << static_cast<uint8_t>(dtype);
}
}
void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
const void** out_dptr) const {
if (dtype == DataType::kFloat32) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021-2023, XGBoost contributors
* Copyright 2021-2024, XGBoost contributors
*/
#include "file_iterator.h"
@ -10,7 +10,10 @@
#include <ostream> // for operator<<, basic_ostream, istringstream
#include <vector> // for vector
#include "../common/common.h" // for Split
#include "../common/common.h" // for Split
#include "xgboost/linalg.h" // for ArrayInterfaceStr, MakeVec
#include "xgboost/linalg.h"
#include "xgboost/logging.h" // for CHECK
#include "xgboost/string_view.h" // for operator<<, StringView
namespace xgboost::data {
@ -28,10 +31,10 @@ std::string ValidateFileFormat(std::string const& uri) {
for (size_t i = 0; i < arg_list.size(); ++i) {
std::istringstream is(arg_list[i]);
std::pair<std::string, std::string> kv;
CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format"
<< " for key in arg " << i + 1;
CHECK(std::getline(is, kv.second)) << "Invalid uri argument format"
<< " for value in arg " << i + 1;
CHECK(std::getline(is, kv.first, '='))
<< "Invalid uri argument format" << " for key in arg " << i + 1;
CHECK(std::getline(is, kv.second))
<< "Invalid uri argument format" << " for value in arg " << i + 1;
args.insert(kv);
}
if (args.find("format") == args.cend()) {
@ -48,4 +51,41 @@ std::string ValidateFileFormat(std::string const& uri) {
return name_args[0] + "?" + name_args[1] + '#' + name_args_cache[1];
}
}
int FileIterator::Next() {
CHECK(parser_);
if (parser_->Next()) {
row_block_ = parser_->Value();
indptr_ = linalg::Make1dInterface(row_block_.offset, row_block_.size + 1);
values_ = linalg::Make1dInterface(row_block_.value, row_block_.offset[row_block_.size]);
indices_ = linalg::Make1dInterface(row_block_.index, row_block_.offset[row_block_.size]);
size_t n_columns =
*std::max_element(row_block_.index, row_block_.index + row_block_.offset[row_block_.size]);
// dmlc parser converts 1-based indexing back to 0-based indexing so we can ignore
// this condition and just add 1 to n_columns
n_columns += 1;
XGProxyDMatrixSetDataCSR(proxy_, indptr_.c_str(), indices_.c_str(), values_.c_str(), n_columns);
if (row_block_.label) {
auto str = linalg::Make1dInterface(row_block_.label, row_block_.size);
XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str());
}
if (row_block_.qid) {
auto str = linalg::Make1dInterface(row_block_.qid, row_block_.size);
XGDMatrixSetInfoFromInterface(proxy_, "qid", str.c_str());
}
if (row_block_.weight) {
auto str = linalg::Make1dInterface(row_block_.weight, row_block_.size);
XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str());
}
// Continue iteration
return true;
} else {
// Stop iteration
return false;
}
}
} // namespace xgboost::data

View File

@ -1,20 +1,16 @@
/**
* Copyright 2021-2023, XGBoost contributors
* Copyright 2021-2024, XGBoost contributors
*/
#ifndef XGBOOST_DATA_FILE_ITERATOR_H_
#define XGBOOST_DATA_FILE_ITERATOR_H_
#include <algorithm> // for max_element
#include <cstddef> // for size_t
#include <cstdint> // for uint32_t
#include <memory> // for unique_ptr
#include <string> // for string
#include <utility> // for move
#include "dmlc/data.h" // for RowBlock, Parser
#include "xgboost/c_api.h" // for XGDMatrixSetDenseInfo, XGDMatrixFree, XGProxyDMatrixCreate
#include "xgboost/linalg.h" // for ArrayInterfaceStr, MakeVec
#include "xgboost/logging.h" // for CHECK
#include "xgboost/c_api.h" // for XGDMatrixFree, XGProxyDMatrixCreate
namespace xgboost::data {
[[nodiscard]] std::string ValidateFileFormat(std::string const& uri);
@ -53,41 +49,7 @@ class FileIterator {
XGDMatrixFree(proxy_);
}
int Next() {
CHECK(parser_);
if (parser_->Next()) {
row_block_ = parser_->Value();
using linalg::MakeVec;
indptr_ = ArrayInterfaceStr(MakeVec(row_block_.offset, row_block_.size + 1));
values_ = ArrayInterfaceStr(MakeVec(row_block_.value, row_block_.offset[row_block_.size]));
indices_ = ArrayInterfaceStr(MakeVec(row_block_.index, row_block_.offset[row_block_.size]));
size_t n_columns = *std::max_element(row_block_.index,
row_block_.index + row_block_.offset[row_block_.size]);
// dmlc parser converts 1-based indexing back to 0-based indexing so we can ignore
// this condition and just add 1 to n_columns
n_columns += 1;
XGProxyDMatrixSetDataCSR(proxy_, indptr_.c_str(), indices_.c_str(),
values_.c_str(), n_columns);
if (row_block_.label) {
XGDMatrixSetDenseInfo(proxy_, "label", row_block_.label, row_block_.size, 1);
}
if (row_block_.qid) {
XGDMatrixSetDenseInfo(proxy_, "qid", row_block_.qid, row_block_.size, 1);
}
if (row_block_.weight) {
XGDMatrixSetDenseInfo(proxy_, "weight", row_block_.weight, row_block_.size, 1);
}
// Continue iteration
return true;
} else {
// Stop iteration
return false;
}
}
int Next();
auto Proxy() -> decltype(proxy_) { return proxy_; }

View File

@ -1,5 +1,5 @@
/**
* Copyright 2014-2023, XGBoost Contributors
* Copyright 2014-2024, XGBoost Contributors
* \file sparse_page_source.h
*/
#ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
@ -7,23 +7,26 @@
#include <algorithm> // for min
#include <atomic> // for atomic
#include <cstdio> // for remove
#include <future> // for async
#include <map>
#include <memory>
#include <mutex> // for mutex
#include <string>
#include <thread>
#include <utility> // for pair, move
#include <vector>
#include <memory> // for unique_ptr
#include <mutex> // for mutex
#include <string> // for string
#include <utility> // for pair, move
#include <vector> // for vector
#include "../common/common.h"
#include "../common/io.h" // for PrivateMmapConstStream
#include "../common/timer.h" // for Monitor, Timer
#include "adapter.h"
#include "proxy_dmatrix.h" // for DMatrixProxy
#include "sparse_page_writer.h" // for SparsePageFormat
#include "xgboost/base.h"
#include "xgboost/data.h"
#if !defined(XGBOOST_USE_CUDA)
#include "../common/common.h" // for AssertGPUSupport
#endif // !defined(XGBOOST_USE_CUDA)
#include "../common/io.h" // for PrivateMmapConstStream
#include "../common/timer.h" // for Monitor, Timer
#include "proxy_dmatrix.h" // for DMatrixProxy
#include "sparse_page_writer.h" // for SparsePageFormat
#include "xgboost/base.h" // for bst_feature_t
#include "xgboost/data.h" // for SparsePage, CSCPage
#include "xgboost/global_config.h" // for GlobalConfigThreadLocalStore
#include "xgboost/logging.h" // for CHECK_EQ
namespace xgboost::data {
inline void TryDeleteCacheFile(const std::string& file) {
@ -185,6 +188,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
exce_.Rethrow();
auto const config = *GlobalConfigThreadLocalStore::Get();
for (std::int32_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
fetch_it %= n_batches_; // ring
if (ring_->at(fetch_it).valid()) {
@ -192,7 +196,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
}
auto const* self = this; // make sure it's const
CHECK_LT(fetch_it, cache_info_->offset.size());
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, this]() {
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, config, this]() {
*GlobalConfigThreadLocalStore::Get() = config;
auto page = std::make_shared<S>();
this->exce_.Run([&] {
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};

View File

@ -1,5 +1,5 @@
/**
* Copyright 2014-2023 by Contributors
* Copyright 2014-2024, XGBoost Contributors
* \file gbtree.cc
* \brief gradient boosted tree implementation.
* \author Tianqi Chen
@ -11,14 +11,12 @@
#include <algorithm>
#include <cstdint> // std::int32_t
#include <map>
#include <memory>
#include <numeric> // for iota
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "../common/common.h"
#include "../common/timer.h"
#include "../tree/param.h" // TrainParam
#include "gbtree_model.h"

View File

@ -10,15 +10,15 @@
#include <array>
#include <cmath>
#include <numeric> // for accumulate
#include "../collective/communicator-inl.h"
#include "../common/common.h" // MetricNoCache
#include "../common/common.h" // for AssertGPUSupport
#include "../common/math.h"
#include "../common/optional_weight.h" // OptionalWeights
#include "../common/pseudo_huber.h"
#include "../common/quantile_loss_utils.h" // QuantileLossParam
#include "../common/threading_utils.h"
#include "metric_common.h"
#include "metric_common.h" // MetricNoCache
#include "xgboost/collective/result.h" // for SafeColl
#include "xgboost/metric.h"

View File

@ -9,8 +9,6 @@
#include <string>
#include "../collective/aggregator.h"
#include "../collective/communicator-inl.h"
#include "../common/common.h"
#include "xgboost/metric.h"
namespace xgboost {

View File

@ -9,8 +9,8 @@
#include <array>
#include <atomic>
#include <cmath>
#include <numeric> // for accumulate
#include "../collective/communicator-inl.h"
#include "../common/math.h"
#include "../common/threading_utils.h"
#include "metric_common.h" // MetricNoCache

View File

@ -9,10 +9,9 @@
#include <array>
#include <memory>
#include <numeric> // for accumulate
#include <vector>
#include "../collective/communicator-inl.h"
#include "../common/math.h"
#include "../common/survival_util.h"
#include "../common/threading_utils.h"
#include "metric_common.h" // MetricNoCache

View File

@ -1,13 +1,13 @@
/**
* Copyright 2019-2023 by XGBoost Contributors
* Copyright 2019-2024, XGBoost Contributors
*/
#include <thrust/functional.h>
#include <thrust/random.h>
#include <thrust/sort.h> // for sort
#include <thrust/transform.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/logging.h>
#include <algorithm>
#include <cstddef> // for size_t
#include <limits>
#include <utility>

Some files were not shown because too many files have changed in this diff Show More