merge latest change from upstream
This commit is contained in:
commit
8b75204fed
8
.github/dependabot.yml
vendored
8
.github/dependabot.yml
vendored
@ -8,7 +8,7 @@ updates:
|
|||||||
- package-ecosystem: "maven"
|
- package-ecosystem: "maven"
|
||||||
directory: "/jvm-packages"
|
directory: "/jvm-packages"
|
||||||
schedule:
|
schedule:
|
||||||
interval: "daily"
|
interval: "monthly"
|
||||||
- package-ecosystem: "maven"
|
- package-ecosystem: "maven"
|
||||||
directory: "/jvm-packages/xgboost4j"
|
directory: "/jvm-packages/xgboost4j"
|
||||||
schedule:
|
schedule:
|
||||||
@ -16,11 +16,11 @@ updates:
|
|||||||
- package-ecosystem: "maven"
|
- package-ecosystem: "maven"
|
||||||
directory: "/jvm-packages/xgboost4j-gpu"
|
directory: "/jvm-packages/xgboost4j-gpu"
|
||||||
schedule:
|
schedule:
|
||||||
interval: "daily"
|
interval: "monthly"
|
||||||
- package-ecosystem: "maven"
|
- package-ecosystem: "maven"
|
||||||
directory: "/jvm-packages/xgboost4j-example"
|
directory: "/jvm-packages/xgboost4j-example"
|
||||||
schedule:
|
schedule:
|
||||||
interval: "daily"
|
interval: "monthly"
|
||||||
- package-ecosystem: "maven"
|
- package-ecosystem: "maven"
|
||||||
directory: "/jvm-packages/xgboost4j-spark"
|
directory: "/jvm-packages/xgboost4j-spark"
|
||||||
schedule:
|
schedule:
|
||||||
@ -28,4 +28,4 @@ updates:
|
|||||||
- package-ecosystem: "maven"
|
- package-ecosystem: "maven"
|
||||||
directory: "/jvm-packages/xgboost4j-spark-gpu"
|
directory: "/jvm-packages/xgboost4j-spark-gpu"
|
||||||
schedule:
|
schedule:
|
||||||
interval: "daily"
|
interval: "monthly"
|
||||||
|
|||||||
8
.github/workflows/r_tests.yml
vendored
8
.github/workflows/r_tests.yml
vendored
@ -110,7 +110,7 @@ jobs:
|
|||||||
name: Test R package on Debian
|
name: Test R package on Debian
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
container:
|
container:
|
||||||
image: rhub/debian-gcc-devel
|
image: rhub/debian-gcc-release
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Install system dependencies
|
- name: Install system dependencies
|
||||||
@ -130,12 +130,12 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
shell: bash -l {0}
|
shell: bash -l {0}
|
||||||
run: |
|
run: |
|
||||||
/tmp/R-devel/bin/Rscript -e "source('./R-package/tests/helper_scripts/install_deps.R')"
|
Rscript -e "source('./R-package/tests/helper_scripts/install_deps.R')"
|
||||||
|
|
||||||
- name: Test R
|
- name: Test R
|
||||||
shell: bash -l {0}
|
shell: bash -l {0}
|
||||||
run: |
|
run: |
|
||||||
python3 tests/ci_build/test_r_package.py --r=/tmp/R-devel/bin/R --build-tool=autotools --task=check
|
python3 tests/ci_build/test_r_package.py --r=/usr/bin/R --build-tool=autotools --task=check
|
||||||
|
|
||||||
- uses: dorny/paths-filter@v2
|
- uses: dorny/paths-filter@v2
|
||||||
id: changes
|
id: changes
|
||||||
@ -147,4 +147,4 @@ jobs:
|
|||||||
- name: Run document check
|
- name: Run document check
|
||||||
if: steps.changes.outputs.r_package == 'true'
|
if: steps.changes.outputs.r_package == 'true'
|
||||||
run: |
|
run: |
|
||||||
python3 tests/ci_build/test_r_package.py --r=/tmp/R-devel/bin/R --task=doc
|
python3 tests/ci_build/test_r_package.py --r=/usr/bin/R --task=doc
|
||||||
|
|||||||
4
.github/workflows/update_rapids.yml
vendored
4
.github/workflows/update_rapids.yml
vendored
@ -3,7 +3,7 @@ name: update-rapids
|
|||||||
on:
|
on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
schedule:
|
schedule:
|
||||||
- cron: "0 20 * * *" # Run once daily
|
- cron: "0 20 * * 1" # Run once weekly
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
pull-requests: write
|
pull-requests: write
|
||||||
@ -32,7 +32,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
bash tests/buildkite/update-rapids.sh
|
bash tests/buildkite/update-rapids.sh
|
||||||
- name: Create Pull Request
|
- name: Create Pull Request
|
||||||
uses: peter-evans/create-pull-request@v5
|
uses: peter-evans/create-pull-request@v6
|
||||||
if: github.ref == 'refs/heads/master'
|
if: github.ref == 'refs/heads/master'
|
||||||
with:
|
with:
|
||||||
add-paths: |
|
add-paths: |
|
||||||
|
|||||||
2
NEWS.md
2
NEWS.md
@ -2101,7 +2101,7 @@ This release marks a major milestone for the XGBoost project.
|
|||||||
## v0.90 (2019.05.18)
|
## v0.90 (2019.05.18)
|
||||||
|
|
||||||
### XGBoost Python package drops Python 2.x (#4379, #4381)
|
### XGBoost Python package drops Python 2.x (#4379, #4381)
|
||||||
Python 2.x is reaching its end-of-life at the end of this year. [Many scientific Python packages are now moving to drop Python 2.x](https://python3statement.org/).
|
Python 2.x is reaching its end-of-life at the end of this year. [Many scientific Python packages are now moving to drop Python 2.x](https://python3statement.github.io/).
|
||||||
|
|
||||||
### XGBoost4J-Spark now requires Spark 2.4.x (#4377)
|
### XGBoost4J-Spark now requires Spark 2.4.x (#4377)
|
||||||
* Spark 2.3 is reaching its end-of-life soon. See discussion at #4389.
|
* Spark 2.3 is reaching its end-of-life soon. See discussion at #4389.
|
||||||
|
|||||||
@ -26,6 +26,11 @@ NVL <- function(x, val) {
|
|||||||
'multi:softprob', 'rank:pairwise', 'rank:ndcg', 'rank:map'))
|
'multi:softprob', 'rank:pairwise', 'rank:ndcg', 'rank:map'))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.RANKING_OBJECTIVES <- function() {
|
||||||
|
return(c('binary:logistic', 'binary:logitraw', 'binary:hinge', 'multi:softmax',
|
||||||
|
'multi:softprob'))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Low-level functions for boosting --------------------------------------------
|
# Low-level functions for boosting --------------------------------------------
|
||||||
@ -235,33 +240,43 @@ convert.labels <- function(labels, objective_name) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Generates random (stratified if needed) CV folds
|
# Generates random (stratified if needed) CV folds
|
||||||
generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
|
generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
|
||||||
|
if (NROW(group)) {
|
||||||
|
if (stratified) {
|
||||||
|
warning(
|
||||||
|
paste0(
|
||||||
|
"Stratified splitting is not supported when using 'group' attribute.",
|
||||||
|
" Will use unstratified splitting."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return(generate.group.folds(nfold, group))
|
||||||
|
}
|
||||||
|
objective <- params$objective
|
||||||
|
if (!is.character(objective)) {
|
||||||
|
warning("Will use unstratified splitting (custom objective used)")
|
||||||
|
stratified <- FALSE
|
||||||
|
}
|
||||||
|
# cannot stratify if label is NULL
|
||||||
|
if (stratified && is.null(label)) {
|
||||||
|
warning("Will use unstratified splitting (no 'labels' available)")
|
||||||
|
stratified <- FALSE
|
||||||
|
}
|
||||||
|
|
||||||
# cannot do it for rank
|
# cannot do it for rank
|
||||||
objective <- params$objective
|
|
||||||
if (is.character(objective) && strtrim(objective, 5) == 'rank:') {
|
if (is.character(objective) && strtrim(objective, 5) == 'rank:') {
|
||||||
stop("\n\tAutomatic generation of CV-folds is not implemented for ranking!\n",
|
stop("\n\tAutomatic generation of CV-folds is not implemented for ranking without 'group' field!\n",
|
||||||
"\tConsider providing pre-computed CV-folds through the 'folds=' parameter.\n")
|
"\tConsider providing pre-computed CV-folds through the 'folds=' parameter.\n")
|
||||||
}
|
}
|
||||||
# shuffle
|
# shuffle
|
||||||
rnd_idx <- sample.int(nrows)
|
rnd_idx <- sample.int(nrows)
|
||||||
if (stratified &&
|
if (stratified && length(label) == length(rnd_idx)) {
|
||||||
length(label) == length(rnd_idx)) {
|
|
||||||
y <- label[rnd_idx]
|
y <- label[rnd_idx]
|
||||||
# WARNING: some heuristic logic is employed to identify classification setting!
|
|
||||||
# - For classification, need to convert y labels to factor before making the folds,
|
# - For classification, need to convert y labels to factor before making the folds,
|
||||||
# and then do stratification by factor levels.
|
# and then do stratification by factor levels.
|
||||||
# - For regression, leave y numeric and do stratification by quantiles.
|
# - For regression, leave y numeric and do stratification by quantiles.
|
||||||
if (is.character(objective)) {
|
if (is.character(objective)) {
|
||||||
y <- convert.labels(y, params$objective)
|
y <- convert.labels(y, objective)
|
||||||
} else {
|
|
||||||
# If no 'objective' given in params, it means that user either wants to
|
|
||||||
# use the default 'reg:squarederror' objective or has provided a custom
|
|
||||||
# obj function. Here, assume classification setting when y has 5 or less
|
|
||||||
# unique values:
|
|
||||||
if (length(unique(y)) <= 5) {
|
|
||||||
y <- factor(y)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
folds <- xgb.createFolds(y = y, k = nfold)
|
folds <- xgb.createFolds(y = y, k = nfold)
|
||||||
} else {
|
} else {
|
||||||
@ -277,6 +292,29 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
|
|||||||
return(folds)
|
return(folds)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
generate.group.folds <- function(nfold, group) {
|
||||||
|
ngroups <- length(group) - 1
|
||||||
|
if (ngroups < nfold) {
|
||||||
|
stop("DMatrix has fewer groups than folds.")
|
||||||
|
}
|
||||||
|
seq_groups <- seq_len(ngroups)
|
||||||
|
indices <- lapply(seq_groups, function(gr) seq(group[gr] + 1, group[gr + 1]))
|
||||||
|
assignments <- base::split(seq_groups, as.integer(seq_groups %% nfold))
|
||||||
|
assignments <- unname(assignments)
|
||||||
|
|
||||||
|
out <- vector("list", nfold)
|
||||||
|
randomized_groups <- sample(ngroups)
|
||||||
|
for (idx in seq_len(nfold)) {
|
||||||
|
groups_idx_test <- randomized_groups[assignments[[idx]]]
|
||||||
|
groups_test <- indices[groups_idx_test]
|
||||||
|
idx_test <- unlist(groups_test)
|
||||||
|
attributes(idx_test)$group_test <- lengths(groups_test)
|
||||||
|
attributes(idx_test)$group_train <- lengths(indices[-groups_idx_test])
|
||||||
|
out[[idx]] <- idx_test
|
||||||
|
}
|
||||||
|
return(out)
|
||||||
|
}
|
||||||
|
|
||||||
# Creates CV folds stratified by the values of y.
|
# Creates CV folds stratified by the values of y.
|
||||||
# It was borrowed from caret::createFolds and simplified
|
# It was borrowed from caret::createFolds and simplified
|
||||||
# by always returning an unnamed list of fold indices.
|
# by always returning an unnamed list of fold indices.
|
||||||
|
|||||||
@ -1259,8 +1259,11 @@ xgb.get.DMatrix.data <- function(dmat) {
|
|||||||
#' Get a new DMatrix containing the specified rows of
|
#' Get a new DMatrix containing the specified rows of
|
||||||
#' original xgb.DMatrix object
|
#' original xgb.DMatrix object
|
||||||
#'
|
#'
|
||||||
#' @param object Object of class "xgb.DMatrix"
|
#' @param object Object of class "xgb.DMatrix".
|
||||||
#' @param idxset a integer vector of indices of rows needed
|
#' @param idxset An integer vector of indices of rows needed (base-1 indexing).
|
||||||
|
#' @param allow_groups Whether to allow slicing an `xgb.DMatrix` with `group` (or
|
||||||
|
#' equivalently `qid`) field. Note that in such case, the result will not have
|
||||||
|
#' the groups anymore - they need to be set manually through `setinfo`.
|
||||||
#' @param colset currently not used (columns subsetting is not available)
|
#' @param colset currently not used (columns subsetting is not available)
|
||||||
#'
|
#'
|
||||||
#' @examples
|
#' @examples
|
||||||
@ -1275,11 +1278,11 @@ xgb.get.DMatrix.data <- function(dmat) {
|
|||||||
#'
|
#'
|
||||||
#' @rdname xgb.slice.DMatrix
|
#' @rdname xgb.slice.DMatrix
|
||||||
#' @export
|
#' @export
|
||||||
xgb.slice.DMatrix <- function(object, idxset) {
|
xgb.slice.DMatrix <- function(object, idxset, allow_groups = FALSE) {
|
||||||
if (!inherits(object, "xgb.DMatrix")) {
|
if (!inherits(object, "xgb.DMatrix")) {
|
||||||
stop("object must be xgb.DMatrix")
|
stop("object must be xgb.DMatrix")
|
||||||
}
|
}
|
||||||
ret <- .Call(XGDMatrixSliceDMatrix_R, object, idxset)
|
ret <- .Call(XGDMatrixSliceDMatrix_R, object, idxset, allow_groups)
|
||||||
|
|
||||||
attr_list <- attributes(object)
|
attr_list <- attributes(object)
|
||||||
nr <- nrow(object)
|
nr <- nrow(object)
|
||||||
@ -1296,7 +1299,15 @@ xgb.slice.DMatrix <- function(object, idxset) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return(structure(ret, class = "xgb.DMatrix"))
|
|
||||||
|
out <- structure(ret, class = "xgb.DMatrix")
|
||||||
|
parent_fields <- as.list(attributes(object)$fields)
|
||||||
|
if (NROW(parent_fields)) {
|
||||||
|
child_fields <- parent_fields[!(names(parent_fields) %in% c("group", "qid"))]
|
||||||
|
child_fields <- as.environment(child_fields)
|
||||||
|
attributes(out)$fields <- child_fields
|
||||||
|
}
|
||||||
|
return(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
#' @rdname xgb.slice.DMatrix
|
#' @rdname xgb.slice.DMatrix
|
||||||
@ -1340,11 +1351,11 @@ print.xgb.DMatrix <- function(x, verbose = FALSE, ...) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cat(class_print, ' dim:', nrow(x), 'x', ncol(x), ' info: ')
|
cat(class_print, ' dim:', nrow(x), 'x', ncol(x), ' info: ')
|
||||||
infos <- character(0)
|
infos <- names(attributes(x)$fields)
|
||||||
if (xgb.DMatrix.hasinfo(x, 'label')) infos <- 'label'
|
infos <- infos[infos != "feature_name"]
|
||||||
if (xgb.DMatrix.hasinfo(x, 'weight')) infos <- c(infos, 'weight')
|
if (!NROW(infos)) infos <- "NA"
|
||||||
if (xgb.DMatrix.hasinfo(x, 'base_margin')) infos <- c(infos, 'base_margin')
|
infos <- infos[order(infos)]
|
||||||
if (length(infos) == 0) infos <- 'NA'
|
infos <- paste(infos, collapse = ", ")
|
||||||
cat(infos)
|
cat(infos)
|
||||||
cnames <- colnames(x)
|
cnames <- colnames(x)
|
||||||
cat(' colnames:')
|
cat(' colnames:')
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
#' Cross Validation
|
#' Cross Validation
|
||||||
#'
|
#'
|
||||||
#' The cross validation function of xgboost
|
#' The cross validation function of xgboost.
|
||||||
#'
|
#'
|
||||||
#' @param params the list of parameters. The complete list of parameters is
|
#' @param params the list of parameters. The complete list of parameters is
|
||||||
#' available in the \href{http://xgboost.readthedocs.io/en/latest/parameter.html}{online documentation}. Below
|
#' available in the \href{http://xgboost.readthedocs.io/en/latest/parameter.html}{online documentation}. Below
|
||||||
@ -19,13 +19,17 @@
|
|||||||
#'
|
#'
|
||||||
#' See \code{\link{xgb.train}} for further details.
|
#' See \code{\link{xgb.train}} for further details.
|
||||||
#' See also demo/ for walkthrough example in R.
|
#' See also demo/ for walkthrough example in R.
|
||||||
#' @param data takes an \code{xgb.DMatrix}, \code{matrix}, or \code{dgCMatrix} as the input.
|
#'
|
||||||
|
#' Note that, while `params` accepts a `seed` entry and will use such parameter for model training if
|
||||||
|
#' supplied, this seed is not used for creation of train-test splits, which instead rely on R's own RNG
|
||||||
|
#' system - thus, for reproducible results, one needs to call the `set.seed` function beforehand.
|
||||||
|
#' @param data An `xgb.DMatrix` object, with corresponding fields like `label` or bounds as required
|
||||||
|
#' for model training by the objective.
|
||||||
|
#'
|
||||||
|
#' Note that only the basic `xgb.DMatrix` class is supported - variants such as `xgb.QuantileDMatrix`
|
||||||
|
#' or `xgb.ExternalDMatrix` are not supported here.
|
||||||
#' @param nrounds the max number of iterations
|
#' @param nrounds the max number of iterations
|
||||||
#' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples.
|
#' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples.
|
||||||
#' @param label vector of response values. Should be provided only when data is an R-matrix.
|
|
||||||
#' @param missing is only used when input is a dense matrix. By default is set to NA, which means
|
|
||||||
#' that NA values should be considered as 'missing' by the algorithm.
|
|
||||||
#' Sometimes, 0 or other extreme value might be used to represent missing values.
|
|
||||||
#' @param prediction A logical value indicating whether to return the test fold predictions
|
#' @param prediction A logical value indicating whether to return the test fold predictions
|
||||||
#' from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback.
|
#' from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback.
|
||||||
#' @param showsd \code{boolean}, whether to show standard deviation of cross validation
|
#' @param showsd \code{boolean}, whether to show standard deviation of cross validation
|
||||||
@ -47,13 +51,30 @@
|
|||||||
#' @param feval customized evaluation function. Returns
|
#' @param feval customized evaluation function. Returns
|
||||||
#' \code{list(metric='metric-name', value='metric-value')} with given
|
#' \code{list(metric='metric-name', value='metric-value')} with given
|
||||||
#' prediction and dtrain.
|
#' prediction and dtrain.
|
||||||
#' @param stratified a \code{boolean} indicating whether sampling of folds should be stratified
|
#' @param stratified A \code{boolean} indicating whether sampling of folds should be stratified
|
||||||
#' by the values of outcome labels.
|
#' by the values of outcome labels. For real-valued labels in regression objectives,
|
||||||
|
#' stratification will be done by discretizing the labels into up to 5 buckets beforehand.
|
||||||
|
#'
|
||||||
|
#' If passing "auto", will be set to `TRUE` if the objective in `params` is a classification
|
||||||
|
#' objective (from XGBoost's built-in objectives, doesn't apply to custom ones), and to
|
||||||
|
#' `FALSE` otherwise.
|
||||||
|
#'
|
||||||
|
#' This parameter is ignored when `data` has a `group` field - in such case, the splitting
|
||||||
|
#' will be based on whole groups (note that this might make the folds have different sizes).
|
||||||
|
#'
|
||||||
|
#' Value `TRUE` here is \bold{not} supported for custom objectives.
|
||||||
#' @param folds \code{list} provides a possibility to use a list of pre-defined CV folds
|
#' @param folds \code{list} provides a possibility to use a list of pre-defined CV folds
|
||||||
#' (each element must be a vector of test fold's indices). When folds are supplied,
|
#' (each element must be a vector of test fold's indices). When folds are supplied,
|
||||||
#' the \code{nfold} and \code{stratified} parameters are ignored.
|
#' the \code{nfold} and \code{stratified} parameters are ignored.
|
||||||
|
#'
|
||||||
|
#' If `data` has a `group` field and the objective requires this field, each fold (list element)
|
||||||
|
#' must additionally have two attributes (retrievable through \link{attributes}) named `group_test`
|
||||||
|
#' and `group_train`, which should hold the `group` to assign through \link{setinfo.xgb.DMatrix} to
|
||||||
|
#' the resulting DMatrices.
|
||||||
#' @param train_folds \code{list} list specifying which indicies to use for training. If \code{NULL}
|
#' @param train_folds \code{list} list specifying which indicies to use for training. If \code{NULL}
|
||||||
#' (the default) all indices not specified in \code{folds} will be used for training.
|
#' (the default) all indices not specified in \code{folds} will be used for training.
|
||||||
|
#'
|
||||||
|
#' This is not supported when `data` has `group` field.
|
||||||
#' @param verbose \code{boolean}, print the statistics during the process
|
#' @param verbose \code{boolean}, print the statistics during the process
|
||||||
#' @param print_every_n Print each n-th iteration evaluation messages when \code{verbose>0}.
|
#' @param print_every_n Print each n-th iteration evaluation messages when \code{verbose>0}.
|
||||||
#' Default is 1 which means all messages are printed. This parameter is passed to the
|
#' Default is 1 which means all messages are printed. This parameter is passed to the
|
||||||
@ -118,13 +139,14 @@
|
|||||||
#' print(cv, verbose=TRUE)
|
#' print(cv, verbose=TRUE)
|
||||||
#'
|
#'
|
||||||
#' @export
|
#' @export
|
||||||
xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing = NA,
|
xgb.cv <- function(params = list(), data, nrounds, nfold,
|
||||||
prediction = FALSE, showsd = TRUE, metrics = list(),
|
prediction = FALSE, showsd = TRUE, metrics = list(),
|
||||||
obj = NULL, feval = NULL, stratified = TRUE, folds = NULL, train_folds = NULL,
|
obj = NULL, feval = NULL, stratified = "auto", folds = NULL, train_folds = NULL,
|
||||||
verbose = TRUE, print_every_n = 1L,
|
verbose = TRUE, print_every_n = 1L,
|
||||||
early_stopping_rounds = NULL, maximize = NULL, callbacks = list(), ...) {
|
early_stopping_rounds = NULL, maximize = NULL, callbacks = list(), ...) {
|
||||||
|
|
||||||
check.deprecation(...)
|
check.deprecation(...)
|
||||||
|
stopifnot(inherits(data, "xgb.DMatrix"))
|
||||||
if (inherits(data, "xgb.DMatrix") && .Call(XGCheckNullPtr_R, data)) {
|
if (inherits(data, "xgb.DMatrix") && .Call(XGCheckNullPtr_R, data)) {
|
||||||
stop("'data' is an invalid 'xgb.DMatrix' object. Must be constructed again.")
|
stop("'data' is an invalid 'xgb.DMatrix' object. Must be constructed again.")
|
||||||
}
|
}
|
||||||
@ -137,16 +159,22 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
|
|||||||
check.custom.obj()
|
check.custom.obj()
|
||||||
check.custom.eval()
|
check.custom.eval()
|
||||||
|
|
||||||
# Check the labels
|
if (stratified == "auto") {
|
||||||
if ((inherits(data, 'xgb.DMatrix') && !xgb.DMatrix.hasinfo(data, 'label')) ||
|
if (is.character(params$objective)) {
|
||||||
(!inherits(data, 'xgb.DMatrix') && is.null(label))) {
|
stratified <- (
|
||||||
stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix")
|
(params$objective %in% .CLASSIFICATION_OBJECTIVES())
|
||||||
} else if (inherits(data, 'xgb.DMatrix')) {
|
&& !(params$objective %in% .RANKING_OBJECTIVES())
|
||||||
if (!is.null(label))
|
)
|
||||||
warning("xgb.cv: label will be ignored, since data is of type xgb.DMatrix")
|
|
||||||
cv_label <- getinfo(data, 'label')
|
|
||||||
} else {
|
} else {
|
||||||
cv_label <- label
|
stratified <- FALSE
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check the labels and groups
|
||||||
|
cv_label <- getinfo(data, "label")
|
||||||
|
cv_group <- getinfo(data, "group")
|
||||||
|
if (!is.null(train_folds) && NROW(cv_group)) {
|
||||||
|
stop("'train_folds' is not supported for DMatrix object with 'group' field.")
|
||||||
}
|
}
|
||||||
|
|
||||||
# CV folds
|
# CV folds
|
||||||
@ -157,7 +185,7 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
|
|||||||
} else {
|
} else {
|
||||||
if (nfold <= 1)
|
if (nfold <= 1)
|
||||||
stop("'nfold' must be > 1")
|
stop("'nfold' must be > 1")
|
||||||
folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, params)
|
folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, cv_group, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
@ -195,20 +223,18 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
|
|||||||
|
|
||||||
# create the booster-folds
|
# create the booster-folds
|
||||||
# train_folds
|
# train_folds
|
||||||
dall <- xgb.get.DMatrix(
|
dall <- data
|
||||||
data = data,
|
|
||||||
label = label,
|
|
||||||
missing = missing,
|
|
||||||
weight = NULL,
|
|
||||||
nthread = params$nthread
|
|
||||||
)
|
|
||||||
bst_folds <- lapply(seq_along(folds), function(k) {
|
bst_folds <- lapply(seq_along(folds), function(k) {
|
||||||
dtest <- xgb.slice.DMatrix(dall, folds[[k]])
|
dtest <- xgb.slice.DMatrix(dall, folds[[k]], allow_groups = TRUE)
|
||||||
# code originally contributed by @RolandASc on stackoverflow
|
# code originally contributed by @RolandASc on stackoverflow
|
||||||
if (is.null(train_folds))
|
if (is.null(train_folds))
|
||||||
dtrain <- xgb.slice.DMatrix(dall, unlist(folds[-k]))
|
dtrain <- xgb.slice.DMatrix(dall, unlist(folds[-k]), allow_groups = TRUE)
|
||||||
else
|
else
|
||||||
dtrain <- xgb.slice.DMatrix(dall, train_folds[[k]])
|
dtrain <- xgb.slice.DMatrix(dall, train_folds[[k]], allow_groups = TRUE)
|
||||||
|
if (!is.null(attributes(folds[[k]])$group_test)) {
|
||||||
|
setinfo(dtest, "group", attributes(folds[[k]])$group_test)
|
||||||
|
setinfo(dtrain, "group", attributes(folds[[k]])$group_train)
|
||||||
|
}
|
||||||
bst <- xgb.Booster(
|
bst <- xgb.Booster(
|
||||||
params = params,
|
params = params,
|
||||||
cachelist = list(dtrain, dtest),
|
cachelist = list(dtrain, dtest),
|
||||||
@ -312,7 +338,7 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
|
|||||||
#' @examples
|
#' @examples
|
||||||
#' data(agaricus.train, package='xgboost')
|
#' data(agaricus.train, package='xgboost')
|
||||||
#' train <- agaricus.train
|
#' train <- agaricus.train
|
||||||
#' cv <- xgb.cv(data = train$data, label = train$label, nfold = 5, max_depth = 2,
|
#' cv <- xgb.cv(data = xgb.DMatrix(train$data, label = train$label), nfold = 5, max_depth = 2,
|
||||||
#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
|
#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
|
||||||
#' print(cv)
|
#' print(cv)
|
||||||
#' print(cv, verbose=TRUE)
|
#' print(cv, verbose=TRUE)
|
||||||
|
|||||||
@ -23,7 +23,7 @@ including the best iteration (when available).
|
|||||||
\examples{
|
\examples{
|
||||||
data(agaricus.train, package='xgboost')
|
data(agaricus.train, package='xgboost')
|
||||||
train <- agaricus.train
|
train <- agaricus.train
|
||||||
cv <- xgb.cv(data = train$data, label = train$label, nfold = 5, max_depth = 2,
|
cv <- xgb.cv(data = xgb.DMatrix(train$data, label = train$label), nfold = 5, max_depth = 2,
|
||||||
eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
|
eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
|
||||||
print(cv)
|
print(cv)
|
||||||
print(cv, verbose=TRUE)
|
print(cv, verbose=TRUE)
|
||||||
|
|||||||
@ -9,14 +9,12 @@ xgb.cv(
|
|||||||
data,
|
data,
|
||||||
nrounds,
|
nrounds,
|
||||||
nfold,
|
nfold,
|
||||||
label = NULL,
|
|
||||||
missing = NA,
|
|
||||||
prediction = FALSE,
|
prediction = FALSE,
|
||||||
showsd = TRUE,
|
showsd = TRUE,
|
||||||
metrics = list(),
|
metrics = list(),
|
||||||
obj = NULL,
|
obj = NULL,
|
||||||
feval = NULL,
|
feval = NULL,
|
||||||
stratified = TRUE,
|
stratified = "auto",
|
||||||
folds = NULL,
|
folds = NULL,
|
||||||
train_folds = NULL,
|
train_folds = NULL,
|
||||||
verbose = TRUE,
|
verbose = TRUE,
|
||||||
@ -44,20 +42,23 @@ is a shorter summary:
|
|||||||
}
|
}
|
||||||
|
|
||||||
See \code{\link{xgb.train}} for further details.
|
See \code{\link{xgb.train}} for further details.
|
||||||
See also demo/ for walkthrough example in R.}
|
See also demo/ for walkthrough example in R.
|
||||||
|
|
||||||
\item{data}{takes an \code{xgb.DMatrix}, \code{matrix}, or \code{dgCMatrix} as the input.}
|
Note that, while \code{params} accepts a \code{seed} entry and will use such parameter for model training if
|
||||||
|
supplied, this seed is not used for creation of train-test splits, which instead rely on R's own RNG
|
||||||
|
system - thus, for reproducible results, one needs to call the \code{set.seed} function beforehand.}
|
||||||
|
|
||||||
|
\item{data}{An \code{xgb.DMatrix} object, with corresponding fields like \code{label} or bounds as required
|
||||||
|
for model training by the objective.
|
||||||
|
|
||||||
|
\if{html}{\out{<div class="sourceCode">}}\preformatted{ Note that only the basic `xgb.DMatrix` class is supported - variants such as `xgb.QuantileDMatrix`
|
||||||
|
or `xgb.ExternalDMatrix` are not supported here.
|
||||||
|
}\if{html}{\out{</div>}}}
|
||||||
|
|
||||||
\item{nrounds}{the max number of iterations}
|
\item{nrounds}{the max number of iterations}
|
||||||
|
|
||||||
\item{nfold}{the original dataset is randomly partitioned into \code{nfold} equal size subsamples.}
|
\item{nfold}{the original dataset is randomly partitioned into \code{nfold} equal size subsamples.}
|
||||||
|
|
||||||
\item{label}{vector of response values. Should be provided only when data is an R-matrix.}
|
|
||||||
|
|
||||||
\item{missing}{is only used when input is a dense matrix. By default is set to NA, which means
|
|
||||||
that NA values should be considered as 'missing' by the algorithm.
|
|
||||||
Sometimes, 0 or other extreme value might be used to represent missing values.}
|
|
||||||
|
|
||||||
\item{prediction}{A logical value indicating whether to return the test fold predictions
|
\item{prediction}{A logical value indicating whether to return the test fold predictions
|
||||||
from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback.}
|
from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback.}
|
||||||
|
|
||||||
@ -84,15 +85,35 @@ gradient with given prediction and dtrain.}
|
|||||||
\code{list(metric='metric-name', value='metric-value')} with given
|
\code{list(metric='metric-name', value='metric-value')} with given
|
||||||
prediction and dtrain.}
|
prediction and dtrain.}
|
||||||
|
|
||||||
\item{stratified}{a \code{boolean} indicating whether sampling of folds should be stratified
|
\item{stratified}{A \code{boolean} indicating whether sampling of folds should be stratified
|
||||||
by the values of outcome labels.}
|
by the values of outcome labels. For real-valued labels in regression objectives,
|
||||||
|
stratification will be done by discretizing the labels into up to 5 buckets beforehand.
|
||||||
|
|
||||||
|
\if{html}{\out{<div class="sourceCode">}}\preformatted{ If passing "auto", will be set to `TRUE` if the objective in `params` is a classification
|
||||||
|
objective (from XGBoost's built-in objectives, doesn't apply to custom ones), and to
|
||||||
|
`FALSE` otherwise.
|
||||||
|
|
||||||
|
This parameter is ignored when `data` has a `group` field - in such case, the splitting
|
||||||
|
will be based on whole groups (note that this might make the folds have different sizes).
|
||||||
|
|
||||||
|
Value `TRUE` here is \\bold\{not\} supported for custom objectives.
|
||||||
|
}\if{html}{\out{</div>}}}
|
||||||
|
|
||||||
\item{folds}{\code{list} provides a possibility to use a list of pre-defined CV folds
|
\item{folds}{\code{list} provides a possibility to use a list of pre-defined CV folds
|
||||||
(each element must be a vector of test fold's indices). When folds are supplied,
|
(each element must be a vector of test fold's indices). When folds are supplied,
|
||||||
the \code{nfold} and \code{stratified} parameters are ignored.}
|
the \code{nfold} and \code{stratified} parameters are ignored.
|
||||||
|
|
||||||
|
\if{html}{\out{<div class="sourceCode">}}\preformatted{ If `data` has a `group` field and the objective requires this field, each fold (list element)
|
||||||
|
must additionally have two attributes (retrievable through \link{attributes}) named `group_test`
|
||||||
|
and `group_train`, which should hold the `group` to assign through \link{setinfo.xgb.DMatrix} to
|
||||||
|
the resulting DMatrices.
|
||||||
|
}\if{html}{\out{</div>}}}
|
||||||
|
|
||||||
\item{train_folds}{\code{list} list specifying which indicies to use for training. If \code{NULL}
|
\item{train_folds}{\code{list} list specifying which indicies to use for training. If \code{NULL}
|
||||||
(the default) all indices not specified in \code{folds} will be used for training.}
|
(the default) all indices not specified in \code{folds} will be used for training.
|
||||||
|
|
||||||
|
\if{html}{\out{<div class="sourceCode">}}\preformatted{ This is not supported when `data` has `group` field.
|
||||||
|
}\if{html}{\out{</div>}}}
|
||||||
|
|
||||||
\item{verbose}{\code{boolean}, print the statistics during the process}
|
\item{verbose}{\code{boolean}, print the statistics during the process}
|
||||||
|
|
||||||
@ -142,7 +163,7 @@ such as saving also the models created during cross validation); or a list \code
|
|||||||
will contain elements such as \code{best_iteration} when using the early stopping callback (\link{xgb.cb.early.stop}).
|
will contain elements such as \code{best_iteration} when using the early stopping callback (\link{xgb.cb.early.stop}).
|
||||||
}
|
}
|
||||||
\description{
|
\description{
|
||||||
The cross validation function of xgboost
|
The cross validation function of xgboost.
|
||||||
}
|
}
|
||||||
\details{
|
\details{
|
||||||
The original sample is randomly partitioned into \code{nfold} equal size subsamples.
|
The original sample is randomly partitioned into \code{nfold} equal size subsamples.
|
||||||
|
|||||||
@ -6,14 +6,18 @@
|
|||||||
\title{Get a new DMatrix containing the specified rows of
|
\title{Get a new DMatrix containing the specified rows of
|
||||||
original xgb.DMatrix object}
|
original xgb.DMatrix object}
|
||||||
\usage{
|
\usage{
|
||||||
xgb.slice.DMatrix(object, idxset)
|
xgb.slice.DMatrix(object, idxset, allow_groups = FALSE)
|
||||||
|
|
||||||
\method{[}{xgb.DMatrix}(object, idxset, colset = NULL)
|
\method{[}{xgb.DMatrix}(object, idxset, colset = NULL)
|
||||||
}
|
}
|
||||||
\arguments{
|
\arguments{
|
||||||
\item{object}{Object of class "xgb.DMatrix"}
|
\item{object}{Object of class "xgb.DMatrix".}
|
||||||
|
|
||||||
\item{idxset}{a integer vector of indices of rows needed}
|
\item{idxset}{An integer vector of indices of rows needed (base-1 indexing).}
|
||||||
|
|
||||||
|
\item{allow_groups}{Whether to allow slicing an \code{xgb.DMatrix} with \code{group} (or
|
||||||
|
equivalently \code{qid}) field. Note that in such case, the result will not have
|
||||||
|
the groups anymore - they need to be set manually through \code{setinfo}.}
|
||||||
|
|
||||||
\item{colset}{currently not used (columns subsetting is not available)}
|
\item{colset}{currently not used (columns subsetting is not available)}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -99,10 +99,12 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/context.o \
|
$(PKGROOT)/src/context.o \
|
||||||
$(PKGROOT)/src/logging.o \
|
$(PKGROOT)/src/logging.o \
|
||||||
$(PKGROOT)/src/global_config.o \
|
$(PKGROOT)/src/global_config.o \
|
||||||
|
$(PKGROOT)/src/collective/result.o \
|
||||||
$(PKGROOT)/src/collective/allgather.o \
|
$(PKGROOT)/src/collective/allgather.o \
|
||||||
$(PKGROOT)/src/collective/allreduce.o \
|
$(PKGROOT)/src/collective/allreduce.o \
|
||||||
$(PKGROOT)/src/collective/broadcast.o \
|
$(PKGROOT)/src/collective/broadcast.o \
|
||||||
$(PKGROOT)/src/collective/comm.o \
|
$(PKGROOT)/src/collective/comm.o \
|
||||||
|
$(PKGROOT)/src/collective/comm_group.o \
|
||||||
$(PKGROOT)/src/collective/coll.o \
|
$(PKGROOT)/src/collective/coll.o \
|
||||||
$(PKGROOT)/src/collective/communicator-inl.o \
|
$(PKGROOT)/src/collective/communicator-inl.o \
|
||||||
$(PKGROOT)/src/collective/tracker.o \
|
$(PKGROOT)/src/collective/tracker.o \
|
||||||
|
|||||||
@ -99,10 +99,12 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/context.o \
|
$(PKGROOT)/src/context.o \
|
||||||
$(PKGROOT)/src/logging.o \
|
$(PKGROOT)/src/logging.o \
|
||||||
$(PKGROOT)/src/global_config.o \
|
$(PKGROOT)/src/global_config.o \
|
||||||
|
$(PKGROOT)/src/collective/result.o \
|
||||||
$(PKGROOT)/src/collective/allgather.o \
|
$(PKGROOT)/src/collective/allgather.o \
|
||||||
$(PKGROOT)/src/collective/allreduce.o \
|
$(PKGROOT)/src/collective/allreduce.o \
|
||||||
$(PKGROOT)/src/collective/broadcast.o \
|
$(PKGROOT)/src/collective/broadcast.o \
|
||||||
$(PKGROOT)/src/collective/comm.o \
|
$(PKGROOT)/src/collective/comm.o \
|
||||||
|
$(PKGROOT)/src/collective/comm_group.o \
|
||||||
$(PKGROOT)/src/collective/coll.o \
|
$(PKGROOT)/src/collective/coll.o \
|
||||||
$(PKGROOT)/src/collective/communicator-inl.o \
|
$(PKGROOT)/src/collective/communicator-inl.o \
|
||||||
$(PKGROOT)/src/collective/tracker.o \
|
$(PKGROOT)/src/collective/tracker.o \
|
||||||
|
|||||||
@ -71,7 +71,7 @@ extern SEXP XGDMatrixGetDataAsCSR_R(SEXP);
|
|||||||
extern SEXP XGDMatrixSaveBinary_R(SEXP, SEXP, SEXP);
|
extern SEXP XGDMatrixSaveBinary_R(SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP);
|
extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixSetStrFeatureInfo_R(SEXP, SEXP, SEXP);
|
extern SEXP XGDMatrixSetStrFeatureInfo_R(SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP);
|
extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGBSetGlobalConfig_R(SEXP);
|
extern SEXP XGBSetGlobalConfig_R(SEXP);
|
||||||
extern SEXP XGBGetGlobalConfig_R(void);
|
extern SEXP XGBGetGlobalConfig_R(void);
|
||||||
extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP);
|
extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP);
|
||||||
@ -134,7 +134,7 @@ static const R_CallMethodDef CallEntries[] = {
|
|||||||
{"XGDMatrixSaveBinary_R", (DL_FUNC) &XGDMatrixSaveBinary_R, 3},
|
{"XGDMatrixSaveBinary_R", (DL_FUNC) &XGDMatrixSaveBinary_R, 3},
|
||||||
{"XGDMatrixSetInfo_R", (DL_FUNC) &XGDMatrixSetInfo_R, 3},
|
{"XGDMatrixSetInfo_R", (DL_FUNC) &XGDMatrixSetInfo_R, 3},
|
||||||
{"XGDMatrixSetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixSetStrFeatureInfo_R, 3},
|
{"XGDMatrixSetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixSetStrFeatureInfo_R, 3},
|
||||||
{"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 2},
|
{"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 3},
|
||||||
{"XGBSetGlobalConfig_R", (DL_FUNC) &XGBSetGlobalConfig_R, 1},
|
{"XGBSetGlobalConfig_R", (DL_FUNC) &XGBSetGlobalConfig_R, 1},
|
||||||
{"XGBGetGlobalConfig_R", (DL_FUNC) &XGBGetGlobalConfig_R, 0},
|
{"XGBGetGlobalConfig_R", (DL_FUNC) &XGBGetGlobalConfig_R, 0},
|
||||||
{"XGBoosterFeatureScore_R", (DL_FUNC) &XGBoosterFeatureScore_R, 2},
|
{"XGBoosterFeatureScore_R", (DL_FUNC) &XGBoosterFeatureScore_R, 2},
|
||||||
|
|||||||
@ -512,7 +512,7 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
|
XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset, SEXP allow_groups) {
|
||||||
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
|
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
|
||||||
R_API_BEGIN();
|
R_API_BEGIN();
|
||||||
R_xlen_t len = Rf_xlength(idxset);
|
R_xlen_t len = Rf_xlength(idxset);
|
||||||
@ -531,7 +531,7 @@ XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
|
|||||||
res_code = XGDMatrixSliceDMatrixEx(R_ExternalPtrAddr(handle),
|
res_code = XGDMatrixSliceDMatrixEx(R_ExternalPtrAddr(handle),
|
||||||
BeginPtr(idxvec), len,
|
BeginPtr(idxvec), len,
|
||||||
&res,
|
&res,
|
||||||
0);
|
Rf_asLogical(allow_groups));
|
||||||
}
|
}
|
||||||
CHECK_CALL(res_code);
|
CHECK_CALL(res_code);
|
||||||
R_SetExternalPtrAddr(ret, res);
|
R_SetExternalPtrAddr(ret, res);
|
||||||
|
|||||||
@ -112,9 +112,10 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP
|
|||||||
* \brief create a new dmatrix from sliced content of existing matrix
|
* \brief create a new dmatrix from sliced content of existing matrix
|
||||||
* \param handle instance of data matrix to be sliced
|
* \param handle instance of data matrix to be sliced
|
||||||
* \param idxset index set
|
* \param idxset index set
|
||||||
|
* \param allow_groups Whether to allow slicing the DMatrix if it has a 'group' field
|
||||||
* \return a sliced new matrix
|
* \return a sliced new matrix
|
||||||
*/
|
*/
|
||||||
XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset);
|
XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset, SEXP allow_groups);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief load a data matrix into binary file
|
* \brief load a data matrix into binary file
|
||||||
|
|||||||
@ -334,7 +334,7 @@ test_that("xgb.cv works", {
|
|||||||
set.seed(11)
|
set.seed(11)
|
||||||
expect_output(
|
expect_output(
|
||||||
cv <- xgb.cv(
|
cv <- xgb.cv(
|
||||||
data = train$data, label = train$label, max_depth = 2, nfold = 5,
|
data = xgb.DMatrix(train$data, label = train$label), max_depth = 2, nfold = 5,
|
||||||
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
|
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
|
||||||
eval_metric = "error", verbose = TRUE
|
eval_metric = "error", verbose = TRUE
|
||||||
),
|
),
|
||||||
@ -357,13 +357,13 @@ test_that("xgb.cv works with stratified folds", {
|
|||||||
cv <- xgb.cv(
|
cv <- xgb.cv(
|
||||||
data = dtrain, max_depth = 2, nfold = 5,
|
data = dtrain, max_depth = 2, nfold = 5,
|
||||||
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
|
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
|
||||||
verbose = TRUE, stratified = FALSE
|
verbose = FALSE, stratified = FALSE
|
||||||
)
|
)
|
||||||
set.seed(314159)
|
set.seed(314159)
|
||||||
cv2 <- xgb.cv(
|
cv2 <- xgb.cv(
|
||||||
data = dtrain, max_depth = 2, nfold = 5,
|
data = dtrain, max_depth = 2, nfold = 5,
|
||||||
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
|
eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic",
|
||||||
verbose = TRUE, stratified = TRUE
|
verbose = FALSE, stratified = TRUE
|
||||||
)
|
)
|
||||||
# Stratified folds should result in a different evaluation logs
|
# Stratified folds should result in a different evaluation logs
|
||||||
expect_true(all(cv$evaluation_log[, test_logloss_mean] != cv2$evaluation_log[, test_logloss_mean]))
|
expect_true(all(cv$evaluation_log[, test_logloss_mean] != cv2$evaluation_log[, test_logloss_mean]))
|
||||||
@ -885,3 +885,57 @@ test_that("Seed in params override PRNG from R", {
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("xgb.cv works for AFT", {
|
||||||
|
X <- matrix(c(1, -1, -1, 1, 0, 1, 1, 0), nrow = 4, byrow = TRUE) # 4x2 matrix
|
||||||
|
dtrain <- xgb.DMatrix(X, nthread = n_threads)
|
||||||
|
|
||||||
|
params <- list(objective = 'survival:aft', learning_rate = 0.2, max_depth = 2L)
|
||||||
|
|
||||||
|
# data must have bounds
|
||||||
|
expect_error(
|
||||||
|
xgb.cv(
|
||||||
|
params = params,
|
||||||
|
data = dtrain,
|
||||||
|
nround = 5L,
|
||||||
|
nfold = 4L,
|
||||||
|
nthread = n_threads
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
setinfo(dtrain, 'label_lower_bound', c(2, 3, 0, 4))
|
||||||
|
setinfo(dtrain, 'label_upper_bound', c(2, Inf, 4, 5))
|
||||||
|
|
||||||
|
# automatic stratified splitting is turned off
|
||||||
|
expect_warning(
|
||||||
|
xgb.cv(
|
||||||
|
params = params, data = dtrain, nround = 5L, nfold = 4L,
|
||||||
|
nthread = n_threads, stratified = TRUE, verbose = FALSE
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# this works without any issue
|
||||||
|
expect_no_warning(
|
||||||
|
xgb.cv(params = params, data = dtrain, nround = 5L, nfold = 4L, verbose = FALSE)
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
test_that("xgb.cv works for ranking", {
|
||||||
|
data(iris)
|
||||||
|
x <- iris[, -(4:5)]
|
||||||
|
y <- as.integer(iris$Petal.Width)
|
||||||
|
group <- rep(50, 3)
|
||||||
|
dm <- xgb.DMatrix(x, label = y, group = group)
|
||||||
|
res <- xgb.cv(
|
||||||
|
data = dm,
|
||||||
|
params = list(
|
||||||
|
objective = "rank:pairwise",
|
||||||
|
max_depth = 3
|
||||||
|
),
|
||||||
|
nrounds = 3,
|
||||||
|
nfold = 2,
|
||||||
|
verbose = FALSE,
|
||||||
|
stratified = FALSE
|
||||||
|
)
|
||||||
|
expect_equal(length(res$folds), 2L)
|
||||||
|
})
|
||||||
|
|||||||
@ -367,7 +367,7 @@ test_that("prediction in early-stopping xgb.cv works", {
|
|||||||
expect_output(
|
expect_output(
|
||||||
cv <- xgb.cv(param, dtrain, nfold = 5, eta = 0.1, nrounds = 20,
|
cv <- xgb.cv(param, dtrain, nfold = 5, eta = 0.1, nrounds = 20,
|
||||||
early_stopping_rounds = 5, maximize = FALSE, stratified = FALSE,
|
early_stopping_rounds = 5, maximize = FALSE, stratified = FALSE,
|
||||||
prediction = TRUE, base_score = 0.5)
|
prediction = TRUE, base_score = 0.5, verbose = TRUE)
|
||||||
, "Stopping. Best iteration")
|
, "Stopping. Best iteration")
|
||||||
|
|
||||||
expect_false(is.null(cv$early_stop$best_iteration))
|
expect_false(is.null(cv$early_stop$best_iteration))
|
||||||
@ -387,7 +387,7 @@ test_that("prediction in xgb.cv for softprob works", {
|
|||||||
lb <- as.numeric(iris$Species) - 1
|
lb <- as.numeric(iris$Species) - 1
|
||||||
set.seed(11)
|
set.seed(11)
|
||||||
expect_warning(
|
expect_warning(
|
||||||
cv <- xgb.cv(data = as.matrix(iris[, -5]), label = lb, nfold = 4,
|
cv <- xgb.cv(data = xgb.DMatrix(as.matrix(iris[, -5]), label = lb), nfold = 4,
|
||||||
eta = 0.5, nrounds = 5, max_depth = 3, nthread = n_threads,
|
eta = 0.5, nrounds = 5, max_depth = 3, nthread = n_threads,
|
||||||
subsample = 0.8, gamma = 2, verbose = 0,
|
subsample = 0.8, gamma = 2, verbose = 0,
|
||||||
prediction = TRUE, objective = "multi:softprob", num_class = 3)
|
prediction = TRUE, objective = "multi:softprob", num_class = 3)
|
||||||
|
|||||||
@ -243,7 +243,7 @@ test_that("xgb.DMatrix: print", {
|
|||||||
txt <- capture.output({
|
txt <- capture.output({
|
||||||
print(dtrain)
|
print(dtrain)
|
||||||
})
|
})
|
||||||
expect_equal(txt, "xgb.DMatrix dim: 6513 x 126 info: label weight base_margin colnames: yes")
|
expect_equal(txt, "xgb.DMatrix dim: 6513 x 126 info: base_margin, label, weight colnames: yes")
|
||||||
|
|
||||||
# DMatrix with just features
|
# DMatrix with just features
|
||||||
dtrain <- xgb.DMatrix(
|
dtrain <- xgb.DMatrix(
|
||||||
@ -724,6 +724,44 @@ test_that("xgb.DMatrix: quantile cuts look correct", {
|
|||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("xgb.DMatrix: slicing keeps field indicators", {
|
||||||
|
data(mtcars)
|
||||||
|
x <- as.matrix(mtcars[, -1])
|
||||||
|
y <- mtcars[, 1]
|
||||||
|
dm <- xgb.DMatrix(
|
||||||
|
data = x,
|
||||||
|
label_lower_bound = -y,
|
||||||
|
label_upper_bound = y,
|
||||||
|
nthread = 1
|
||||||
|
)
|
||||||
|
idx_take <- seq(1, 5)
|
||||||
|
dm_slice <- xgb.slice.DMatrix(dm, idx_take)
|
||||||
|
|
||||||
|
expect_true(xgb.DMatrix.hasinfo(dm_slice, "label_lower_bound"))
|
||||||
|
expect_true(xgb.DMatrix.hasinfo(dm_slice, "label_upper_bound"))
|
||||||
|
expect_false(xgb.DMatrix.hasinfo(dm_slice, "label"))
|
||||||
|
|
||||||
|
expect_equal(getinfo(dm_slice, "label_lower_bound"), -y[idx_take], tolerance = 1e-6)
|
||||||
|
expect_equal(getinfo(dm_slice, "label_upper_bound"), y[idx_take], tolerance = 1e-6)
|
||||||
|
})
|
||||||
|
|
||||||
|
test_that("xgb.DMatrix: can slice with groups", {
|
||||||
|
data(iris)
|
||||||
|
x <- as.matrix(iris[, -5])
|
||||||
|
set.seed(123)
|
||||||
|
y <- sample(3, size = nrow(x), replace = TRUE)
|
||||||
|
group <- c(50, 50, 50)
|
||||||
|
dm <- xgb.DMatrix(x, label = y, group = group, nthread = 1)
|
||||||
|
idx_take <- seq(1, 50)
|
||||||
|
dm_slice <- xgb.slice.DMatrix(dm, idx_take, allow_groups = TRUE)
|
||||||
|
|
||||||
|
expect_true(xgb.DMatrix.hasinfo(dm_slice, "label"))
|
||||||
|
expect_false(xgb.DMatrix.hasinfo(dm_slice, "group"))
|
||||||
|
expect_false(xgb.DMatrix.hasinfo(dm_slice, "qid"))
|
||||||
|
expect_null(getinfo(dm_slice, "group"))
|
||||||
|
expect_equal(getinfo(dm_slice, "label"), y[idx_take], tolerance = 1e-6)
|
||||||
|
})
|
||||||
|
|
||||||
test_that("xgb.DMatrix: can read CSV", {
|
test_that("xgb.DMatrix: can read CSV", {
|
||||||
txt <- paste(
|
txt <- paste(
|
||||||
"1,2,3",
|
"1,2,3",
|
||||||
|
|||||||
@ -40,7 +40,7 @@ def main(client):
|
|||||||
# you can pass output directly into `predict` too.
|
# you can pass output directly into `predict` too.
|
||||||
prediction = dxgb.predict(client, bst, dtrain)
|
prediction = dxgb.predict(client, bst, dtrain)
|
||||||
print("Evaluation history:", history)
|
print("Evaluation history:", history)
|
||||||
return prediction
|
print("Error:", da.sqrt((prediction - y) ** 2).mean().compute())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -144,6 +144,14 @@ which provides higher flexibility. For example:
|
|||||||
|
|
||||||
ctest --verbose
|
ctest --verbose
|
||||||
|
|
||||||
|
If you need to debug errors on Windows using the debugger from VS, you can append the gtest flags in `test_main.cc`:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
::testing::GTEST_FLAG(filter) = "Suite.Test";
|
||||||
|
::testing::GTEST_FLAG(repeat) = 10;
|
||||||
|
|
||||||
|
|
||||||
***********************************************
|
***********************************************
|
||||||
Sanitizers: Detect memory errors and data races
|
Sanitizers: Detect memory errors and data races
|
||||||
***********************************************
|
***********************************************
|
||||||
|
|||||||
@ -28,7 +28,7 @@ Contents
|
|||||||
Python Package <python/index>
|
Python Package <python/index>
|
||||||
R Package <R-package/index>
|
R Package <R-package/index>
|
||||||
JVM Package <jvm/index>
|
JVM Package <jvm/index>
|
||||||
Ruby Package <https://github.com/ankane/xgb>
|
Ruby Package <https://github.com/ankane/xgboost-ruby>
|
||||||
Swift Package <https://github.com/kongzii/SwiftXGBoost>
|
Swift Package <https://github.com/kongzii/SwiftXGBoost>
|
||||||
Julia Package <julia>
|
Julia Package <julia>
|
||||||
C Package <c>
|
C Package <c>
|
||||||
|
|||||||
@ -52,7 +52,7 @@ Notice that the samples are sorted based on their query index in a non-decreasin
|
|||||||
X, y = make_classification(random_state=seed)
|
X, y = make_classification(random_state=seed)
|
||||||
rng = np.random.default_rng(seed)
|
rng = np.random.default_rng(seed)
|
||||||
n_query_groups = 3
|
n_query_groups = 3
|
||||||
qid = rng.integers(0, 3, size=X.shape[0])
|
qid = rng.integers(0, n_query_groups, size=X.shape[0])
|
||||||
|
|
||||||
# Sort the inputs based on query index
|
# Sort the inputs based on query index
|
||||||
sorted_idx = np.argsort(qid)
|
sorted_idx = np.argsort(qid)
|
||||||
@ -65,14 +65,14 @@ The simplest way to train a ranking model is by using the scikit-learn estimator
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
ranker = xgb.XGBRanker(tree_method="hist", lambdarank_num_pair_per_sample=8, objective="rank:ndcg", lambdarank_pair_method="topk")
|
ranker = xgb.XGBRanker(tree_method="hist", lambdarank_num_pair_per_sample=8, objective="rank:ndcg", lambdarank_pair_method="topk")
|
||||||
ranker.fit(X, y, qid=qid)
|
ranker.fit(X, y, qid=qid[sorted_idx])
|
||||||
|
|
||||||
Please note that, as of writing, there's no learning-to-rank interface in scikit-learn. As a result, the :py:class:`xgboost.XGBRanker` class does not fully conform the scikit-learn estimator guideline and can not be directly used with some of its utility functions. For instances, the ``auc_score`` and ``ndcg_score`` in scikit-learn don't consider query group information nor the pairwise loss. Most of the metrics are implemented as part of XGBoost, but to use scikit-learn utilities like :py:func:`sklearn.model_selection.cross_validation`, we need to make some adjustments in order to pass the ``qid`` as an additional parameter for :py:meth:`xgboost.XGBRanker.score`. Given a data frame ``X`` (either pandas or cuDF), add the column ``qid`` as follows:
|
Please note that, as of writing, there's no learning-to-rank interface in scikit-learn. As a result, the :py:class:`xgboost.XGBRanker` class does not fully conform the scikit-learn estimator guideline and can not be directly used with some of its utility functions. For instances, the ``auc_score`` and ``ndcg_score`` in scikit-learn don't consider query group information nor the pairwise loss. Most of the metrics are implemented as part of XGBoost, but to use scikit-learn utilities like :py:func:`sklearn.model_selection.cross_validation`, we need to make some adjustments in order to pass the ``qid`` as an additional parameter for :py:meth:`xgboost.XGBRanker.score`. Given a data frame ``X`` (either pandas or cuDF), add the column ``qid`` as follows:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
df = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])])
|
df = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])])
|
||||||
df["qid"] = qid
|
df["qid"] = qid[sorted_idx]
|
||||||
ranker.fit(df, y) # No need to pass qid as a separate argument
|
ranker.fit(df, y) # No need to pass qid as a separate argument
|
||||||
|
|
||||||
from sklearn.model_selection import StratifiedGroupKFold, cross_val_score
|
from sklearn.model_selection import StratifiedGroupKFold, cross_val_score
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2015~2023 by XGBoost Contributors
|
* Copyright 2015-2024, XGBoost Contributors
|
||||||
* \file c_api.h
|
* \file c_api.h
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
* \brief C API of XGBoost, used for interfacing to other languages.
|
* \brief C API of XGBoost, used for interfacing to other languages.
|
||||||
@ -639,21 +639,14 @@ XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle,
|
|||||||
* \param len length of array
|
* \param len length of array
|
||||||
* \return 0 when success, -1 when failure happens
|
* \return 0 when success, -1 when failure happens
|
||||||
*/
|
*/
|
||||||
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
|
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const float *array,
|
||||||
const char *field,
|
|
||||||
const float *array,
|
|
||||||
bst_ulong len);
|
bst_ulong len);
|
||||||
/*!
|
/**
|
||||||
* \brief set uint32 vector to a content in info
|
* @deprecated since 2.1.0
|
||||||
* \param handle a instance of data matrix
|
*
|
||||||
* \param field field name
|
* Use @ref XGDMatrixSetInfoFromInterface instead.
|
||||||
* \param array pointer to unsigned int vector
|
|
||||||
* \param len length of array
|
|
||||||
* \return 0 when success, -1 when failure happens
|
|
||||||
*/
|
*/
|
||||||
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
|
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const unsigned *array,
|
||||||
const char *field,
|
|
||||||
const unsigned *array,
|
|
||||||
bst_ulong len);
|
bst_ulong len);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
@ -725,42 +718,13 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
|
|||||||
bst_ulong *size,
|
bst_ulong *size,
|
||||||
const char ***out_features);
|
const char ***out_features);
|
||||||
|
|
||||||
/*!
|
/**
|
||||||
* \brief Set meta info from dense matrix. Valid field names are:
|
* @deprecated since 2.1.0
|
||||||
*
|
*
|
||||||
* - label
|
* Use @ref XGDMatrixSetInfoFromInterface instead.
|
||||||
* - weight
|
|
||||||
* - base_margin
|
|
||||||
* - group
|
|
||||||
* - label_lower_bound
|
|
||||||
* - label_upper_bound
|
|
||||||
* - feature_weights
|
|
||||||
*
|
|
||||||
* \param handle An instance of data matrix
|
|
||||||
* \param field Field name
|
|
||||||
* \param data Pointer to consecutive memory storing data.
|
|
||||||
* \param size Size of the data, this is relative to size of type. (Meaning NOT number
|
|
||||||
* of bytes.)
|
|
||||||
* \param type Indicator of data type. This is defined in xgboost::DataType enum class.
|
|
||||||
* - float = 1
|
|
||||||
* - double = 2
|
|
||||||
* - uint32_t = 3
|
|
||||||
* - uint64_t = 4
|
|
||||||
* \return 0 when success, -1 when failure happens
|
|
||||||
*/
|
*/
|
||||||
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field,
|
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void const *data,
|
||||||
void const *data, bst_ulong size, int type);
|
bst_ulong size, int type);
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix
|
|
||||||
* \param handle a instance of data matrix
|
|
||||||
* \param group pointer to group size
|
|
||||||
* \param len length of array
|
|
||||||
* \return 0 when success, -1 when failure happens
|
|
||||||
*/
|
|
||||||
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
|
|
||||||
const unsigned *group,
|
|
||||||
bst_ulong len);
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief get float info vector from matrix.
|
* \brief get float info vector from matrix.
|
||||||
@ -1591,7 +1555,7 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Get the arguments needed for running workers. This should be called after
|
* @brief Get the arguments needed for running workers. This should be called after
|
||||||
* XGTrackerRun() and XGTrackerWait()
|
* XGTrackerRun().
|
||||||
*
|
*
|
||||||
* @param handle The handle to the tracker.
|
* @param handle The handle to the tracker.
|
||||||
* @param args The arguments returned as a JSON document.
|
* @param args The arguments returned as a JSON document.
|
||||||
@ -1601,16 +1565,19 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle);
|
|||||||
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args);
|
XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Run the tracker.
|
* @brief Start the tracker. The tracker runs in the background and this function returns
|
||||||
|
* once the tracker is started.
|
||||||
*
|
*
|
||||||
* @param handle The handle to the tracker.
|
* @param handle The handle to the tracker.
|
||||||
|
* @param config Unused at the moment, preserved for the future.
|
||||||
*
|
*
|
||||||
* @return 0 for success, -1 for failure.
|
* @return 0 for success, -1 for failure.
|
||||||
*/
|
*/
|
||||||
XGB_DLL int XGTrackerRun(TrackerHandle handle);
|
XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *config);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Wait for the tracker to finish, should be called after XGTrackerRun().
|
* @brief Wait for the tracker to finish, should be called after XGTrackerRun(). This
|
||||||
|
* function will block until the tracker task is finished or timeout is reached.
|
||||||
*
|
*
|
||||||
* @param handle The handle to the tracker.
|
* @param handle The handle to the tracker.
|
||||||
* @param config JSON encoded configuration. No argument is required yet, preserved for
|
* @param config JSON encoded configuration. No argument is required yet, preserved for
|
||||||
@ -1618,11 +1585,12 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle);
|
|||||||
*
|
*
|
||||||
* @return 0 for success, -1 for failure.
|
* @return 0 for success, -1 for failure.
|
||||||
*/
|
*/
|
||||||
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config);
|
XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Free a tracker instance. XGTrackerWait() is called internally. If the tracker
|
* @brief Free a tracker instance. This should be called after XGTrackerWaitFor(). If the
|
||||||
* cannot close properly, manual interruption is required.
|
* tracker is not properly waited, this function will shutdown all connections with
|
||||||
|
* the tracker, potentially leading to undefined behavior.
|
||||||
*
|
*
|
||||||
* @param handle The handle to the tracker.
|
* @param handle The handle to the tracker.
|
||||||
*
|
*
|
||||||
|
|||||||
@ -3,12 +3,10 @@
|
|||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <xgboost/logging.h>
|
#include <cstdint> // for int32_t
|
||||||
|
|
||||||
#include <memory> // for unique_ptr
|
#include <memory> // for unique_ptr
|
||||||
#include <sstream> // for stringstream
|
|
||||||
#include <stack> // for stack
|
|
||||||
#include <string> // for string
|
#include <string> // for string
|
||||||
|
#include <system_error> // for error_code
|
||||||
#include <utility> // for move
|
#include <utility> // for move
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
@ -48,48 +46,19 @@ struct ResultImpl {
|
|||||||
return cur_eq;
|
return cur_eq;
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] std::string Report() {
|
[[nodiscard]] std::string Report() const;
|
||||||
std::stringstream ss;
|
[[nodiscard]] std::error_code Code() const;
|
||||||
ss << "\n- " << this->message;
|
|
||||||
if (this->errc != std::error_code{}) {
|
|
||||||
ss << " system error:" << this->errc.message();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto ptr = prev.get();
|
void Concat(std::unique_ptr<ResultImpl> rhs);
|
||||||
while (ptr) {
|
|
||||||
ss << "\n- ";
|
|
||||||
ss << ptr->message;
|
|
||||||
|
|
||||||
if (ptr->errc != std::error_code{}) {
|
|
||||||
ss << " " << ptr->errc.message();
|
|
||||||
}
|
|
||||||
ptr = ptr->prev.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
[[nodiscard]] auto Code() const {
|
|
||||||
// Find the root error.
|
|
||||||
std::stack<ResultImpl const*> stack;
|
|
||||||
auto ptr = this;
|
|
||||||
while (ptr) {
|
|
||||||
stack.push(ptr);
|
|
||||||
if (ptr->prev) {
|
|
||||||
ptr = ptr->prev.get();
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
while (!stack.empty()) {
|
|
||||||
auto frame = stack.top();
|
|
||||||
stack.pop();
|
|
||||||
if (frame->errc != std::error_code{}) {
|
|
||||||
return frame->errc;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return std::error_code{};
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
|
||||||
|
#define __builtin_FILE() nullptr
|
||||||
|
#define __builtin_LINE() (-1)
|
||||||
|
std::string MakeMsg(std::string&& msg, char const*, std::int32_t);
|
||||||
|
#else
|
||||||
|
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line);
|
||||||
|
#endif
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -131,8 +100,21 @@ struct Result {
|
|||||||
}
|
}
|
||||||
return *impl_ == *that.impl_;
|
return *impl_ == *that.impl_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
friend Result operator+(Result&& lhs, Result&& rhs);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
[[nodiscard]] inline Result operator+(Result&& lhs, Result&& rhs) {
|
||||||
|
if (lhs.OK()) {
|
||||||
|
return std::forward<Result>(rhs);
|
||||||
|
}
|
||||||
|
if (rhs.OK()) {
|
||||||
|
return std::forward<Result>(lhs);
|
||||||
|
}
|
||||||
|
lhs.impl_->Concat(std::move(rhs.impl_));
|
||||||
|
return std::forward<Result>(lhs);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Return success.
|
* @brief Return success.
|
||||||
*/
|
*/
|
||||||
@ -140,38 +122,43 @@ struct Result {
|
|||||||
/**
|
/**
|
||||||
* @brief Return failure.
|
* @brief Return failure.
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] inline auto Fail(std::string msg) { return Result{std::move(msg)}; }
|
[[nodiscard]] inline auto Fail(std::string msg, char const* file = __builtin_FILE(),
|
||||||
|
std::int32_t line = __builtin_LINE()) {
|
||||||
|
return Result{detail::MakeMsg(std::move(msg), file, line)};
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
* @brief Return failure with `errno`.
|
* @brief Return failure with `errno`.
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc) {
|
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc,
|
||||||
return Result{std::move(msg), std::move(errc)};
|
char const* file = __builtin_FILE(),
|
||||||
|
std::int32_t line = __builtin_LINE()) {
|
||||||
|
return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc)};
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* @brief Return failure with a previous error.
|
* @brief Return failure with a previous error.
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] inline auto Fail(std::string msg, Result&& prev) {
|
[[nodiscard]] inline auto Fail(std::string msg, Result&& prev, char const* file = __builtin_FILE(),
|
||||||
return Result{std::move(msg), std::forward<Result>(prev)};
|
std::int32_t line = __builtin_LINE()) {
|
||||||
|
return Result{detail::MakeMsg(std::move(msg), file, line), std::forward<Result>(prev)};
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* @brief Return failure with a previous error and a new `errno`.
|
* @brief Return failure with a previous error and a new `errno`.
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev) {
|
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev,
|
||||||
return Result{std::move(msg), std::move(errc), std::forward<Result>(prev)};
|
char const* file = __builtin_FILE(),
|
||||||
|
std::int32_t line = __builtin_LINE()) {
|
||||||
|
return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc),
|
||||||
|
std::forward<Result>(prev)};
|
||||||
}
|
}
|
||||||
|
|
||||||
// We don't have monad, a simple helper would do.
|
// We don't have monad, a simple helper would do.
|
||||||
template <typename Fn>
|
template <typename Fn>
|
||||||
[[nodiscard]] Result operator<<(Result&& r, Fn&& fn) {
|
[[nodiscard]] std::enable_if_t<std::is_invocable_v<Fn>, Result> operator<<(Result&& r, Fn&& fn) {
|
||||||
if (!r.OK()) {
|
if (!r.OK()) {
|
||||||
return std::forward<Result>(r);
|
return std::forward<Result>(r);
|
||||||
}
|
}
|
||||||
return fn();
|
return fn();
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void SafeColl(Result const& rc) {
|
void SafeColl(Result const& rc);
|
||||||
if (!rc.OK()) {
|
|
||||||
LOG(FATAL) << rc.Report();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright (c) 2022-2023, XGBoost Contributors
|
* Copyright (c) 2022-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
@ -12,7 +12,6 @@
|
|||||||
#include <cstddef> // std::size_t
|
#include <cstddef> // std::size_t
|
||||||
#include <cstdint> // std::int32_t, std::uint16_t
|
#include <cstdint> // std::int32_t, std::uint16_t
|
||||||
#include <cstring> // memset
|
#include <cstring> // memset
|
||||||
#include <limits> // std::numeric_limits
|
|
||||||
#include <string> // std::string
|
#include <string> // std::string
|
||||||
#include <system_error> // std::error_code, std::system_category
|
#include <system_error> // std::error_code, std::system_category
|
||||||
#include <utility> // std::swap
|
#include <utility> // std::swap
|
||||||
@ -125,6 +124,21 @@ inline std::int32_t CloseSocket(SocketT fd) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline std::int32_t ShutdownSocket(SocketT fd) {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
auto rc = shutdown(fd, SD_BOTH);
|
||||||
|
if (rc != 0 && LastError() == WSANOTINITIALISED) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
auto rc = shutdown(fd, SHUT_RDWR);
|
||||||
|
if (rc != 0 && LastError() == ENOTCONN) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
|
inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
return errsv == WSAEWOULDBLOCK;
|
return errsv == WSAEWOULDBLOCK;
|
||||||
@ -468,19 +482,30 @@ class TCPSocket {
|
|||||||
*addr = SockAddress{SockAddrV6{caddr}};
|
*addr = SockAddress{SockAddrV6{caddr}};
|
||||||
*out = TCPSocket{newfd};
|
*out = TCPSocket{newfd};
|
||||||
}
|
}
|
||||||
|
// On MacOS, this is automatically set to async socket if the parent socket is async
|
||||||
|
// We make sure all socket are blocking by default.
|
||||||
|
//
|
||||||
|
// On Windows, a closed socket is returned during shutdown. We guard against it when
|
||||||
|
// setting non-blocking.
|
||||||
|
if (!out->IsClosed()) {
|
||||||
|
return out->NonBlocking(false);
|
||||||
|
}
|
||||||
return Success();
|
return Success();
|
||||||
}
|
}
|
||||||
|
|
||||||
~TCPSocket() {
|
~TCPSocket() {
|
||||||
if (!IsClosed()) {
|
if (!IsClosed()) {
|
||||||
Close();
|
auto rc = this->Close();
|
||||||
|
if (!rc.OK()) {
|
||||||
|
LOG(WARNING) << rc.Report();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TCPSocket(TCPSocket const &that) = delete;
|
TCPSocket(TCPSocket const &that) = delete;
|
||||||
TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); }
|
TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); }
|
||||||
TCPSocket &operator=(TCPSocket const &that) = delete;
|
TCPSocket &operator=(TCPSocket const &that) = delete;
|
||||||
TCPSocket &operator=(TCPSocket &&that) {
|
TCPSocket &operator=(TCPSocket &&that) noexcept(true) {
|
||||||
std::swap(this->handle_, that.handle_);
|
std::swap(this->handle_, that.handle_);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@ -489,36 +514,49 @@ class TCPSocket {
|
|||||||
*/
|
*/
|
||||||
[[nodiscard]] HandleT const &Handle() const { return handle_; }
|
[[nodiscard]] HandleT const &Handle() const { return handle_; }
|
||||||
/**
|
/**
|
||||||
* \brief Listen to incoming requests. Should be called after bind.
|
* @brief Listen to incoming requests. Should be called after bind.
|
||||||
*/
|
*/
|
||||||
void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); }
|
[[nodiscard]] Result Listen(std::int32_t backlog = 16) {
|
||||||
|
if (listen(handle_, backlog) != 0) {
|
||||||
|
return system::FailWithCode("Failed to listen.");
|
||||||
|
}
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
* \brief Bind socket to INADDR_ANY, return the port selected by the OS.
|
* @brief Bind socket to INADDR_ANY, return the port selected by the OS.
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] in_port_t BindHost() {
|
[[nodiscard]] Result BindHost(std::int32_t* p_out) {
|
||||||
|
// Use int32 instead of in_port_t for consistency. We take port as parameter from
|
||||||
|
// users using other languages, the port is usually stored and passed around as int.
|
||||||
if (Domain() == SockDomain::kV6) {
|
if (Domain() == SockDomain::kV6) {
|
||||||
auto addr = SockAddrV6::InaddrAny();
|
auto addr = SockAddrV6::InaddrAny();
|
||||||
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
|
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
|
||||||
xgboost_CHECK_SYS_CALL(
|
if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
|
||||||
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
|
return system::FailWithCode("bind failed.");
|
||||||
|
}
|
||||||
|
|
||||||
sockaddr_in6 res_addr;
|
sockaddr_in6 res_addr;
|
||||||
socklen_t addrlen = sizeof(res_addr);
|
socklen_t addrlen = sizeof(res_addr);
|
||||||
xgboost_CHECK_SYS_CALL(
|
if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
|
||||||
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
|
return system::FailWithCode("getsockname failed.");
|
||||||
return ntohs(res_addr.sin6_port);
|
}
|
||||||
|
*p_out = ntohs(res_addr.sin6_port);
|
||||||
} else {
|
} else {
|
||||||
auto addr = SockAddrV4::InaddrAny();
|
auto addr = SockAddrV4::InaddrAny();
|
||||||
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
|
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
|
||||||
xgboost_CHECK_SYS_CALL(
|
if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
|
||||||
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
|
return system::FailWithCode("bind failed.");
|
||||||
|
}
|
||||||
|
|
||||||
sockaddr_in res_addr;
|
sockaddr_in res_addr;
|
||||||
socklen_t addrlen = sizeof(res_addr);
|
socklen_t addrlen = sizeof(res_addr);
|
||||||
xgboost_CHECK_SYS_CALL(
|
if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
|
||||||
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
|
return system::FailWithCode("getsockname failed.");
|
||||||
return ntohs(res_addr.sin_port);
|
|
||||||
}
|
}
|
||||||
|
*p_out = ntohs(res_addr.sin_port);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Success();
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] auto Port() const {
|
[[nodiscard]] auto Port() const {
|
||||||
@ -631,26 +669,49 @@ class TCPSocket {
|
|||||||
*/
|
*/
|
||||||
std::size_t Send(StringView str);
|
std::size_t Send(StringView str);
|
||||||
/**
|
/**
|
||||||
* \brief Receive string, format is matched with the Python socket wrapper in RABIT.
|
* @brief Receive string, format is matched with the Python socket wrapper in RABIT.
|
||||||
*/
|
*/
|
||||||
std::size_t Recv(std::string *p_str);
|
[[nodiscard]] Result Recv(std::string *p_str);
|
||||||
/**
|
/**
|
||||||
* \brief Close the socket, called automatically in destructor if the socket is not closed.
|
* @brief Close the socket, called automatically in destructor if the socket is not closed.
|
||||||
*/
|
*/
|
||||||
void Close() {
|
[[nodiscard]] Result Close() {
|
||||||
if (InvalidSocket() != handle_) {
|
if (InvalidSocket() != handle_) {
|
||||||
#if defined(_WIN32)
|
|
||||||
auto rc = system::CloseSocket(handle_);
|
auto rc = system::CloseSocket(handle_);
|
||||||
|
#if defined(_WIN32)
|
||||||
// it's possible that we close TCP sockets after finalizing WSA due to detached thread.
|
// it's possible that we close TCP sockets after finalizing WSA due to detached thread.
|
||||||
if (rc != 0 && system::LastError() != WSANOTINITIALISED) {
|
if (rc != 0 && system::LastError() != WSANOTINITIALISED) {
|
||||||
system::ThrowAtError("close", rc);
|
return system::FailWithCode("Failed to close the socket.");
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
xgboost_CHECK_SYS_CALL(system::CloseSocket(handle_), 0);
|
if (rc != 0) {
|
||||||
|
return system::FailWithCode("Failed to close the socket.");
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
handle_ = InvalidSocket();
|
handle_ = InvalidSocket();
|
||||||
}
|
}
|
||||||
|
return Success();
|
||||||
}
|
}
|
||||||
|
/**
|
||||||
|
* @brief Call shutdown on the socket.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] Result Shutdown() {
|
||||||
|
if (this->IsClosed()) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
auto rc = system::ShutdownSocket(this->Handle());
|
||||||
|
#if defined(_WIN32)
|
||||||
|
// Windows cannot shutdown a socket if it's not connected.
|
||||||
|
if (rc == -1 && system::LastError() == WSAENOTCONN) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
if (rc != 0) {
|
||||||
|
return system::FailWithCode("Failed to shutdown socket.");
|
||||||
|
}
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Create a TCP socket on specified domain.
|
* \brief Create a TCP socket on specified domain.
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -19,7 +19,6 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <numeric>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -137,14 +136,6 @@ class MetaInfo {
|
|||||||
* \param fo The output stream.
|
* \param fo The output stream.
|
||||||
*/
|
*/
|
||||||
void SaveBinary(dmlc::Stream* fo) const;
|
void SaveBinary(dmlc::Stream* fo) const;
|
||||||
/*!
|
|
||||||
* \brief Set information in the meta info.
|
|
||||||
* \param key The key of the information.
|
|
||||||
* \param dptr The data pointer of the source array.
|
|
||||||
* \param dtype The type of the source data.
|
|
||||||
* \param num Number of elements in the source array.
|
|
||||||
*/
|
|
||||||
void SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype, size_t num);
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Set information in the meta info with array interface.
|
* \brief Set information in the meta info with array interface.
|
||||||
* \param key The key of the information.
|
* \param key The key of the information.
|
||||||
@ -517,10 +508,6 @@ class DMatrix {
|
|||||||
DMatrix() = default;
|
DMatrix() = default;
|
||||||
/*! \brief meta information of the dataset */
|
/*! \brief meta information of the dataset */
|
||||||
virtual MetaInfo& Info() = 0;
|
virtual MetaInfo& Info() = 0;
|
||||||
virtual void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
|
|
||||||
auto const& ctx = *this->Ctx();
|
|
||||||
this->Info().SetInfo(ctx, key, dptr, dtype, num);
|
|
||||||
}
|
|
||||||
virtual void SetInfo(const char* key, std::string const& interface_str) {
|
virtual void SetInfo(const char* key, std::string const& interface_str) {
|
||||||
auto const& ctx = *this->Ctx();
|
auto const& ctx = *this->Ctx();
|
||||||
this->Info().SetInfo(ctx, key, StringView{interface_str});
|
this->Info().SetInfo(ctx, key, StringView{interface_str});
|
||||||
|
|||||||
@ -190,13 +190,14 @@ constexpr auto ArrToTuple(T (&arr)[N]) {
|
|||||||
// uint division optimization inspired by the CIndexer in cupy. Division operation is
|
// uint division optimization inspired by the CIndexer in cupy. Division operation is
|
||||||
// slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64
|
// slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64
|
||||||
// bit when the index is smaller, then try to avoid division when it's exp of 2.
|
// bit when the index is smaller, then try to avoid division when it's exp of 2.
|
||||||
template <typename I, int32_t D>
|
template <typename I, std::int32_t D>
|
||||||
LINALG_HD auto UnravelImpl(I idx, common::Span<size_t const, D> shape) {
|
LINALG_HD auto UnravelImpl(I idx, common::Span<size_t const, D> shape) {
|
||||||
size_t index[D]{0};
|
std::size_t index[D]{0};
|
||||||
static_assert(std::is_signed<decltype(D)>::value,
|
static_assert(std::is_signed<decltype(D)>::value,
|
||||||
"Don't change the type without changing the for loop.");
|
"Don't change the type without changing the for loop.");
|
||||||
|
auto const sptr = shape.data();
|
||||||
for (int32_t dim = D; --dim > 0;) {
|
for (int32_t dim = D; --dim > 0;) {
|
||||||
auto s = static_cast<std::remove_const_t<std::remove_reference_t<I>>>(shape[dim]);
|
auto s = static_cast<std::remove_const_t<std::remove_reference_t<I>>>(sptr[dim]);
|
||||||
if (s & (s - 1)) {
|
if (s & (s - 1)) {
|
||||||
auto t = idx / s;
|
auto t = idx / s;
|
||||||
index[dim] = idx - t * s;
|
index[dim] = idx - t * s;
|
||||||
@ -745,6 +746,14 @@ auto ArrayInterfaceStr(TensorView<T, D> const &t) {
|
|||||||
return str;
|
return str;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
auto Make1dInterface(T const *vec, std::size_t len) {
|
||||||
|
Context ctx;
|
||||||
|
auto t = linalg::MakeTensorView(&ctx, common::Span{vec, len}, len);
|
||||||
|
auto str = linalg::ArrayInterfaceStr(t);
|
||||||
|
return str;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief A tensor storage. To use it for other functionality like slicing one needs to
|
* \brief A tensor storage. To use it for other functionality like slicing one needs to
|
||||||
* obtain a view first. This way we can use it on both host and device.
|
* obtain a view first. This way we can use it on both host and device.
|
||||||
|
|||||||
@ -30,9 +30,8 @@
|
|||||||
#define XGBOOST_SPAN_H_
|
#define XGBOOST_SPAN_H_
|
||||||
|
|
||||||
#include <xgboost/base.h>
|
#include <xgboost/base.h>
|
||||||
#include <xgboost/logging.h>
|
|
||||||
|
|
||||||
#include <cinttypes> // size_t
|
#include <cstddef> // size_t
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <limits> // numeric_limits
|
#include <limits> // numeric_limits
|
||||||
@ -75,8 +74,7 @@
|
|||||||
|
|
||||||
#endif // defined(_MSC_VER) && _MSC_VER < 1910
|
#endif // defined(_MSC_VER) && _MSC_VER < 1910
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::common {
|
||||||
namespace common {
|
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__)
|
#if defined(__CUDA_ARCH__)
|
||||||
// Usual logging facility is not available inside device code.
|
// Usual logging facility is not available inside device code.
|
||||||
@ -744,8 +742,8 @@ class IterSpan {
|
|||||||
return it_ + size();
|
return it_ + size();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace common
|
} // namespace xgboost::common
|
||||||
} // namespace xgboost
|
|
||||||
|
|
||||||
#if defined(_MSC_VER) &&_MSC_VER < 1910
|
#if defined(_MSC_VER) &&_MSC_VER < 1910
|
||||||
#undef constexpr
|
#undef constexpr
|
||||||
|
|||||||
@ -33,7 +33,7 @@
|
|||||||
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
|
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
|
||||||
<maven.compiler.source>1.8</maven.compiler.source>
|
<maven.compiler.source>1.8</maven.compiler.source>
|
||||||
<maven.compiler.target>1.8</maven.compiler.target>
|
<maven.compiler.target>1.8</maven.compiler.target>
|
||||||
<flink.version>1.18.0</flink.version>
|
<flink.version>1.19.0</flink.version>
|
||||||
<junit.version>4.13.2</junit.version>
|
<junit.version>4.13.2</junit.version>
|
||||||
<spark.version>3.4.1</spark.version>
|
<spark.version>3.4.1</spark.version>
|
||||||
<spark.version.gpu>3.4.1</spark.version.gpu>
|
<spark.version.gpu>3.4.1</spark.version.gpu>
|
||||||
@ -46,9 +46,9 @@
|
|||||||
<cudf.version>23.12.1</cudf.version>
|
<cudf.version>23.12.1</cudf.version>
|
||||||
<spark.rapids.version>23.12.1</spark.rapids.version>
|
<spark.rapids.version>23.12.1</spark.rapids.version>
|
||||||
<cudf.classifier>cuda12</cudf.classifier>
|
<cudf.classifier>cuda12</cudf.classifier>
|
||||||
|
<scalatest.version>3.2.18</scalatest.version>
|
||||||
|
<scala-collection-compat.version>2.12.0</scala-collection-compat.version>
|
||||||
<use.hip>OFF</use.hip>
|
<use.hip>OFF</use.hip>
|
||||||
<scalatest.version>3.2.17</scalatest.version>
|
|
||||||
<scala-collection-compat.version>2.11.0</scala-collection-compat.version>
|
|
||||||
|
|
||||||
<!-- SPARK-36796 for JDK-17 test-->
|
<!-- SPARK-36796 for JDK-17 test-->
|
||||||
<extraJavaTestArgs>
|
<extraJavaTestArgs>
|
||||||
@ -124,7 +124,7 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-jar-plugin</artifactId>
|
<artifactId>maven-jar-plugin</artifactId>
|
||||||
<version>3.3.0</version>
|
<version>3.4.0</version>
|
||||||
<executions>
|
<executions>
|
||||||
<execution>
|
<execution>
|
||||||
<id>empty-javadoc-jar</id>
|
<id>empty-javadoc-jar</id>
|
||||||
@ -153,7 +153,7 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-gpg-plugin</artifactId>
|
<artifactId>maven-gpg-plugin</artifactId>
|
||||||
<version>3.1.0</version>
|
<version>3.2.3</version>
|
||||||
<executions>
|
<executions>
|
||||||
<execution>
|
<execution>
|
||||||
<id>sign-artifacts</id>
|
<id>sign-artifacts</id>
|
||||||
@ -167,7 +167,7 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-source-plugin</artifactId>
|
<artifactId>maven-source-plugin</artifactId>
|
||||||
<version>3.3.0</version>
|
<version>3.3.1</version>
|
||||||
<executions>
|
<executions>
|
||||||
<execution>
|
<execution>
|
||||||
<id>attach-sources</id>
|
<id>attach-sources</id>
|
||||||
@ -205,7 +205,7 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-assembly-plugin</artifactId>
|
<artifactId>maven-assembly-plugin</artifactId>
|
||||||
<version>3.6.0</version>
|
<version>3.7.1</version>
|
||||||
<configuration>
|
<configuration>
|
||||||
<descriptorRefs>
|
<descriptorRefs>
|
||||||
<descriptorRef>jar-with-dependencies</descriptorRef>
|
<descriptorRef>jar-with-dependencies</descriptorRef>
|
||||||
@ -446,7 +446,7 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<version>3.2.2</version>
|
<version>3.2.5</version>
|
||||||
<configuration>
|
<configuration>
|
||||||
<skipTests>false</skipTests>
|
<skipTests>false</skipTests>
|
||||||
<useSystemClassLoader>false</useSystemClassLoader>
|
<useSystemClassLoader>false</useSystemClassLoader>
|
||||||
@ -488,12 +488,12 @@
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.esotericsoftware</groupId>
|
<groupId>com.esotericsoftware</groupId>
|
||||||
<artifactId>kryo</artifactId>
|
<artifactId>kryo</artifactId>
|
||||||
<version>5.5.0</version>
|
<version>5.6.0</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>commons-logging</groupId>
|
<groupId>commons-logging</groupId>
|
||||||
<artifactId>commons-logging</artifactId>
|
<artifactId>commons-logging</artifactId>
|
||||||
<version>1.3.0</version>
|
<version>1.3.1</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.scalatest</groupId>
|
<groupId>org.scalatest</groupId>
|
||||||
|
|||||||
@ -72,7 +72,7 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-javadoc-plugin</artifactId>
|
<artifactId>maven-javadoc-plugin</artifactId>
|
||||||
<version>3.6.2</version>
|
<version>3.6.3</version>
|
||||||
<configuration>
|
<configuration>
|
||||||
<show>protected</show>
|
<show>protected</show>
|
||||||
<nohelp>true</nohelp>
|
<nohelp>true</nohelp>
|
||||||
@ -88,7 +88,7 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<artifactId>exec-maven-plugin</artifactId>
|
<artifactId>exec-maven-plugin</artifactId>
|
||||||
<groupId>org.codehaus.mojo</groupId>
|
<groupId>org.codehaus.mojo</groupId>
|
||||||
<version>3.1.0</version>
|
<version>3.2.0</version>
|
||||||
<executions>
|
<executions>
|
||||||
<execution>
|
<execution>
|
||||||
<id>native</id>
|
<id>native</id>
|
||||||
@ -115,7 +115,7 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-jar-plugin</artifactId>
|
<artifactId>maven-jar-plugin</artifactId>
|
||||||
<version>3.3.0</version>
|
<version>3.4.0</version>
|
||||||
<executions>
|
<executions>
|
||||||
<execution>
|
<execution>
|
||||||
<goals>
|
<goals>
|
||||||
|
|||||||
@ -22,7 +22,7 @@ pom_template = """
|
|||||||
<scala.version>{scala_version}</scala.version>
|
<scala.version>{scala_version}</scala.version>
|
||||||
<scalatest.version>3.2.15</scalatest.version>
|
<scalatest.version>3.2.15</scalatest.version>
|
||||||
<scala.binary.version>{scala_binary_version}</scala.binary.version>
|
<scala.binary.version>{scala_binary_version}</scala.binary.version>
|
||||||
<kryo.version>5.5.0</kryo.version>
|
<kryo.version>5.6.0</kryo.version>
|
||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
|
|||||||
@ -60,7 +60,7 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-javadoc-plugin</artifactId>
|
<artifactId>maven-javadoc-plugin</artifactId>
|
||||||
<version>3.6.2</version>
|
<version>3.6.3</version>
|
||||||
<configuration>
|
<configuration>
|
||||||
<show>protected</show>
|
<show>protected</show>
|
||||||
<nohelp>true</nohelp>
|
<nohelp>true</nohelp>
|
||||||
@ -76,7 +76,7 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<artifactId>exec-maven-plugin</artifactId>
|
<artifactId>exec-maven-plugin</artifactId>
|
||||||
<groupId>org.codehaus.mojo</groupId>
|
<groupId>org.codehaus.mojo</groupId>
|
||||||
<version>3.1.0</version>
|
<version>3.2.0</version>
|
||||||
<executions>
|
<executions>
|
||||||
<execution>
|
<execution>
|
||||||
<id>native</id>
|
<id>native</id>
|
||||||
@ -99,7 +99,7 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-jar-plugin</artifactId>
|
<artifactId>maven-jar-plugin</artifactId>
|
||||||
<version>3.3.0</version>
|
<version>3.4.0</version>
|
||||||
<executions>
|
<executions>
|
||||||
<execution>
|
<execution>
|
||||||
<goals>
|
<goals>
|
||||||
|
|||||||
@ -408,7 +408,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatI
|
|||||||
|
|
||||||
jfloat* array = jenv->GetFloatArrayElements(jarray, NULL);
|
jfloat* array = jenv->GetFloatArrayElements(jarray, NULL);
|
||||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
|
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
|
||||||
int ret = XGDMatrixSetFloatInfo(handle, field, (float const *)array, len);
|
auto str = xgboost::linalg::Make1dInterface(array, len);
|
||||||
|
int ret = XGDMatrixSetInfoFromInterface(handle, field, str.c_str());
|
||||||
JVM_CHECK_CALL(ret);
|
JVM_CHECK_CALL(ret);
|
||||||
//release
|
//release
|
||||||
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
||||||
@ -427,7 +428,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntIn
|
|||||||
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
||||||
jint* array = jenv->GetIntArrayElements(jarray, NULL);
|
jint* array = jenv->GetIntArrayElements(jarray, NULL);
|
||||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
|
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
|
||||||
int ret = XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len);
|
auto str = xgboost::linalg::Make1dInterface(array, len);
|
||||||
|
int ret = XGDMatrixSetInfoFromInterface(handle, field, str.c_str());
|
||||||
JVM_CHECK_CALL(ret);
|
JVM_CHECK_CALL(ret);
|
||||||
//release
|
//release
|
||||||
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
|
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
|
||||||
@ -730,8 +732,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFr
|
|||||||
if (jmargin) {
|
if (jmargin) {
|
||||||
margin = jenv->GetFloatArrayElements(jmargin, nullptr);
|
margin = jenv->GetFloatArrayElements(jmargin, nullptr);
|
||||||
JVM_CHECK_CALL(XGProxyDMatrixCreate(&proxy));
|
JVM_CHECK_CALL(XGProxyDMatrixCreate(&proxy));
|
||||||
JVM_CHECK_CALL(
|
auto str = xgboost::linalg::Make1dInterface(margin, jenv->GetArrayLength(jmargin));
|
||||||
XGDMatrixSetFloatInfo(proxy, "base_margin", margin, jenv->GetArrayLength(jmargin)));
|
JVM_CHECK_CALL(XGDMatrixSetInfoFromInterface(proxy, "base_margin", str.c_str()));
|
||||||
}
|
}
|
||||||
|
|
||||||
bst_ulong const *out_shape;
|
bst_ulong const *out_shape;
|
||||||
|
|||||||
@ -89,19 +89,15 @@ Coll *FederatedColl::MakeCUDAVar() {
|
|||||||
|
|
||||||
[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
||||||
std::int32_t root) {
|
std::int32_t root) {
|
||||||
if (comm.Rank() == root) {
|
|
||||||
return BroadcastImpl(comm, &this->sequence_number_, data, root);
|
return BroadcastImpl(comm, &this->sequence_number_, data, root);
|
||||||
} else {
|
|
||||||
return BroadcastImpl(comm, &this->sequence_number_, data, root);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data,
|
[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data) {
|
||||||
std::int64_t size) {
|
|
||||||
using namespace federated; // NOLINT
|
using namespace federated; // NOLINT
|
||||||
auto fed = dynamic_cast<FederatedComm const *>(&comm);
|
auto fed = dynamic_cast<FederatedComm const *>(&comm);
|
||||||
CHECK(fed);
|
CHECK(fed);
|
||||||
auto stub = fed->Handle();
|
auto stub = fed->Handle();
|
||||||
|
auto size = data.size_bytes() / comm.World();
|
||||||
|
|
||||||
auto offset = comm.Rank() * size;
|
auto offset = comm.Rank() * size;
|
||||||
auto segment = data.subspan(offset, size);
|
auto segment = data.subspan(offset, size);
|
||||||
|
|||||||
@ -53,8 +53,7 @@ Coll *FederatedColl::MakeCUDAVar() {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data,
|
[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data) {
|
||||||
std::int64_t size) {
|
|
||||||
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
|
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
|
||||||
CHECK(cufed);
|
CHECK(cufed);
|
||||||
std::vector<std::int8_t> h_data(data.size());
|
std::vector<std::int8_t> h_data(data.size());
|
||||||
@ -63,7 +62,7 @@ Coll *FederatedColl::MakeCUDAVar() {
|
|||||||
return GetCUDAResult(
|
return GetCUDAResult(
|
||||||
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
|
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
|
||||||
} << [&] {
|
} << [&] {
|
||||||
return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()}, size);
|
return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()});
|
||||||
} << [&] {
|
} << [&] {
|
||||||
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
|
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
|
||||||
cudaMemcpyHostToDevice, cufed->Stream()));
|
cudaMemcpyHostToDevice, cufed->Stream()));
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost contributors
|
* Copyright 2023-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include "../../src/collective/comm.h" // for Comm, Coll
|
#include "../../src/collective/comm.h" // for Comm, Coll
|
||||||
#include "federated_coll.h" // for FederatedColl
|
#include "federated_coll.h" // for FederatedColl
|
||||||
@ -16,8 +16,7 @@ class CUDAFederatedColl : public Coll {
|
|||||||
ArrayInterfaceHandler::Type type, Op op) override;
|
ArrayInterfaceHandler::Type type, Op op) override;
|
||||||
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
||||||
std::int32_t root) override;
|
std::int32_t root) override;
|
||||||
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
|
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data) override;
|
||||||
std::int64_t size) override;
|
|
||||||
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
|
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
|
||||||
common::Span<std::int64_t const> sizes,
|
common::Span<std::int64_t const> sizes,
|
||||||
common::Span<std::int64_t> recv_segments,
|
common::Span<std::int64_t> recv_segments,
|
||||||
|
|||||||
@ -1,12 +1,9 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost contributors
|
* Copyright 2023-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include "../../src/collective/coll.h" // for Coll
|
#include "../../src/collective/coll.h" // for Coll
|
||||||
#include "../../src/collective/comm.h" // for Comm
|
#include "../../src/collective/comm.h" // for Comm
|
||||||
#include "../../src/common/io.h" // for ReadAll
|
|
||||||
#include "../../src/common/json_utils.h" // for OptionalArg
|
|
||||||
#include "xgboost/json.h" // for Json
|
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
class FederatedColl : public Coll {
|
class FederatedColl : public Coll {
|
||||||
@ -20,8 +17,7 @@ class FederatedColl : public Coll {
|
|||||||
ArrayInterfaceHandler::Type type, Op op) override;
|
ArrayInterfaceHandler::Type type, Op op) override;
|
||||||
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
||||||
std::int32_t root) override;
|
std::int32_t root) override;
|
||||||
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
|
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data) override;
|
||||||
std::int64_t) override;
|
|
||||||
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
|
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
|
||||||
common::Span<std::int64_t const> sizes,
|
common::Span<std::int64_t const> sizes,
|
||||||
common::Span<std::int64_t> recv_segments,
|
common::Span<std::int64_t> recv_segments,
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
@ -9,7 +9,6 @@
|
|||||||
#include "../../src/common/device_helpers.cuh" // for CUDAStreamView
|
#include "../../src/common/device_helpers.cuh" // for CUDAStreamView
|
||||||
#include "federated_comm.h" // for FederatedComm
|
#include "federated_comm.h" // for FederatedComm
|
||||||
#include "xgboost/context.h" // for Context
|
#include "xgboost/context.h" // for Context
|
||||||
#include "xgboost/logging.h"
|
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
class CUDAFederatedComm : public FederatedComm {
|
class CUDAFederatedComm : public FederatedComm {
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost contributors
|
* Copyright 2023-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
@ -11,7 +11,6 @@
|
|||||||
#include <string> // for string
|
#include <string> // for string
|
||||||
|
|
||||||
#include "../../src/collective/comm.h" // for HostComm
|
#include "../../src/collective/comm.h" // for HostComm
|
||||||
#include "../../src/common/json_utils.h" // for OptionalArg
|
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
@ -51,6 +50,10 @@ class FederatedComm : public HostComm {
|
|||||||
std::int32_t rank) {
|
std::int32_t rank) {
|
||||||
this->Init(host, port, world, rank, {}, {}, {});
|
this->Init(host, port, world, rank, {}, {}, {});
|
||||||
}
|
}
|
||||||
|
[[nodiscard]] Result Shutdown() final {
|
||||||
|
this->ResetState();
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
~FederatedComm() override { stub_.reset(); }
|
~FederatedComm() override { stub_.reset(); }
|
||||||
|
|
||||||
[[nodiscard]] std::shared_ptr<Channel> Chan(std::int32_t) const override {
|
[[nodiscard]] std::shared_ptr<Channel> Chan(std::int32_t) const override {
|
||||||
@ -65,5 +68,13 @@ class FederatedComm : public HostComm {
|
|||||||
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }
|
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }
|
||||||
|
|
||||||
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
|
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
|
||||||
|
/**
|
||||||
|
* @brief Get a string ID for the current process.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] Result ProcessorName(std::string* out) const final {
|
||||||
|
auto rank = this->Rank();
|
||||||
|
*out = "rank:" + std::to_string(rank);
|
||||||
|
return Success();
|
||||||
|
};
|
||||||
};
|
};
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -1,22 +1,18 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2022-2023, XGBoost contributors
|
* Copyright 2022-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <federated.old.grpc.pb.h>
|
#include <federated.old.grpc.pb.h>
|
||||||
|
|
||||||
#include <cstdint> // for int32_t
|
#include <cstdint> // for int32_t
|
||||||
#include <future> // for future
|
|
||||||
|
|
||||||
#include "../../src/collective/in_memory_handler.h"
|
#include "../../src/collective/in_memory_handler.h"
|
||||||
#include "../../src/collective/tracker.h" // for Tracker
|
|
||||||
#include "xgboost/collective/result.h" // for Result
|
|
||||||
|
|
||||||
namespace xgboost::federated {
|
namespace xgboost::federated {
|
||||||
class FederatedService final : public Federated::Service {
|
class FederatedService final : public Federated::Service {
|
||||||
public:
|
public:
|
||||||
explicit FederatedService(std::int32_t world_size)
|
explicit FederatedService(std::int32_t world_size) : handler_{world_size} {}
|
||||||
: handler_{static_cast<std::size_t>(world_size)} {}
|
|
||||||
|
|
||||||
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
||||||
AllgatherReply* reply) override;
|
AllgatherReply* reply) override;
|
||||||
|
|||||||
@ -125,14 +125,14 @@ Result FederatedTracker::Shutdown() {
|
|||||||
|
|
||||||
[[nodiscard]] Json FederatedTracker::WorkerArgs() const {
|
[[nodiscard]] Json FederatedTracker::WorkerArgs() const {
|
||||||
auto rc = this->WaitUntilReady();
|
auto rc = this->WaitUntilReady();
|
||||||
CHECK(rc.OK()) << rc.Report();
|
SafeColl(rc);
|
||||||
|
|
||||||
std::string host;
|
std::string host;
|
||||||
rc = GetHostAddress(&host);
|
rc = GetHostAddress(&host);
|
||||||
CHECK(rc.OK());
|
CHECK(rc.OK());
|
||||||
Json args{Object{}};
|
Json args{Object{}};
|
||||||
args["DMLC_TRACKER_URI"] = String{host};
|
args["dmlc_tracker_uri"] = String{host};
|
||||||
args["DMLC_TRACKER_PORT"] = this->Port();
|
args["dmlc_tracker_port"] = this->Port();
|
||||||
return args;
|
return args;
|
||||||
}
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -17,8 +17,7 @@ namespace xgboost::collective {
|
|||||||
namespace federated {
|
namespace federated {
|
||||||
class FederatedService final : public Federated::Service {
|
class FederatedService final : public Federated::Service {
|
||||||
public:
|
public:
|
||||||
explicit FederatedService(std::int32_t world_size)
|
explicit FederatedService(std::int32_t world_size) : handler_{world_size} {}
|
||||||
: handler_{static_cast<std::size_t>(world_size)} {}
|
|
||||||
|
|
||||||
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
||||||
AllgatherReply* reply) override;
|
AllgatherReply* reply) override;
|
||||||
|
|||||||
334
plugin/sycl/common/hist_util.cc
Normal file
334
plugin/sycl/common/hist_util.cc
Normal file
@ -0,0 +1,334 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2017-2023 by Contributors
|
||||||
|
* \file hist_util.cc
|
||||||
|
*/
|
||||||
|
#include <vector>
|
||||||
|
#include <limits>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "../data/gradient_index.h"
|
||||||
|
#include "hist_util.h"
|
||||||
|
|
||||||
|
#include <CL/sycl.hpp>
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace sycl {
|
||||||
|
namespace common {
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Fill histogram with zeroes
|
||||||
|
*/
|
||||||
|
template<typename GradientSumT>
|
||||||
|
void InitHist(::sycl::queue qu, GHistRow<GradientSumT, MemoryType::on_device>* hist,
|
||||||
|
size_t size, ::sycl::event* event) {
|
||||||
|
*event = qu.fill(hist->Begin(),
|
||||||
|
xgboost::detail::GradientPairInternal<GradientSumT>(), size, *event);
|
||||||
|
}
|
||||||
|
template void InitHist(::sycl::queue qu,
|
||||||
|
GHistRow<float, MemoryType::on_device>* hist,
|
||||||
|
size_t size, ::sycl::event* event);
|
||||||
|
template void InitHist(::sycl::queue qu,
|
||||||
|
GHistRow<double, MemoryType::on_device>* hist,
|
||||||
|
size_t size, ::sycl::event* event);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Compute Subtraction: dst = src1 - src2
|
||||||
|
*/
|
||||||
|
template<typename GradientSumT>
|
||||||
|
::sycl::event SubtractionHist(::sycl::queue qu,
|
||||||
|
GHistRow<GradientSumT, MemoryType::on_device>* dst,
|
||||||
|
const GHistRow<GradientSumT, MemoryType::on_device>& src1,
|
||||||
|
const GHistRow<GradientSumT, MemoryType::on_device>& src2,
|
||||||
|
size_t size, ::sycl::event event_priv) {
|
||||||
|
GradientSumT* pdst = reinterpret_cast<GradientSumT*>(dst->Data());
|
||||||
|
const GradientSumT* psrc1 = reinterpret_cast<const GradientSumT*>(src1.DataConst());
|
||||||
|
const GradientSumT* psrc2 = reinterpret_cast<const GradientSumT*>(src2.DataConst());
|
||||||
|
|
||||||
|
auto event_final = qu.submit([&](::sycl::handler& cgh) {
|
||||||
|
cgh.depends_on(event_priv);
|
||||||
|
cgh.parallel_for<>(::sycl::range<1>(2 * size), [pdst, psrc1, psrc2](::sycl::item<1> pid) {
|
||||||
|
const size_t i = pid.get_id(0);
|
||||||
|
pdst[i] = psrc1[i] - psrc2[i];
|
||||||
|
});
|
||||||
|
});
|
||||||
|
return event_final;
|
||||||
|
}
|
||||||
|
template ::sycl::event SubtractionHist(::sycl::queue qu,
|
||||||
|
GHistRow<float, MemoryType::on_device>* dst,
|
||||||
|
const GHistRow<float, MemoryType::on_device>& src1,
|
||||||
|
const GHistRow<float, MemoryType::on_device>& src2,
|
||||||
|
size_t size, ::sycl::event event_priv);
|
||||||
|
template ::sycl::event SubtractionHist(::sycl::queue qu,
|
||||||
|
GHistRow<double, MemoryType::on_device>* dst,
|
||||||
|
const GHistRow<double, MemoryType::on_device>& src1,
|
||||||
|
const GHistRow<double, MemoryType::on_device>& src2,
|
||||||
|
size_t size, ::sycl::event event_priv);
|
||||||
|
|
||||||
|
// Kernel with buffer using
|
||||||
|
template<typename FPType, typename BinIdxType, bool isDense>
|
||||||
|
::sycl::event BuildHistKernel(::sycl::queue qu,
|
||||||
|
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
|
||||||
|
const RowSetCollection::Elem& row_indices,
|
||||||
|
const GHistIndexMatrix& gmat,
|
||||||
|
GHistRow<FPType, MemoryType::on_device>* hist,
|
||||||
|
GHistRow<FPType, MemoryType::on_device>* hist_buffer,
|
||||||
|
::sycl::event event_priv) {
|
||||||
|
const size_t size = row_indices.Size();
|
||||||
|
const size_t* rid = row_indices.begin;
|
||||||
|
const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride;
|
||||||
|
const GradientPair::ValueT* pgh =
|
||||||
|
reinterpret_cast<const GradientPair::ValueT*>(gpair_device.DataConst());
|
||||||
|
const BinIdxType* gradient_index = gmat.index.data<BinIdxType>();
|
||||||
|
const uint32_t* offsets = gmat.index.Offset();
|
||||||
|
FPType* hist_data = reinterpret_cast<FPType*>(hist->Data());
|
||||||
|
const size_t nbins = gmat.nbins;
|
||||||
|
|
||||||
|
const size_t max_work_group_size =
|
||||||
|
qu.get_device().get_info<::sycl::info::device::max_work_group_size>();
|
||||||
|
const size_t work_group_size = n_columns < max_work_group_size ? n_columns : max_work_group_size;
|
||||||
|
|
||||||
|
const size_t max_nblocks = hist_buffer->Size() / (nbins * 2);
|
||||||
|
const size_t min_block_size = 128;
|
||||||
|
size_t nblocks = std::min(max_nblocks, size / min_block_size + !!(size % min_block_size));
|
||||||
|
const size_t block_size = size / nblocks + !!(size % nblocks);
|
||||||
|
FPType* hist_buffer_data = reinterpret_cast<FPType*>(hist_buffer->Data());
|
||||||
|
|
||||||
|
auto event_fill = qu.fill(hist_buffer_data, FPType(0), nblocks * nbins * 2, event_priv);
|
||||||
|
auto event_main = qu.submit([&](::sycl::handler& cgh) {
|
||||||
|
cgh.depends_on(event_fill);
|
||||||
|
cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(nblocks, work_group_size),
|
||||||
|
::sycl::range<2>(1, work_group_size)),
|
||||||
|
[=](::sycl::nd_item<2> pid) {
|
||||||
|
size_t block = pid.get_global_id(0);
|
||||||
|
size_t feat = pid.get_global_id(1);
|
||||||
|
|
||||||
|
FPType* hist_local = hist_buffer_data + block * nbins * 2;
|
||||||
|
for (size_t idx = 0; idx < block_size; ++idx) {
|
||||||
|
size_t i = block * block_size + idx;
|
||||||
|
if (i < size) {
|
||||||
|
const size_t icol_start = n_columns * rid[i];
|
||||||
|
const size_t idx_gh = rid[i];
|
||||||
|
|
||||||
|
pid.barrier(::sycl::access::fence_space::local_space);
|
||||||
|
const BinIdxType* gr_index_local = gradient_index + icol_start;
|
||||||
|
|
||||||
|
for (size_t j = feat; j < n_columns; j += work_group_size) {
|
||||||
|
uint32_t idx_bin = static_cast<uint32_t>(gr_index_local[j]);
|
||||||
|
if constexpr (isDense) {
|
||||||
|
idx_bin += offsets[j];
|
||||||
|
}
|
||||||
|
if (idx_bin < nbins) {
|
||||||
|
hist_local[2 * idx_bin] += pgh[2 * idx_gh];
|
||||||
|
hist_local[2 * idx_bin+1] += pgh[2 * idx_gh+1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
auto event_save = qu.submit([&](::sycl::handler& cgh) {
|
||||||
|
cgh.depends_on(event_main);
|
||||||
|
cgh.parallel_for<>(::sycl::range<1>(nbins), [=](::sycl::item<1> pid) {
|
||||||
|
size_t idx_bin = pid.get_id(0);
|
||||||
|
|
||||||
|
FPType gsum = 0.0f;
|
||||||
|
FPType hsum = 0.0f;
|
||||||
|
|
||||||
|
for (size_t j = 0; j < nblocks; ++j) {
|
||||||
|
gsum += hist_buffer_data[j * nbins * 2 + 2 * idx_bin];
|
||||||
|
hsum += hist_buffer_data[j * nbins * 2 + 2 * idx_bin + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
hist_data[2 * idx_bin] = gsum;
|
||||||
|
hist_data[2 * idx_bin + 1] = hsum;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
return event_save;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kernel with atomic using
|
||||||
|
template<typename FPType, typename BinIdxType, bool isDense>
|
||||||
|
::sycl::event BuildHistKernel(::sycl::queue qu,
|
||||||
|
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
|
||||||
|
const RowSetCollection::Elem& row_indices,
|
||||||
|
const GHistIndexMatrix& gmat,
|
||||||
|
GHistRow<FPType, MemoryType::on_device>* hist,
|
||||||
|
::sycl::event event_priv) {
|
||||||
|
const size_t size = row_indices.Size();
|
||||||
|
const size_t* rid = row_indices.begin;
|
||||||
|
const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride;
|
||||||
|
const GradientPair::ValueT* pgh =
|
||||||
|
reinterpret_cast<const GradientPair::ValueT*>(gpair_device.DataConst());
|
||||||
|
const BinIdxType* gradient_index = gmat.index.data<BinIdxType>();
|
||||||
|
const uint32_t* offsets = gmat.index.Offset();
|
||||||
|
FPType* hist_data = reinterpret_cast<FPType*>(hist->Data());
|
||||||
|
const size_t nbins = gmat.nbins;
|
||||||
|
|
||||||
|
const size_t max_work_group_size =
|
||||||
|
qu.get_device().get_info<::sycl::info::device::max_work_group_size>();
|
||||||
|
const size_t feat_local = n_columns < max_work_group_size ? n_columns : max_work_group_size;
|
||||||
|
|
||||||
|
auto event_fill = qu.fill(hist_data, FPType(0), nbins * 2, event_priv);
|
||||||
|
auto event_main = qu.submit([&](::sycl::handler& cgh) {
|
||||||
|
cgh.depends_on(event_fill);
|
||||||
|
cgh.parallel_for<>(::sycl::range<2>(size, feat_local),
|
||||||
|
[=](::sycl::item<2> pid) {
|
||||||
|
size_t i = pid.get_id(0);
|
||||||
|
size_t feat = pid.get_id(1);
|
||||||
|
|
||||||
|
const size_t icol_start = n_columns * rid[i];
|
||||||
|
const size_t idx_gh = rid[i];
|
||||||
|
|
||||||
|
const BinIdxType* gr_index_local = gradient_index + icol_start;
|
||||||
|
|
||||||
|
for (size_t j = feat; j < n_columns; j += feat_local) {
|
||||||
|
uint32_t idx_bin = static_cast<uint32_t>(gr_index_local[j]);
|
||||||
|
if constexpr (isDense) {
|
||||||
|
idx_bin += offsets[j];
|
||||||
|
}
|
||||||
|
if (idx_bin < nbins) {
|
||||||
|
AtomicRef<FPType> gsum(hist_data[2 * idx_bin]);
|
||||||
|
AtomicRef<FPType> hsum(hist_data[2 * idx_bin + 1]);
|
||||||
|
gsum.fetch_add(pgh[2 * idx_gh]);
|
||||||
|
hsum.fetch_add(pgh[2 * idx_gh + 1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
return event_main;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename FPType, typename BinIdxType>
|
||||||
|
::sycl::event BuildHistDispatchKernel(
|
||||||
|
::sycl::queue qu,
|
||||||
|
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
|
||||||
|
const RowSetCollection::Elem& row_indices,
|
||||||
|
const GHistIndexMatrix& gmat,
|
||||||
|
GHistRow<FPType, MemoryType::on_device>* hist,
|
||||||
|
bool isDense,
|
||||||
|
GHistRow<FPType, MemoryType::on_device>* hist_buffer,
|
||||||
|
::sycl::event events_priv,
|
||||||
|
bool force_atomic_use) {
|
||||||
|
const size_t size = row_indices.Size();
|
||||||
|
const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride;
|
||||||
|
const size_t nbins = gmat.nbins;
|
||||||
|
|
||||||
|
// max cycle size, while atomics are still effective
|
||||||
|
const size_t max_cycle_size_atomics = nbins;
|
||||||
|
const size_t cycle_size = size;
|
||||||
|
|
||||||
|
// TODO(razdoburdin): replace the add-hock dispatching criteria by more sutable one
|
||||||
|
bool use_atomic = (size < nbins) || (gmat.max_num_bins == gmat.nbins / n_columns);
|
||||||
|
|
||||||
|
// force_atomic_use flag is used only for testing
|
||||||
|
use_atomic = use_atomic || force_atomic_use;
|
||||||
|
if (!use_atomic) {
|
||||||
|
if (isDense) {
|
||||||
|
return BuildHistKernel<FPType, BinIdxType, true>(qu, gpair_device, row_indices,
|
||||||
|
gmat, hist, hist_buffer,
|
||||||
|
events_priv);
|
||||||
|
} else {
|
||||||
|
return BuildHistKernel<FPType, uint32_t, false>(qu, gpair_device, row_indices,
|
||||||
|
gmat, hist, hist_buffer,
|
||||||
|
events_priv);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (isDense) {
|
||||||
|
return BuildHistKernel<FPType, BinIdxType, true>(qu, gpair_device, row_indices,
|
||||||
|
gmat, hist, events_priv);
|
||||||
|
} else {
|
||||||
|
return BuildHistKernel<FPType, uint32_t, false>(qu, gpair_device, row_indices,
|
||||||
|
gmat, hist, events_priv);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename FPType>
|
||||||
|
::sycl::event BuildHistKernel(::sycl::queue qu,
|
||||||
|
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
|
||||||
|
const RowSetCollection::Elem& row_indices,
|
||||||
|
const GHistIndexMatrix& gmat, const bool isDense,
|
||||||
|
GHistRow<FPType, MemoryType::on_device>* hist,
|
||||||
|
GHistRow<FPType, MemoryType::on_device>* hist_buffer,
|
||||||
|
::sycl::event event_priv,
|
||||||
|
bool force_atomic_use) {
|
||||||
|
const bool is_dense = isDense;
|
||||||
|
switch (gmat.index.GetBinTypeSize()) {
|
||||||
|
case BinTypeSize::kUint8BinsTypeSize:
|
||||||
|
return BuildHistDispatchKernel<FPType, uint8_t>(qu, gpair_device, row_indices,
|
||||||
|
gmat, hist, is_dense, hist_buffer,
|
||||||
|
event_priv, force_atomic_use);
|
||||||
|
break;
|
||||||
|
case BinTypeSize::kUint16BinsTypeSize:
|
||||||
|
return BuildHistDispatchKernel<FPType, uint16_t>(qu, gpair_device, row_indices,
|
||||||
|
gmat, hist, is_dense, hist_buffer,
|
||||||
|
event_priv, force_atomic_use);
|
||||||
|
break;
|
||||||
|
case BinTypeSize::kUint32BinsTypeSize:
|
||||||
|
return BuildHistDispatchKernel<FPType, uint32_t>(qu, gpair_device, row_indices,
|
||||||
|
gmat, hist, is_dense, hist_buffer,
|
||||||
|
event_priv, force_atomic_use);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
CHECK(false); // no default behavior
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename GradientSumT>
|
||||||
|
::sycl::event GHistBuilder<GradientSumT>::BuildHist(
|
||||||
|
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
|
||||||
|
const RowSetCollection::Elem& row_indices,
|
||||||
|
const GHistIndexMatrix &gmat,
|
||||||
|
GHistRowT<MemoryType::on_device>* hist,
|
||||||
|
bool isDense,
|
||||||
|
GHistRowT<MemoryType::on_device>* hist_buffer,
|
||||||
|
::sycl::event event_priv,
|
||||||
|
bool force_atomic_use) {
|
||||||
|
return BuildHistKernel<GradientSumT>(qu_, gpair_device, row_indices, gmat,
|
||||||
|
isDense, hist, hist_buffer, event_priv,
|
||||||
|
force_atomic_use);
|
||||||
|
}
|
||||||
|
|
||||||
|
template
|
||||||
|
::sycl::event GHistBuilder<float>::BuildHist(
|
||||||
|
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
|
||||||
|
const RowSetCollection::Elem& row_indices,
|
||||||
|
const GHistIndexMatrix& gmat,
|
||||||
|
GHistRow<float, MemoryType::on_device>* hist,
|
||||||
|
bool isDense,
|
||||||
|
GHistRow<float, MemoryType::on_device>* hist_buffer,
|
||||||
|
::sycl::event event_priv,
|
||||||
|
bool force_atomic_use);
|
||||||
|
template
|
||||||
|
::sycl::event GHistBuilder<double>::BuildHist(
|
||||||
|
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
|
||||||
|
const RowSetCollection::Elem& row_indices,
|
||||||
|
const GHistIndexMatrix& gmat,
|
||||||
|
GHistRow<double, MemoryType::on_device>* hist,
|
||||||
|
bool isDense,
|
||||||
|
GHistRow<double, MemoryType::on_device>* hist_buffer,
|
||||||
|
::sycl::event event_priv,
|
||||||
|
bool force_atomic_use);
|
||||||
|
|
||||||
|
template<typename GradientSumT>
|
||||||
|
void GHistBuilder<GradientSumT>::SubtractionTrick(GHistRowT<MemoryType::on_device>* self,
|
||||||
|
const GHistRowT<MemoryType::on_device>& sibling,
|
||||||
|
const GHistRowT<MemoryType::on_device>& parent) {
|
||||||
|
const size_t size = self->Size();
|
||||||
|
CHECK_EQ(sibling.Size(), size);
|
||||||
|
CHECK_EQ(parent.Size(), size);
|
||||||
|
|
||||||
|
SubtractionHist(qu_, self, parent, sibling, size, ::sycl::event());
|
||||||
|
}
|
||||||
|
template
|
||||||
|
void GHistBuilder<float>::SubtractionTrick(GHistRow<float, MemoryType::on_device>* self,
|
||||||
|
const GHistRow<float, MemoryType::on_device>& sibling,
|
||||||
|
const GHistRow<float, MemoryType::on_device>& parent);
|
||||||
|
template
|
||||||
|
void GHistBuilder<double>::SubtractionTrick(GHistRow<double, MemoryType::on_device>* self,
|
||||||
|
const GHistRow<double, MemoryType::on_device>& sibling,
|
||||||
|
const GHistRow<double, MemoryType::on_device>& parent);
|
||||||
|
} // namespace common
|
||||||
|
} // namespace sycl
|
||||||
|
} // namespace xgboost
|
||||||
89
plugin/sycl/common/hist_util.h
Normal file
89
plugin/sycl/common/hist_util.h
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2017-2023 by Contributors
|
||||||
|
* \file hist_util.h
|
||||||
|
*/
|
||||||
|
#ifndef PLUGIN_SYCL_COMMON_HIST_UTIL_H_
|
||||||
|
#define PLUGIN_SYCL_COMMON_HIST_UTIL_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "../data.h"
|
||||||
|
#include "row_set.h"
|
||||||
|
|
||||||
|
#include "../../src/common/hist_util.h"
|
||||||
|
#include "../data/gradient_index.h"
|
||||||
|
|
||||||
|
#include <CL/sycl.hpp>
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace sycl {
|
||||||
|
namespace common {
|
||||||
|
|
||||||
|
template<typename GradientSumT, MemoryType memory_type = MemoryType::shared>
|
||||||
|
using GHistRow = USMVector<xgboost::detail::GradientPairInternal<GradientSumT>, memory_type>;
|
||||||
|
|
||||||
|
using BinTypeSize = ::xgboost::common::BinTypeSize;
|
||||||
|
|
||||||
|
class ColumnMatrix;
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Fill histogram with zeroes
|
||||||
|
*/
|
||||||
|
template<typename GradientSumT>
|
||||||
|
void InitHist(::sycl::queue qu,
|
||||||
|
GHistRow<GradientSumT, MemoryType::on_device>* hist,
|
||||||
|
size_t size, ::sycl::event* event);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Compute subtraction: dst = src1 - src2
|
||||||
|
*/
|
||||||
|
template<typename GradientSumT>
|
||||||
|
::sycl::event SubtractionHist(::sycl::queue qu,
|
||||||
|
GHistRow<GradientSumT, MemoryType::on_device>* dst,
|
||||||
|
const GHistRow<GradientSumT, MemoryType::on_device>& src1,
|
||||||
|
const GHistRow<GradientSumT, MemoryType::on_device>& src2,
|
||||||
|
size_t size, ::sycl::event event_priv);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Builder for histograms of gradient statistics
|
||||||
|
*/
|
||||||
|
template<typename GradientSumT>
|
||||||
|
class GHistBuilder {
|
||||||
|
public:
|
||||||
|
template<MemoryType memory_type = MemoryType::shared>
|
||||||
|
using GHistRowT = GHistRow<GradientSumT, memory_type>;
|
||||||
|
|
||||||
|
GHistBuilder() = default;
|
||||||
|
GHistBuilder(::sycl::queue qu, uint32_t nbins) : qu_{qu}, nbins_{nbins} {}
|
||||||
|
|
||||||
|
// Construct a histogram via histogram aggregation
|
||||||
|
::sycl::event BuildHist(const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
|
||||||
|
const RowSetCollection::Elem& row_indices,
|
||||||
|
const GHistIndexMatrix& gmat,
|
||||||
|
GHistRowT<MemoryType::on_device>* HistCollection,
|
||||||
|
bool isDense,
|
||||||
|
GHistRowT<MemoryType::on_device>* hist_buffer,
|
||||||
|
::sycl::event event,
|
||||||
|
bool force_atomic_use = false);
|
||||||
|
|
||||||
|
// Construct a histogram via subtraction trick
|
||||||
|
void SubtractionTrick(GHistRowT<MemoryType::on_device>* self,
|
||||||
|
const GHistRowT<MemoryType::on_device>& sibling,
|
||||||
|
const GHistRowT<MemoryType::on_device>& parent);
|
||||||
|
|
||||||
|
uint32_t GetNumBins() const {
|
||||||
|
return nbins_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/*! \brief Number of all bins over all features */
|
||||||
|
uint32_t nbins_ { 0 };
|
||||||
|
|
||||||
|
::sycl::queue qu_;
|
||||||
|
};
|
||||||
|
} // namespace common
|
||||||
|
} // namespace sycl
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // PLUGIN_SYCL_COMMON_HIST_UTIL_H_
|
||||||
55
plugin/sycl/tree/updater_quantile_hist.cc
Normal file
55
plugin/sycl/tree/updater_quantile_hist.cc
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2017-2024 by Contributors
|
||||||
|
* \file updater_quantile_hist.cc
|
||||||
|
*/
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#pragma GCC diagnostic push
|
||||||
|
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||||
|
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
||||||
|
#include "xgboost/tree_updater.h"
|
||||||
|
#pragma GCC diagnostic pop
|
||||||
|
|
||||||
|
#include "xgboost/logging.h"
|
||||||
|
|
||||||
|
#include "updater_quantile_hist.h"
|
||||||
|
#include "../data.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace sycl {
|
||||||
|
namespace tree {
|
||||||
|
|
||||||
|
DMLC_REGISTRY_FILE_TAG(updater_quantile_hist_sycl);
|
||||||
|
|
||||||
|
DMLC_REGISTER_PARAMETER(HistMakerTrainParam);
|
||||||
|
|
||||||
|
void QuantileHistMaker::Configure(const Args& args) {
|
||||||
|
const DeviceOrd device_spec = ctx_->Device();
|
||||||
|
qu_ = device_manager.GetQueue(device_spec);
|
||||||
|
|
||||||
|
param_.UpdateAllowUnknown(args);
|
||||||
|
hist_maker_param_.UpdateAllowUnknown(args);
|
||||||
|
}
|
||||||
|
|
||||||
|
void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param,
|
||||||
|
linalg::Matrix<GradientPair>* gpair,
|
||||||
|
DMatrix *dmat,
|
||||||
|
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||||
|
const std::vector<RegTree *> &trees) {
|
||||||
|
LOG(FATAL) << "Not Implemented yet";
|
||||||
|
}
|
||||||
|
|
||||||
|
bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data,
|
||||||
|
linalg::MatrixView<float> out_preds) {
|
||||||
|
LOG(FATAL) << "Not Implemented yet";
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker_sycl")
|
||||||
|
.describe("Grow tree using quantized histogram with SYCL.")
|
||||||
|
.set_body(
|
||||||
|
[](Context const* ctx, ObjInfo const * task) {
|
||||||
|
return new QuantileHistMaker(ctx, task);
|
||||||
|
});
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace sycl
|
||||||
|
} // namespace xgboost
|
||||||
91
plugin/sycl/tree/updater_quantile_hist.h
Normal file
91
plugin/sycl/tree/updater_quantile_hist.h
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2017-2024 by Contributors
|
||||||
|
* \file updater_quantile_hist.h
|
||||||
|
*/
|
||||||
|
#ifndef PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_
|
||||||
|
#define PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_
|
||||||
|
|
||||||
|
#include <dmlc/timer.h>
|
||||||
|
#include <xgboost/tree_updater.h>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../data/gradient_index.h"
|
||||||
|
#include "../common/hist_util.h"
|
||||||
|
#include "../common/row_set.h"
|
||||||
|
#include "../common/partition_builder.h"
|
||||||
|
#include "split_evaluator.h"
|
||||||
|
#include "../device_manager.h"
|
||||||
|
|
||||||
|
#include "xgboost/data.h"
|
||||||
|
#include "xgboost/json.h"
|
||||||
|
#include "../../src/tree/constraints.h"
|
||||||
|
#include "../../src/common/random.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace sycl {
|
||||||
|
namespace tree {
|
||||||
|
|
||||||
|
// training parameters specific to this algorithm
|
||||||
|
struct HistMakerTrainParam
|
||||||
|
: public XGBoostParameter<HistMakerTrainParam> {
|
||||||
|
bool single_precision_histogram = false;
|
||||||
|
// declare parameters
|
||||||
|
DMLC_DECLARE_PARAMETER(HistMakerTrainParam) {
|
||||||
|
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
|
||||||
|
"Use single precision to build histograms.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*! \brief construct a tree using quantized feature values with SYCL backend*/
|
||||||
|
class QuantileHistMaker: public TreeUpdater {
|
||||||
|
public:
|
||||||
|
QuantileHistMaker(Context const* ctx, ObjInfo const * task) :
|
||||||
|
TreeUpdater(ctx), task_{task} {
|
||||||
|
updater_monitor_.Init("SYCLQuantileHistMaker");
|
||||||
|
}
|
||||||
|
void Configure(const Args& args) override;
|
||||||
|
|
||||||
|
void Update(xgboost::tree::TrainParam const *param,
|
||||||
|
linalg::Matrix<GradientPair>* gpair,
|
||||||
|
DMatrix* dmat,
|
||||||
|
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
|
||||||
|
const std::vector<RegTree*>& trees) override;
|
||||||
|
|
||||||
|
bool UpdatePredictionCache(const DMatrix* data,
|
||||||
|
linalg::MatrixView<float> out_preds) override;
|
||||||
|
|
||||||
|
void LoadConfig(Json const& in) override {
|
||||||
|
auto const& config = get<Object const>(in);
|
||||||
|
FromJson(config.at("train_param"), &this->param_);
|
||||||
|
FromJson(config.at("sycl_hist_train_param"), &this->hist_maker_param_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SaveConfig(Json* p_out) const override {
|
||||||
|
auto& out = *p_out;
|
||||||
|
out["train_param"] = ToJson(param_);
|
||||||
|
out["sycl_hist_train_param"] = ToJson(hist_maker_param_);
|
||||||
|
}
|
||||||
|
|
||||||
|
char const* Name() const override {
|
||||||
|
return "grow_quantile_histmaker_sycl";
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
HistMakerTrainParam hist_maker_param_;
|
||||||
|
// training parameter
|
||||||
|
xgboost::tree::TrainParam param_;
|
||||||
|
|
||||||
|
xgboost::common::Monitor updater_monitor_;
|
||||||
|
|
||||||
|
::sycl::queue qu_;
|
||||||
|
DeviceManager device_manager;
|
||||||
|
ObjInfo const *task_{nullptr};
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace sycl
|
||||||
|
} // namespace xgboost
|
||||||
|
|
||||||
|
#endif // PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_
|
||||||
@ -909,9 +909,19 @@ def _transform_cudf_df(
|
|||||||
enable_categorical: bool,
|
enable_categorical: bool,
|
||||||
) -> Tuple[ctypes.c_void_p, list, Optional[FeatureNames], Optional[FeatureTypes]]:
|
) -> Tuple[ctypes.c_void_p, list, Optional[FeatureNames], Optional[FeatureTypes]]:
|
||||||
try:
|
try:
|
||||||
from cudf.api.types import is_categorical_dtype
|
from cudf.api.types import is_bool_dtype, is_categorical_dtype
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from cudf.utils.dtypes import is_categorical_dtype
|
from cudf.utils.dtypes import is_categorical_dtype
|
||||||
|
from pandas.api.types import is_bool_dtype
|
||||||
|
|
||||||
|
# Work around https://github.com/dmlc/xgboost/issues/10181
|
||||||
|
if _is_cudf_ser(data):
|
||||||
|
if is_bool_dtype(data.dtype):
|
||||||
|
data = data.astype(np.uint8)
|
||||||
|
else:
|
||||||
|
data = data.astype(
|
||||||
|
{col: np.uint8 for col in data.select_dtypes(include="bool")}
|
||||||
|
)
|
||||||
|
|
||||||
if _is_cudf_ser(data):
|
if _is_cudf_ser(data):
|
||||||
dtypes = [data.dtype]
|
dtypes = [data.dtype]
|
||||||
|
|||||||
@ -347,15 +347,14 @@ class _SparkXGBParams(
|
|||||||
predict_params[param.name] = self.getOrDefault(param)
|
predict_params[param.name] = self.getOrDefault(param)
|
||||||
return predict_params
|
return predict_params
|
||||||
|
|
||||||
def _validate_gpu_params(self) -> None:
|
def _validate_gpu_params(
|
||||||
|
self, spark_version: str, conf: SparkConf, is_local: bool = False
|
||||||
|
) -> None:
|
||||||
"""Validate the gpu parameters and gpu configurations"""
|
"""Validate the gpu parameters and gpu configurations"""
|
||||||
|
|
||||||
if self._run_on_gpu():
|
if self._run_on_gpu():
|
||||||
ss = _get_spark_session()
|
if is_local:
|
||||||
sc = ss.sparkContext
|
# Supporting GPU training in Spark local mode is just for debugging
|
||||||
|
|
||||||
if _is_local(sc):
|
|
||||||
# Support GPU training in Spark local mode is just for debugging
|
|
||||||
# purposes, so it's okay for printing the below warning instead of
|
# purposes, so it's okay for printing the below warning instead of
|
||||||
# checking the real gpu numbers and raising the exception.
|
# checking the real gpu numbers and raising the exception.
|
||||||
get_logger(self.__class__.__name__).warning(
|
get_logger(self.__class__.__name__).warning(
|
||||||
@ -364,33 +363,41 @@ class _SparkXGBParams(
|
|||||||
self.getOrDefault(self.num_workers),
|
self.getOrDefault(self.num_workers),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
executor_gpus = sc.getConf().get("spark.executor.resource.gpu.amount")
|
executor_gpus = conf.get("spark.executor.resource.gpu.amount")
|
||||||
if executor_gpus is None:
|
if executor_gpus is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The `spark.executor.resource.gpu.amount` is required for training"
|
"The `spark.executor.resource.gpu.amount` is required for training"
|
||||||
" on GPU."
|
" on GPU."
|
||||||
)
|
)
|
||||||
|
gpu_per_task = conf.get("spark.task.resource.gpu.amount")
|
||||||
if not (
|
if gpu_per_task is not None and float(gpu_per_task) > 1.0:
|
||||||
ss.version >= "3.4.0"
|
get_logger(self.__class__.__name__).warning(
|
||||||
and _is_standalone_or_localcluster(sc.getConf())
|
"The configuration assigns %s GPUs to each Spark task, but each "
|
||||||
|
"XGBoost training task only utilizes 1 GPU, which will lead to "
|
||||||
|
"unnecessary GPU waste",
|
||||||
|
gpu_per_task,
|
||||||
|
)
|
||||||
|
# For 3.5.1+, Spark supports task stage-level scheduling for
|
||||||
|
# Yarn/K8s/Standalone/Local cluster
|
||||||
|
# From 3.4.0 ~ 3.5.0, Spark only supports task stage-level scheduing for
|
||||||
|
# Standalone/Local cluster
|
||||||
|
# For spark below 3.4.0, Task stage-level scheduling is not supported.
|
||||||
|
#
|
||||||
|
# With stage-level scheduling, spark.task.resource.gpu.amount is not required
|
||||||
|
# to be set explicitly. Or else, spark.task.resource.gpu.amount is a must-have and
|
||||||
|
# must be set to 1.0
|
||||||
|
if spark_version < "3.4.0" or (
|
||||||
|
"3.4.0" <= spark_version < "3.5.1"
|
||||||
|
and not _is_standalone_or_localcluster(conf)
|
||||||
):
|
):
|
||||||
# We will enable stage-level scheduling in spark 3.4.0+ which doesn't
|
|
||||||
# require spark.task.resource.gpu.amount to be set explicitly
|
|
||||||
gpu_per_task = sc.getConf().get("spark.task.resource.gpu.amount")
|
|
||||||
if gpu_per_task is not None:
|
if gpu_per_task is not None:
|
||||||
if float(gpu_per_task) < 1.0:
|
if float(gpu_per_task) < 1.0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"XGBoost doesn't support GPU fractional configurations. "
|
"XGBoost doesn't support GPU fractional configurations. Please set "
|
||||||
"Please set `spark.task.resource.gpu.amount=spark.executor"
|
"`spark.task.resource.gpu.amount=spark.executor.resource.gpu."
|
||||||
".resource.gpu.amount`"
|
"amount`. To enable GPU fractional configurations, you can try "
|
||||||
)
|
"standalone/localcluster with spark 3.4.0+ and"
|
||||||
|
"YARN/K8S with spark 3.5.1+"
|
||||||
if float(gpu_per_task) > 1.0:
|
|
||||||
get_logger(self.__class__.__name__).warning(
|
|
||||||
"%s GPUs for each Spark task is configured, but each "
|
|
||||||
"XGBoost training task uses only 1 GPU.",
|
|
||||||
gpu_per_task,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -475,7 +482,9 @@ class _SparkXGBParams(
|
|||||||
"`pyspark.ml.linalg.Vector` type."
|
"`pyspark.ml.linalg.Vector` type."
|
||||||
)
|
)
|
||||||
|
|
||||||
self._validate_gpu_params()
|
ss = _get_spark_session()
|
||||||
|
sc = ss.sparkContext
|
||||||
|
self._validate_gpu_params(ss.version, sc.getConf(), _is_local(sc))
|
||||||
|
|
||||||
def _run_on_gpu(self) -> bool:
|
def _run_on_gpu(self) -> bool:
|
||||||
"""If train or transform on the gpu according to the parameters"""
|
"""If train or transform on the gpu according to the parameters"""
|
||||||
@ -925,10 +934,14 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if not _is_standalone_or_localcluster(conf):
|
if (
|
||||||
|
"3.4.0" <= spark_version < "3.5.1"
|
||||||
|
and not _is_standalone_or_localcluster(conf)
|
||||||
|
):
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"Stage-level scheduling in xgboost requires spark standalone or "
|
"For %s, Stage-level scheduling in xgboost requires spark standalone "
|
||||||
"local-cluster mode"
|
"or local-cluster mode",
|
||||||
|
spark_version,
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -980,7 +993,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
"""Try to enable stage-level scheduling"""
|
"""Try to enable stage-level scheduling"""
|
||||||
ss = _get_spark_session()
|
ss = _get_spark_session()
|
||||||
conf = ss.sparkContext.getConf()
|
conf = ss.sparkContext.getConf()
|
||||||
if self._skip_stage_level_scheduling(ss.version, conf):
|
if _is_local(ss.sparkContext) or self._skip_stage_level_scheduling(
|
||||||
|
ss.version, conf
|
||||||
|
):
|
||||||
return rdd
|
return rdd
|
||||||
|
|
||||||
# executor_cores will not be None
|
# executor_cores will not be None
|
||||||
@ -1052,6 +1067,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
|
|
||||||
dev_ordinal = None
|
dev_ordinal = None
|
||||||
use_qdm = _can_use_qdm(booster_params.get("tree_method", None))
|
use_qdm = _can_use_qdm(booster_params.get("tree_method", None))
|
||||||
|
verbosity = booster_params.get("verbosity", 1)
|
||||||
msg = "Training on CPUs"
|
msg = "Training on CPUs"
|
||||||
if run_on_gpu:
|
if run_on_gpu:
|
||||||
dev_ordinal = (
|
dev_ordinal = (
|
||||||
@ -1089,6 +1105,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
|||||||
|
|
||||||
evals_result: Dict[str, Any] = {}
|
evals_result: Dict[str, Any] = {}
|
||||||
with CommunicatorContext(context, **_rabit_args):
|
with CommunicatorContext(context, **_rabit_args):
|
||||||
|
with xgboost.config_context(verbosity=verbosity):
|
||||||
dtrain, dvalid = create_dmatrix_from_partitions(
|
dtrain, dvalid = create_dmatrix_from_partitions(
|
||||||
pandas_df_iter,
|
pandas_df_iter,
|
||||||
feature_prop.features_cols_names,
|
feature_prop.features_cols_names,
|
||||||
|
|||||||
@ -14,7 +14,8 @@ import pyspark
|
|||||||
from pyspark import BarrierTaskContext, SparkConf, SparkContext, SparkFiles, TaskContext
|
from pyspark import BarrierTaskContext, SparkConf, SparkContext, SparkFiles, TaskContext
|
||||||
from pyspark.sql.session import SparkSession
|
from pyspark.sql.session import SparkSession
|
||||||
|
|
||||||
from xgboost import Booster, XGBModel, collective
|
from xgboost import Booster, XGBModel
|
||||||
|
from xgboost.collective import CommunicatorContext as CCtx
|
||||||
from xgboost.tracker import RabitTracker
|
from xgboost.tracker import RabitTracker
|
||||||
|
|
||||||
|
|
||||||
@ -42,22 +43,12 @@ def _get_default_params_from_func(
|
|||||||
return filtered_params_dict
|
return filtered_params_dict
|
||||||
|
|
||||||
|
|
||||||
class CommunicatorContext:
|
class CommunicatorContext(CCtx): # pylint: disable=too-few-public-methods
|
||||||
"""A context controlling collective communicator initialization and finalization.
|
"""Context with PySpark specific task ID."""
|
||||||
This isn't specificially necessary (note Part 3), but it is more understandable
|
|
||||||
coding-wise.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, context: BarrierTaskContext, **args: Any) -> None:
|
def __init__(self, context: BarrierTaskContext, **args: Any) -> None:
|
||||||
self.args = args
|
args["DMLC_TASK_ID"] = str(context.partitionId())
|
||||||
self.args["DMLC_TASK_ID"] = str(context.partitionId())
|
super().__init__(**args)
|
||||||
|
|
||||||
def __enter__(self) -> None:
|
|
||||||
collective.init(**self.args)
|
|
||||||
|
|
||||||
def __exit__(self, *args: Any) -> None:
|
|
||||||
collective.finalize()
|
|
||||||
|
|
||||||
|
|
||||||
def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
|
def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
|
||||||
|
|||||||
@ -429,8 +429,8 @@ def make_categorical(
|
|||||||
categories = np.arange(0, n_categories)
|
categories = np.arange(0, n_categories)
|
||||||
for col in df.columns:
|
for col in df.columns:
|
||||||
if rng.binomial(1, cat_ratio, size=1)[0] == 1:
|
if rng.binomial(1, cat_ratio, size=1)[0] == 1:
|
||||||
df.loc[:, col] = df[col].astype("category")
|
df[col] = df[col].astype("category")
|
||||||
df.loc[:, col] = df[col].cat.set_categories(categories)
|
df[col] = df[col].cat.set_categories(categories)
|
||||||
|
|
||||||
if sparsity > 0.0:
|
if sparsity > 0.0:
|
||||||
for i in range(n_features):
|
for i in range(n_features):
|
||||||
|
|||||||
@ -100,6 +100,24 @@ std::enable_if_t<std::is_integral_v<E>, xgboost::collective::Result> PollError(E
|
|||||||
if ((revents & POLLNVAL) != 0) {
|
if ((revents & POLLNVAL) != 0) {
|
||||||
return xgboost::system::FailWithCode("Invalid polling request.");
|
return xgboost::system::FailWithCode("Invalid polling request.");
|
||||||
}
|
}
|
||||||
|
if ((revents & POLLHUP) != 0) {
|
||||||
|
// Excerpt from the Linux manual:
|
||||||
|
//
|
||||||
|
// Note that when reading from a channel such as a pipe or a stream socket, this event
|
||||||
|
// merely indicates that the peer closed its end of the channel.Subsequent reads from
|
||||||
|
// the channel will return 0 (end of file) only after all outstanding data in the
|
||||||
|
// channel has been consumed.
|
||||||
|
//
|
||||||
|
// We don't usually have a barrier for exiting workers, it's normal to have one end
|
||||||
|
// exit while the other still reading data.
|
||||||
|
return xgboost::collective::Success();
|
||||||
|
}
|
||||||
|
#if defined(POLLRDHUP)
|
||||||
|
// Linux only flag
|
||||||
|
if ((revents & POLLRDHUP) != 0) {
|
||||||
|
return xgboost::system::FailWithCode("Poll hung up on the other end.");
|
||||||
|
}
|
||||||
|
#endif // defined(POLLRDHUP)
|
||||||
return xgboost::collective::Success();
|
return xgboost::collective::Success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -179,9 +197,11 @@ struct PollHelper {
|
|||||||
}
|
}
|
||||||
std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout);
|
std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout);
|
||||||
if (ret == 0) {
|
if (ret == 0) {
|
||||||
return xgboost::collective::Fail("Poll timeout.", std::make_error_code(std::errc::timed_out));
|
return xgboost::collective::Fail(
|
||||||
|
"Poll timeout:" + std::to_string(timeout.count()) + " seconds.",
|
||||||
|
std::make_error_code(std::errc::timed_out));
|
||||||
} else if (ret < 0) {
|
} else if (ret < 0) {
|
||||||
return xgboost::system::FailWithCode("Poll failed.");
|
return xgboost::system::FailWithCode("Poll failed, nfds:" + std::to_string(fdset.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& pfd : fdset) {
|
for (auto& pfd : fdset) {
|
||||||
|
|||||||
@ -132,7 +132,7 @@ bool AllreduceBase::Shutdown() {
|
|||||||
try {
|
try {
|
||||||
for (auto &all_link : all_links) {
|
for (auto &all_link : all_links) {
|
||||||
if (!all_link.sock.IsClosed()) {
|
if (!all_link.sock.IsClosed()) {
|
||||||
all_link.sock.Close();
|
SafeColl(all_link.sock.Close());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
all_links.clear();
|
all_links.clear();
|
||||||
@ -146,7 +146,7 @@ bool AllreduceBase::Shutdown() {
|
|||||||
LOG(FATAL) << rc.Report();
|
LOG(FATAL) << rc.Report();
|
||||||
}
|
}
|
||||||
tracker.Send(xgboost::StringView{"shutdown"});
|
tracker.Send(xgboost::StringView{"shutdown"});
|
||||||
tracker.Close();
|
SafeColl(tracker.Close());
|
||||||
xgboost::system::SocketFinalize();
|
xgboost::system::SocketFinalize();
|
||||||
return true;
|
return true;
|
||||||
} catch (std::exception const &e) {
|
} catch (std::exception const &e) {
|
||||||
@ -167,7 +167,7 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
|
|||||||
|
|
||||||
tracker.Send(xgboost::StringView{"print"});
|
tracker.Send(xgboost::StringView{"print"});
|
||||||
tracker.Send(xgboost::StringView{msg});
|
tracker.Send(xgboost::StringView{msg});
|
||||||
tracker.Close();
|
SafeColl(tracker.Close());
|
||||||
}
|
}
|
||||||
|
|
||||||
// util to parse data with unit suffix
|
// util to parse data with unit suffix
|
||||||
@ -332,15 +332,15 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
|
|
||||||
auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())};
|
auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())};
|
||||||
// create listening socket
|
// create listening socket
|
||||||
int port = sock_listen.BindHost();
|
std::int32_t port{0};
|
||||||
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
|
SafeColl(sock_listen.BindHost(&port));
|
||||||
sock_listen.Listen();
|
SafeColl(sock_listen.Listen());
|
||||||
|
|
||||||
// get number of to connect and number of to accept nodes from tracker
|
// get number of to connect and number of to accept nodes from tracker
|
||||||
int num_conn, num_accept, num_error = 1;
|
int num_conn, num_accept, num_error = 1;
|
||||||
do {
|
do {
|
||||||
for (auto & all_link : all_links) {
|
for (auto & all_link : all_links) {
|
||||||
all_link.sock.Close();
|
SafeColl(all_link.sock.Close());
|
||||||
}
|
}
|
||||||
// tracker construct goodset
|
// tracker construct goodset
|
||||||
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
||||||
@ -352,7 +352,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
LinkRecord r;
|
LinkRecord r;
|
||||||
int hport, hrank;
|
int hport, hrank;
|
||||||
std::string hname;
|
std::string hname;
|
||||||
tracker.Recv(&hname);
|
SafeColl(tracker.Recv(&hname));
|
||||||
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
|
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9");
|
||||||
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
|
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10");
|
||||||
// connect to peer
|
// connect to peer
|
||||||
@ -360,7 +360,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
timeout_sec, &r.sock)
|
timeout_sec, &r.sock)
|
||||||
.OK()) {
|
.OK()) {
|
||||||
num_error += 1;
|
num_error += 1;
|
||||||
r.sock.Close();
|
SafeColl(r.sock.Close());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
||||||
@ -386,7 +386,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
// send back socket listening port to tracker
|
// send back socket listening port to tracker
|
||||||
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
|
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14");
|
||||||
// close connection to tracker
|
// close connection to tracker
|
||||||
tracker.Close();
|
SafeColl(tracker.Close());
|
||||||
|
|
||||||
// listen to incoming links
|
// listen to incoming links
|
||||||
for (int i = 0; i < num_accept; ++i) {
|
for (int i = 0; i < num_accept; ++i) {
|
||||||
@ -408,7 +408,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
|
|||||||
}
|
}
|
||||||
if (!match) all_links.emplace_back(std::move(r));
|
if (!match) all_links.emplace_back(std::move(r));
|
||||||
}
|
}
|
||||||
sock_listen.Close();
|
SafeColl(sock_listen.Close());
|
||||||
|
|
||||||
this->parent_index = -1;
|
this->parent_index = -1;
|
||||||
// setup tree links and ring structure
|
// setup tree links and ring structure
|
||||||
@ -635,7 +635,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
|
|||||||
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
|
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
|
||||||
|
|
||||||
if (len == 0) {
|
if (len == 0) {
|
||||||
links[parent_index].sock.Close();
|
SafeColl(links[parent_index].sock.Close());
|
||||||
return ReportError(&links[parent_index], kRecvZeroLen);
|
return ReportError(&links[parent_index], kRecvZeroLen);
|
||||||
}
|
}
|
||||||
if (len != -1) {
|
if (len != -1) {
|
||||||
|
|||||||
@ -270,7 +270,7 @@ class AllreduceBase : public IEngine {
|
|||||||
ssize_t len = sock.Recv(buffer_head + offset, nmax);
|
ssize_t len = sock.Recv(buffer_head + offset, nmax);
|
||||||
// length equals 0, remote disconnected
|
// length equals 0, remote disconnected
|
||||||
if (len == 0) {
|
if (len == 0) {
|
||||||
sock.Close(); return kRecvZeroLen;
|
SafeColl(sock.Close()); return kRecvZeroLen;
|
||||||
}
|
}
|
||||||
if (len == -1) return Errno2Return();
|
if (len == -1) return Errno2Return();
|
||||||
size_read += static_cast<size_t>(len);
|
size_read += static_cast<size_t>(len);
|
||||||
@ -289,7 +289,7 @@ class AllreduceBase : public IEngine {
|
|||||||
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
|
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
|
||||||
// length equals 0, remote disconnected
|
// length equals 0, remote disconnected
|
||||||
if (len == 0) {
|
if (len == 0) {
|
||||||
sock.Close(); return kRecvZeroLen;
|
SafeColl(sock.Close()); return kRecvZeroLen;
|
||||||
}
|
}
|
||||||
if (len == -1) return Errno2Return();
|
if (len == -1) return Errno2Return();
|
||||||
size_read += static_cast<size_t>(len);
|
size_read += static_cast<size_t>(len);
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2014-2024 by XGBoost Contributors
|
* Copyright 2014-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include "xgboost/c_api.h"
|
#include "xgboost/c_api.h"
|
||||||
|
|
||||||
@ -617,8 +617,8 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const
|
|||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
xgboost_CHECK_C_ARG_PTR(field);
|
xgboost_CHECK_C_ARG_PTR(field);
|
||||||
auto const& p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||||
p_fmat->SetInfo(field, info, xgboost::DataType::kFloat32, len);
|
p_fmat->SetInfo(field, linalg::Make1dInterface(info, len));
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -637,8 +637,9 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const
|
|||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
xgboost_CHECK_C_ARG_PTR(field);
|
xgboost_CHECK_C_ARG_PTR(field);
|
||||||
|
LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGDMatrixSetInfoFromInterface");
|
||||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||||
p_fmat->SetInfo(field, info, xgboost::DataType::kUInt32, len);
|
p_fmat->SetInfo(field, linalg::Make1dInterface(info, len));
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -682,19 +683,52 @@ XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void
|
|||||||
xgboost::bst_ulong size, int type) {
|
xgboost::bst_ulong size, int type) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGDMatrixSetInfoFromInterface");
|
||||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||||
CHECK(type >= 1 && type <= 4);
|
CHECK(type >= 1 && type <= 4);
|
||||||
xgboost_CHECK_C_ARG_PTR(field);
|
xgboost_CHECK_C_ARG_PTR(field);
|
||||||
p_fmat->SetInfo(field, data, static_cast<DataType>(type), size);
|
|
||||||
API_END();
|
|
||||||
}
|
|
||||||
|
|
||||||
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, const unsigned *group, xgboost::bst_ulong len) {
|
Context ctx;
|
||||||
API_BEGIN();
|
auto dtype = static_cast<DataType>(type);
|
||||||
CHECK_HANDLE();
|
std::string str;
|
||||||
LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead.";
|
auto proc = [&](auto cast_d_ptr) {
|
||||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
|
||||||
p_fmat->SetInfo("group", group, xgboost::DataType::kUInt32, len);
|
auto t = linalg::TensorView<T, 1>(
|
||||||
|
common::Span<T>{cast_d_ptr, static_cast<typename common::Span<T>::index_type>(size)},
|
||||||
|
{size}, DeviceOrd::CPU());
|
||||||
|
CHECK(t.CContiguous());
|
||||||
|
Json iface{linalg::ArrayInterface(t)};
|
||||||
|
CHECK(ArrayInterface<1>{iface}.is_contiguous);
|
||||||
|
str = Json::Dump(iface);
|
||||||
|
return str;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Legacy code using XGBoost dtype, which is a small subset of array interface types.
|
||||||
|
switch (dtype) {
|
||||||
|
case xgboost::DataType::kFloat32: {
|
||||||
|
auto cast_ptr = reinterpret_cast<const float *>(data);
|
||||||
|
p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case xgboost::DataType::kDouble: {
|
||||||
|
auto cast_ptr = reinterpret_cast<const double *>(data);
|
||||||
|
p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case xgboost::DataType::kUInt32: {
|
||||||
|
auto cast_ptr = reinterpret_cast<const uint32_t *>(data);
|
||||||
|
p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case xgboost::DataType::kUInt64: {
|
||||||
|
auto cast_ptr = reinterpret_cast<const uint64_t *>(data);
|
||||||
|
p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "Unknown data type" << static_cast<uint8_t>(dtype);
|
||||||
|
}
|
||||||
|
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -990,7 +1024,7 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle, DMatrixHandle dtrain, bs
|
|||||||
bst_float *hess, xgboost::bst_ulong len) {
|
bst_float *hess, xgboost::bst_ulong len) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
error::DeprecatedFunc(__func__, "2.1.0", "XGBoosterTrainOneIter");
|
LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGBoosterTrainOneIter");
|
||||||
auto *learner = static_cast<Learner *>(handle);
|
auto *learner = static_cast<Learner *>(handle);
|
||||||
auto ctx = learner->Ctx()->MakeCPU();
|
auto ctx = learner->Ctx()->MakeCPU();
|
||||||
|
|
||||||
|
|||||||
@ -1,17 +1,18 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2021-2023, XGBoost Contributors
|
* Copyright 2021-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_C_API_C_API_UTILS_H_
|
#ifndef XGBOOST_C_API_C_API_UTILS_H_
|
||||||
#define XGBOOST_C_API_C_API_UTILS_H_
|
#define XGBOOST_C_API_C_API_UTILS_H_
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm> // for min
|
||||||
#include <cstddef>
|
#include <cstddef> // for size_t
|
||||||
#include <functional>
|
#include <functional> // for multiplies
|
||||||
#include <memory> // for shared_ptr
|
#include <memory> // for shared_ptr
|
||||||
|
#include <numeric> // for accumulate
|
||||||
#include <string> // for string
|
#include <string> // for string
|
||||||
#include <tuple> // for make_tuple
|
#include <tuple> // for make_tuple
|
||||||
#include <utility> // for move
|
#include <utility> // for move
|
||||||
#include <vector>
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../common/json_utils.h" // for TypeCheck
|
#include "../common/json_utils.h" // for TypeCheck
|
||||||
#include "xgboost/c_api.h"
|
#include "xgboost/c_api.h"
|
||||||
|
|||||||
@ -1,15 +1,17 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <chrono> // for seconds
|
#include <chrono> // for seconds
|
||||||
#include <cstddef> // for size_t
|
|
||||||
#include <future> // for future
|
#include <future> // for future
|
||||||
#include <memory> // for unique_ptr
|
#include <memory> // for unique_ptr
|
||||||
#include <string> // for string
|
#include <string> // for string
|
||||||
|
#include <thread> // for sleep_for
|
||||||
#include <type_traits> // for is_same_v, remove_pointer_t
|
#include <type_traits> // for is_same_v, remove_pointer_t
|
||||||
#include <utility> // for pair
|
#include <utility> // for pair
|
||||||
|
|
||||||
|
#include "../collective/comm.h" // for DefaultTimeoutSec
|
||||||
#include "../collective/tracker.h" // for RabitTracker
|
#include "../collective/tracker.h" // for RabitTracker
|
||||||
|
#include "../common/timer.h" // for Timer
|
||||||
#include "c_api_error.h" // for API_BEGIN
|
#include "c_api_error.h" // for API_BEGIN
|
||||||
#include "xgboost/c_api.h"
|
#include "xgboost/c_api.h"
|
||||||
#include "xgboost/collective/result.h" // for Result
|
#include "xgboost/collective/result.h" // for Result
|
||||||
@ -26,7 +28,7 @@ using namespace xgboost; // NOLINT
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
using TrackerHandleT =
|
using TrackerHandleT =
|
||||||
std::pair<std::unique_ptr<collective::Tracker>, std::shared_future<collective::Result>>;
|
std::pair<std::shared_ptr<collective::Tracker>, std::shared_future<collective::Result>>;
|
||||||
|
|
||||||
TrackerHandleT *GetTrackerHandle(TrackerHandle handle) {
|
TrackerHandleT *GetTrackerHandle(TrackerHandle handle) {
|
||||||
xgboost_CHECK_C_ARG_PTR(handle);
|
xgboost_CHECK_C_ARG_PTR(handle);
|
||||||
@ -40,17 +42,29 @@ struct CollAPIEntry {
|
|||||||
};
|
};
|
||||||
using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
|
using CollAPIThreadLocalStore = dmlc::ThreadLocalStore<CollAPIEntry>;
|
||||||
|
|
||||||
void WaitImpl(TrackerHandleT *ptr) {
|
void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) {
|
||||||
std::chrono::seconds wait_for{100};
|
constexpr std::int64_t kDft{collective::DefaultTimeoutSec()};
|
||||||
|
std::chrono::seconds wait_for{timeout.count() != 0 ? std::min(kDft, timeout.count()) : kDft};
|
||||||
|
|
||||||
|
common::Timer timer;
|
||||||
|
timer.Start();
|
||||||
|
|
||||||
|
auto ref = ptr->first; // hold a reference to that free don't delete it while waiting.
|
||||||
|
|
||||||
auto fut = ptr->second;
|
auto fut = ptr->second;
|
||||||
while (fut.valid()) {
|
while (fut.valid()) {
|
||||||
auto res = fut.wait_for(wait_for);
|
auto res = fut.wait_for(wait_for);
|
||||||
CHECK(res != std::future_status::deferred);
|
CHECK(res != std::future_status::deferred);
|
||||||
|
|
||||||
if (res == std::future_status::ready) {
|
if (res == std::future_status::ready) {
|
||||||
auto const &rc = ptr->second.get();
|
auto const &rc = ptr->second.get();
|
||||||
CHECK(rc.OK()) << rc.Report();
|
collective::SafeColl(rc);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (timer.Duration() > timeout && timeout.count() != 0) {
|
||||||
|
collective::SafeColl(collective::Fail("Timeout waiting for the tracker."));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -62,15 +76,15 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle) {
|
|||||||
Json jconfig = Json::Load(config);
|
Json jconfig = Json::Load(config);
|
||||||
|
|
||||||
auto type = RequiredArg<String>(jconfig, "dmlc_communicator", __func__);
|
auto type = RequiredArg<String>(jconfig, "dmlc_communicator", __func__);
|
||||||
std::unique_ptr<collective::Tracker> tptr;
|
std::shared_ptr<collective::Tracker> tptr;
|
||||||
if (type == "federated") {
|
if (type == "federated") {
|
||||||
#if defined(XGBOOST_USE_FEDERATED)
|
#if defined(XGBOOST_USE_FEDERATED)
|
||||||
tptr = std::make_unique<collective::FederatedTracker>(jconfig);
|
tptr = std::make_shared<collective::FederatedTracker>(jconfig);
|
||||||
#else
|
#else
|
||||||
LOG(FATAL) << error::NoFederated();
|
LOG(FATAL) << error::NoFederated();
|
||||||
#endif // defined(XGBOOST_USE_FEDERATED)
|
#endif // defined(XGBOOST_USE_FEDERATED)
|
||||||
} else if (type == "rabit") {
|
} else if (type == "rabit") {
|
||||||
tptr = std::make_unique<collective::RabitTracker>(jconfig);
|
tptr = std::make_shared<collective::RabitTracker>(jconfig);
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "Unknown communicator:" << type;
|
LOG(FATAL) << "Unknown communicator:" << type;
|
||||||
}
|
}
|
||||||
@ -93,7 +107,7 @@ XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args) {
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGTrackerRun(TrackerHandle handle) {
|
XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
auto *ptr = GetTrackerHandle(handle);
|
auto *ptr = GetTrackerHandle(handle);
|
||||||
CHECK(!ptr->second.valid()) << "Tracker is already running.";
|
CHECK(!ptr->second.valid()) << "Tracker is already running.";
|
||||||
@ -101,19 +115,39 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle) {
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) {
|
XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
auto *ptr = GetTrackerHandle(handle);
|
auto *ptr = GetTrackerHandle(handle);
|
||||||
xgboost_CHECK_C_ARG_PTR(config);
|
xgboost_CHECK_C_ARG_PTR(config);
|
||||||
auto jconfig = Json::Load(StringView{config});
|
auto jconfig = Json::Load(StringView{config});
|
||||||
WaitImpl(ptr);
|
// Internally, 0 indicates no timeout, which is the default since we don't want to
|
||||||
|
// interrupt the model training.
|
||||||
|
xgboost_CHECK_C_ARG_PTR(config);
|
||||||
|
auto timeout = OptionalArg<Integer>(jconfig, "timeout", std::int64_t{0});
|
||||||
|
WaitImpl(ptr, std::chrono::seconds{timeout});
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGTrackerFree(TrackerHandle handle) {
|
XGB_DLL int XGTrackerFree(TrackerHandle handle) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
using namespace std::chrono_literals; // NOLINT
|
||||||
auto *ptr = GetTrackerHandle(handle);
|
auto *ptr = GetTrackerHandle(handle);
|
||||||
WaitImpl(ptr);
|
ptr->first->Stop();
|
||||||
|
// The wait is not necessary since we just called stop, just reusing the function to do
|
||||||
|
// any potential cleanups.
|
||||||
|
WaitImpl(ptr, ptr->first->Timeout());
|
||||||
|
common::Timer timer;
|
||||||
|
timer.Start();
|
||||||
|
// Make sure no one else is waiting on the tracker.
|
||||||
|
while (!ptr->first.unique()) {
|
||||||
|
auto ela = timer.Duration().count();
|
||||||
|
if (ela > ptr->first->Timeout().count()) {
|
||||||
|
LOG(WARNING) << "Time out " << ptr->first->Timeout().count()
|
||||||
|
<< " seconds reached for TrackerFree, killing the tracker.";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::this_thread::sleep_for(64ms);
|
||||||
|
}
|
||||||
delete ptr;
|
delete ptr;
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -165,7 +165,7 @@ template <typename T>
|
|||||||
T GlobalRatio(Context const* ctx, MetaInfo const& info, T dividend, T divisor) {
|
T GlobalRatio(Context const* ctx, MetaInfo const& info, T dividend, T divisor) {
|
||||||
std::array<T, 2> results{dividend, divisor};
|
std::array<T, 2> results{dividend, divisor};
|
||||||
auto rc = GlobalSum(ctx, info, linalg::MakeVec(results.data(), results.size()));
|
auto rc = GlobalSum(ctx, info, linalg::MakeVec(results.data(), results.size()));
|
||||||
collective::SafeColl(rc);
|
SafeColl(rc);
|
||||||
std::tie(dividend, divisor) = std::tuple_cat(results);
|
std::tie(dividend, divisor) = std::tuple_cat(results);
|
||||||
if (divisor <= 0) {
|
if (divisor <= 0) {
|
||||||
return std::numeric_limits<T>::quiet_NaN();
|
return std::numeric_limits<T>::quiet_NaN();
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include "allgather.h"
|
#include "allgather.h"
|
||||||
|
|
||||||
@ -7,6 +7,7 @@
|
|||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
#include <cstdint> // for int8_t, int32_t, int64_t
|
#include <cstdint> // for int8_t, int32_t, int64_t
|
||||||
#include <memory> // for shared_ptr
|
#include <memory> // for shared_ptr
|
||||||
|
#include <utility> // for move
|
||||||
|
|
||||||
#include "broadcast.h"
|
#include "broadcast.h"
|
||||||
#include "comm.h" // for Comm, Channel
|
#include "comm.h" // for Comm, Channel
|
||||||
@ -29,16 +30,22 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
|||||||
auto rc = Success() << [&] {
|
auto rc = Success() << [&] {
|
||||||
auto send_rank = (rank + world - r + worker_off) % world;
|
auto send_rank = (rank + world - r + worker_off) % world;
|
||||||
auto send_off = send_rank * segment_size;
|
auto send_off = send_rank * segment_size;
|
||||||
send_off = std::min(send_off, data.size_bytes());
|
bool is_last_segment = send_rank == (world - 1);
|
||||||
auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off));
|
auto send_nbytes = is_last_segment ? (data.size_bytes() - send_off) : segment_size;
|
||||||
|
auto send_seg = data.subspan(send_off, send_nbytes);
|
||||||
|
CHECK_NE(send_seg.size(), 0);
|
||||||
return next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
|
return next_ch->SendAll(send_seg.data(), send_seg.size_bytes());
|
||||||
} << [&] {
|
} << [&] {
|
||||||
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
|
auto recv_rank = (rank + world - r - 1 + worker_off) % world;
|
||||||
auto recv_off = recv_rank * segment_size;
|
auto recv_off = recv_rank * segment_size;
|
||||||
recv_off = std::min(recv_off, data.size_bytes());
|
bool is_last_segment = recv_rank == (world - 1);
|
||||||
auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off));
|
auto recv_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : segment_size;
|
||||||
|
auto recv_seg = data.subspan(recv_off, recv_nbytes);
|
||||||
|
CHECK_NE(recv_seg.size(), 0);
|
||||||
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||||
} << [&] { return prev_ch->Block(); };
|
} << [&] {
|
||||||
|
return comm.Block();
|
||||||
|
};
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
@ -91,7 +98,9 @@ namespace detail {
|
|||||||
auto recv_size = sizes[recv_rank];
|
auto recv_size = sizes[recv_rank];
|
||||||
auto recv_seg = erased_result.subspan(recv_off, recv_size);
|
auto recv_seg = erased_result.subspan(recv_off, recv_size);
|
||||||
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||||
} << [&] { return prev_ch->Block(); };
|
} << [&] {
|
||||||
|
return prev_ch->Block();
|
||||||
|
};
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
@ -99,4 +108,47 @@ namespace detail {
|
|||||||
return comm.Block();
|
return comm.Block();
|
||||||
}
|
}
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
|
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
|
||||||
|
Context const* ctx, CommGroup const& comm, std::vector<std::vector<char>> const& input) {
|
||||||
|
auto n_inputs = input.size();
|
||||||
|
std::vector<std::int64_t> sizes(n_inputs);
|
||||||
|
std::transform(input.cbegin(), input.cend(), sizes.begin(),
|
||||||
|
[](auto const& vec) { return vec.size(); });
|
||||||
|
|
||||||
|
std::vector<std::int64_t> recv_segments(comm.World() + 1, 0);
|
||||||
|
|
||||||
|
HostDeviceVector<std::int8_t> recv;
|
||||||
|
auto rc =
|
||||||
|
AllgatherV(ctx, comm, linalg::MakeVec(sizes.data(), sizes.size()), &recv_segments, &recv);
|
||||||
|
SafeColl(rc);
|
||||||
|
|
||||||
|
auto global_sizes = common::RestoreType<std::int64_t const>(recv.ConstHostSpan());
|
||||||
|
std::vector<std::int64_t> offset(global_sizes.size() + 1);
|
||||||
|
offset[0] = 0;
|
||||||
|
for (std::size_t i = 1; i < offset.size(); i++) {
|
||||||
|
offset[i] = offset[i - 1] + global_sizes[i - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<char> collected;
|
||||||
|
for (auto const& vec : input) {
|
||||||
|
collected.insert(collected.end(), vec.cbegin(), vec.cend());
|
||||||
|
}
|
||||||
|
rc = AllgatherV(ctx, comm, linalg::MakeVec(collected.data(), collected.size()), &recv_segments,
|
||||||
|
&recv);
|
||||||
|
SafeColl(rc);
|
||||||
|
auto out = common::RestoreType<char const>(recv.ConstHostSpan());
|
||||||
|
|
||||||
|
std::vector<std::vector<char>> result;
|
||||||
|
for (std::size_t i = 1; i < offset.size(); ++i) {
|
||||||
|
std::vector<char> local(out.cbegin() + offset[i - 1], out.cbegin() + offset[i]);
|
||||||
|
result.emplace_back(std::move(local));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
|
||||||
|
Context const* ctx, std::vector<std::vector<char>> const& input) {
|
||||||
|
return VectorAllgatherV(ctx, *GlobalCommGroup(), input);
|
||||||
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -1,25 +1,27 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
#include <cstdint> // for int32_t
|
#include <cstdint> // for int32_t
|
||||||
#include <memory> // for shared_ptr
|
#include <memory> // for shared_ptr
|
||||||
#include <numeric> // for accumulate
|
#include <numeric> // for accumulate
|
||||||
|
#include <string> // for string
|
||||||
#include <type_traits> // for remove_cv_t
|
#include <type_traits> // for remove_cv_t
|
||||||
#include <vector> // for vector
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../common/type.h" // for EraseType
|
#include "../common/type.h" // for EraseType
|
||||||
#include "comm.h" // for Comm, Channel
|
#include "comm.h" // for Comm, Channel
|
||||||
|
#include "comm_group.h" // for CommGroup
|
||||||
#include "xgboost/collective/result.h" // for Result
|
#include "xgboost/collective/result.h" // for Result
|
||||||
#include "xgboost/linalg.h"
|
#include "xgboost/linalg.h" // for MakeVec
|
||||||
#include "xgboost/span.h" // for Span
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
namespace cpu_impl {
|
namespace cpu_impl {
|
||||||
/**
|
/**
|
||||||
* @param worker_off Segment offset. For example, if the rank 2 worker specifies
|
* @param worker_off Segment offset. For example, if the rank 2 worker specifies
|
||||||
* worker_off = 1, then it owns the third segment.
|
* worker_off = 1, then it owns the third segment (2 + 1).
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data,
|
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data,
|
||||||
std::size_t segment_size, std::int32_t worker_off,
|
std::size_t segment_size, std::int32_t worker_off,
|
||||||
@ -51,8 +53,10 @@ inline void AllgatherVOffset(common::Span<std::int64_t const> sizes,
|
|||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
|
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data) {
|
||||||
auto n_bytes = sizeof(T) * size;
|
// This function is also used for ring allreduce, hence we allow the last segment to be
|
||||||
|
// larger due to round-down.
|
||||||
|
auto n_bytes_per_segment = data.size_bytes() / comm.World();
|
||||||
auto erased = common::EraseType(data);
|
auto erased = common::EraseType(data);
|
||||||
|
|
||||||
auto rank = comm.Rank();
|
auto rank = comm.Rank();
|
||||||
@ -61,7 +65,7 @@ template <typename T>
|
|||||||
|
|
||||||
auto prev_ch = comm.Chan(prev);
|
auto prev_ch = comm.Chan(prev);
|
||||||
auto next_ch = comm.Chan(next);
|
auto next_ch = comm.Chan(next);
|
||||||
auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes, 0, prev_ch, next_ch);
|
auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes_per_segment, 0, prev_ch, next_ch);
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
@ -76,7 +80,7 @@ template <typename T>
|
|||||||
|
|
||||||
std::vector<std::int64_t> sizes(world, 0);
|
std::vector<std::int64_t> sizes(world, 0);
|
||||||
sizes[rank] = data.size_bytes();
|
sizes[rank] = data.size_bytes();
|
||||||
auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()}, 1);
|
auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()});
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
@ -98,4 +102,115 @@ template <typename T>
|
|||||||
|
|
||||||
return detail::RingAllgatherV(comm, sizes, s_segments, erased_result);
|
return detail::RingAllgatherV(comm, sizes, s_segments, erased_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
[[nodiscard]] Result Allgather(Context const* ctx, CommGroup const& comm,
|
||||||
|
linalg::VectorView<T> data) {
|
||||||
|
if (!comm.IsDistributed()) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
CHECK(data.Contiguous());
|
||||||
|
auto erased = common::EraseType(data.Values());
|
||||||
|
|
||||||
|
auto const& cctx = comm.Ctx(ctx, data.Device());
|
||||||
|
auto backend = comm.Backend(data.Device());
|
||||||
|
return backend->Allgather(cctx, erased);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Gather all data from all workers.
|
||||||
|
*
|
||||||
|
* @param data The input and output buffer, needs to be pre-allocated by the caller.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
[[nodiscard]] Result Allgather(Context const* ctx, linalg::VectorView<T> data) {
|
||||||
|
auto const& cg = *GlobalCommGroup();
|
||||||
|
if (data.Size() % cg.World() != 0) {
|
||||||
|
return Fail("The total number of elements should be multiple of the number of workers.");
|
||||||
|
}
|
||||||
|
return Allgather(ctx, cg, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
[[nodiscard]] Result AllgatherV(Context const* ctx, CommGroup const& comm,
|
||||||
|
linalg::VectorView<T> data,
|
||||||
|
std::vector<std::int64_t>* recv_segments,
|
||||||
|
HostDeviceVector<std::int8_t>* recv) {
|
||||||
|
if (!comm.IsDistributed()) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
std::vector<std::int64_t> sizes(comm.World(), 0);
|
||||||
|
sizes[comm.Rank()] = data.Values().size_bytes();
|
||||||
|
auto erased_sizes = common::EraseType(common::Span{sizes.data(), sizes.size()});
|
||||||
|
auto rc = comm.Backend(DeviceOrd::CPU())
|
||||||
|
->Allgather(comm.Ctx(ctx, DeviceOrd::CPU()), erased_sizes);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
|
recv_segments->resize(sizes.size() + 1);
|
||||||
|
detail::AllgatherVOffset(sizes, common::Span{recv_segments->data(), recv_segments->size()});
|
||||||
|
auto total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0LL);
|
||||||
|
recv->SetDevice(data.Device());
|
||||||
|
recv->Resize(total_bytes);
|
||||||
|
|
||||||
|
auto s_segments = common::Span{recv_segments->data(), recv_segments->size()};
|
||||||
|
|
||||||
|
auto backend = comm.Backend(data.Device());
|
||||||
|
auto erased = common::EraseType(data.Values());
|
||||||
|
|
||||||
|
return backend->AllgatherV(
|
||||||
|
comm.Ctx(ctx, data.Device()), erased, common::Span{sizes.data(), sizes.size()}, s_segments,
|
||||||
|
data.Device().IsCUDA() ? recv->DeviceSpan() : recv->HostSpan(), AllgatherVAlgo::kBcast);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Allgather with variable length data.
|
||||||
|
*
|
||||||
|
* @param data The input data.
|
||||||
|
* @param recv_segments segment size for each worker. [0, 2, 5] means [0, 2) elements are
|
||||||
|
* from the first worker, [2, 5) elements are from the second one.
|
||||||
|
* @param recv The buffer storing the result.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
[[nodiscard]] Result AllgatherV(Context const* ctx, linalg::VectorView<T> data,
|
||||||
|
std::vector<std::int64_t>* recv_segments,
|
||||||
|
HostDeviceVector<std::int8_t>* recv) {
|
||||||
|
return AllgatherV(ctx, *GlobalCommGroup(), data, recv_segments, recv);
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
|
||||||
|
Context const* ctx, CommGroup const& comm, std::vector<std::vector<char>> const& input);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Gathers variable-length data from all processes and distributes it to all processes.
|
||||||
|
*
|
||||||
|
* @param inputs All the inputs from the local worker. The number of inputs can vary
|
||||||
|
* across different workers. Along with which, the size of each vector in
|
||||||
|
* the input can also vary.
|
||||||
|
*
|
||||||
|
* @return The AllgatherV result, containing vectors from all workers.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] std::vector<std::vector<char>> VectorAllgatherV(
|
||||||
|
Context const* ctx, std::vector<std::vector<char>> const& input);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Gathers variable-length strings from all processes and distributes them to all processes.
|
||||||
|
* @param input Variable-length list of variable-length strings.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] inline Result AllgatherStrings(std::vector<std::string> const& input,
|
||||||
|
std::vector<std::string>* p_result) {
|
||||||
|
std::vector<std::vector<char>> inputs(input.size());
|
||||||
|
for (std::size_t i = 0; i < input.size(); ++i) {
|
||||||
|
inputs[i] = {input[i].cbegin(), input[i].cend()};
|
||||||
|
}
|
||||||
|
Context ctx;
|
||||||
|
auto out = VectorAllgatherV(&ctx, *GlobalCommGroup(), inputs);
|
||||||
|
auto& result = *p_result;
|
||||||
|
result.resize(out.size());
|
||||||
|
for (std::size_t i = 0; i < out.size(); ++i) {
|
||||||
|
result[i] = {out[i].cbegin(), out[i].cend()};
|
||||||
|
}
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include "allreduce.h"
|
#include "allreduce.h"
|
||||||
|
|
||||||
@ -16,7 +16,44 @@
|
|||||||
#include "xgboost/span.h" // for Span
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
namespace xgboost::collective::cpu_impl {
|
namespace xgboost::collective::cpu_impl {
|
||||||
|
namespace {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
Result RingAllreduceSmall(Comm const& comm, common::Span<std::int8_t> data, Func const& op) {
|
||||||
|
auto rank = comm.Rank();
|
||||||
|
auto world = comm.World();
|
||||||
|
|
||||||
|
auto next_ch = comm.Chan(BootstrapNext(rank, world));
|
||||||
|
auto prev_ch = comm.Chan(BootstrapPrev(rank, world));
|
||||||
|
|
||||||
|
std::vector<std::int8_t> buffer(data.size_bytes() * world, 0);
|
||||||
|
auto s_buffer = common::Span{buffer.data(), buffer.size()};
|
||||||
|
|
||||||
|
auto offset = data.size_bytes() * rank;
|
||||||
|
auto self = s_buffer.subspan(offset, data.size_bytes());
|
||||||
|
std::copy_n(data.data(), data.size_bytes(), self.data());
|
||||||
|
|
||||||
|
auto typed = common::RestoreType<T>(s_buffer);
|
||||||
|
auto rc = RingAllgather(comm, typed);
|
||||||
|
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
auto first = s_buffer.subspan(0, data.size_bytes());
|
||||||
|
CHECK_EQ(first.size(), data.size());
|
||||||
|
|
||||||
|
for (std::int32_t r = 1; r < world; ++r) {
|
||||||
|
auto offset = data.size_bytes() * r;
|
||||||
|
auto buf = s_buffer.subspan(offset, data.size_bytes());
|
||||||
|
op(buf, first);
|
||||||
|
}
|
||||||
|
std::copy_n(first.data(), first.size(), data.data());
|
||||||
|
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
// note that n_bytes_in_seg is calculated with round-down.
|
||||||
Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||||
std::size_t n_bytes_in_seg, Func const& op) {
|
std::size_t n_bytes_in_seg, Func const& op) {
|
||||||
auto rank = comm.Rank();
|
auto rank = comm.Rank();
|
||||||
@ -27,33 +64,39 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
|||||||
auto next_ch = comm.Chan(dst_rank);
|
auto next_ch = comm.Chan(dst_rank);
|
||||||
auto prev_ch = comm.Chan(src_rank);
|
auto prev_ch = comm.Chan(src_rank);
|
||||||
|
|
||||||
std::vector<std::int8_t> buffer(n_bytes_in_seg, 0);
|
std::vector<std::int8_t> buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, 0);
|
||||||
auto s_buf = common::Span{buffer.data(), buffer.size()};
|
auto s_buf = common::Span{buffer.data(), buffer.size()};
|
||||||
|
|
||||||
for (std::int32_t r = 0; r < world - 1; ++r) {
|
for (std::int32_t r = 0; r < world - 1; ++r) {
|
||||||
|
common::Span<std::int8_t> seg, recv_seg;
|
||||||
|
auto rc = Success() << [&] {
|
||||||
// send to ring next
|
// send to ring next
|
||||||
auto send_off = ((rank + world - r) % world) * n_bytes_in_seg;
|
auto send_rank = (rank + world - r) % world;
|
||||||
send_off = std::min(send_off, data.size_bytes());
|
auto send_off = send_rank * n_bytes_in_seg;
|
||||||
auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg);
|
|
||||||
auto send_seg = data.subspan(send_off, seg_nbytes);
|
|
||||||
|
|
||||||
auto rc = next_ch->SendAll(send_seg);
|
bool is_last_segment = send_rank == (world - 1);
|
||||||
if (!rc.OK()) {
|
|
||||||
return rc;
|
|
||||||
}
|
|
||||||
|
|
||||||
// receive from ring prev
|
auto seg_nbytes = is_last_segment ? data.size_bytes() - send_off : n_bytes_in_seg;
|
||||||
auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg;
|
|
||||||
recv_off = std::min(recv_off, data.size_bytes());
|
|
||||||
seg_nbytes = std::min(data.size_bytes() - recv_off, n_bytes_in_seg);
|
|
||||||
CHECK_EQ(seg_nbytes % sizeof(T), 0);
|
CHECK_EQ(seg_nbytes % sizeof(T), 0);
|
||||||
auto recv_seg = data.subspan(recv_off, seg_nbytes);
|
|
||||||
auto seg = s_buf.subspan(0, recv_seg.size());
|
|
||||||
|
|
||||||
rc = std::move(rc) << [&] { return prev_ch->RecvAll(seg); } << [&] { return comm.Block(); };
|
auto send_seg = data.subspan(send_off, seg_nbytes);
|
||||||
if (!rc.OK()) {
|
return next_ch->SendAll(send_seg);
|
||||||
return rc;
|
} << [&] {
|
||||||
}
|
// receive from ring prev
|
||||||
|
auto recv_rank = (rank + world - r - 1) % world;
|
||||||
|
auto recv_off = recv_rank * n_bytes_in_seg;
|
||||||
|
|
||||||
|
bool is_last_segment = recv_rank == (world - 1);
|
||||||
|
|
||||||
|
auto seg_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : n_bytes_in_seg;
|
||||||
|
CHECK_EQ(seg_nbytes % sizeof(T), 0);
|
||||||
|
|
||||||
|
recv_seg = data.subspan(recv_off, seg_nbytes);
|
||||||
|
seg = s_buf.subspan(0, recv_seg.size());
|
||||||
|
return prev_ch->RecvAll(seg);
|
||||||
|
} << [&] {
|
||||||
|
return comm.Block();
|
||||||
|
};
|
||||||
|
|
||||||
// accumulate to recv_seg
|
// accumulate to recv_seg
|
||||||
CHECK_EQ(seg.size(), recv_seg.size());
|
CHECK_EQ(seg.size(), recv_seg.size());
|
||||||
@ -68,6 +111,9 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
|
|||||||
if (comm.World() == 1) {
|
if (comm.World() == 1) {
|
||||||
return Success();
|
return Success();
|
||||||
}
|
}
|
||||||
|
if (data.size_bytes() == 0) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
return DispatchDType(type, [&](auto t) {
|
return DispatchDType(type, [&](auto t) {
|
||||||
using T = decltype(t);
|
using T = decltype(t);
|
||||||
// Divide the data into segments according to the number of workers.
|
// Divide the data into segments according to the number of workers.
|
||||||
@ -75,7 +121,11 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
|
|||||||
CHECK_EQ(data.size_bytes() % n_bytes_elem, 0);
|
CHECK_EQ(data.size_bytes() % n_bytes_elem, 0);
|
||||||
auto n = data.size_bytes() / n_bytes_elem;
|
auto n = data.size_bytes() / n_bytes_elem;
|
||||||
auto world = comm.World();
|
auto world = comm.World();
|
||||||
auto n_bytes_in_seg = common::DivRoundUp(n, world) * sizeof(T);
|
if (n < static_cast<decltype(n)>(world)) {
|
||||||
|
return RingAllreduceSmall<T>(comm, data, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto n_bytes_in_seg = (n / world) * sizeof(T);
|
||||||
auto rc = RingScatterReduceTyped<T>(comm, data, n_bytes_in_seg, op);
|
auto rc = RingScatterReduceTyped<T>(comm, data, n_bytes_in_seg, op);
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return rc;
|
return rc;
|
||||||
@ -88,7 +138,9 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
|
|||||||
|
|
||||||
return std::move(rc) << [&] {
|
return std::move(rc) << [&] {
|
||||||
return RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
|
return RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
|
||||||
} << [&] { return comm.Block(); };
|
} << [&] {
|
||||||
|
return comm.Block();
|
||||||
|
};
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} // namespace xgboost::collective::cpu_impl
|
} // namespace xgboost::collective::cpu_impl
|
||||||
|
|||||||
@ -1,15 +1,18 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <cstdint> // for int8_t
|
#include <cstdint> // for int8_t
|
||||||
#include <functional> // for function
|
#include <functional> // for function
|
||||||
#include <type_traits> // for is_invocable_v, enable_if_t
|
#include <type_traits> // for is_invocable_v, enable_if_t
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../common/type.h" // for EraseType, RestoreType
|
#include "../common/type.h" // for EraseType, RestoreType
|
||||||
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
#include "../data/array_interface.h" // for ToDType, ArrayInterfaceHandler
|
||||||
#include "comm.h" // for Comm, RestoreType
|
#include "comm.h" // for Comm, RestoreType
|
||||||
|
#include "comm_group.h" // for GlobalCommGroup
|
||||||
#include "xgboost/collective/result.h" // for Result
|
#include "xgboost/collective/result.h" // for Result
|
||||||
|
#include "xgboost/context.h" // for Context
|
||||||
#include "xgboost/span.h" // for Span
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
@ -27,8 +30,7 @@ std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>
|
|||||||
auto erased = common::EraseType(data);
|
auto erased = common::EraseType(data);
|
||||||
auto type = ToDType<T>::kType;
|
auto type = ToDType<T>::kType;
|
||||||
|
|
||||||
auto erased_fn = [type, redop](common::Span<std::int8_t const> lhs,
|
auto erased_fn = [redop](common::Span<std::int8_t const> lhs, common::Span<std::int8_t> out) {
|
||||||
common::Span<std::int8_t> out) {
|
|
||||||
CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction.";
|
CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction.";
|
||||||
auto lhs_t = common::RestoreType<T const>(lhs);
|
auto lhs_t = common::RestoreType<T const>(lhs);
|
||||||
auto rhs_t = common::RestoreType<T>(out);
|
auto rhs_t = common::RestoreType<T>(out);
|
||||||
@ -37,4 +39,40 @@ std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>
|
|||||||
|
|
||||||
return cpu_impl::RingAllreduce(comm, erased, erased_fn, type);
|
return cpu_impl::RingAllreduce(comm, erased, erased_fn, type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, std::int32_t kDim>
|
||||||
|
[[nodiscard]] Result Allreduce(Context const* ctx, CommGroup const& comm,
|
||||||
|
linalg::TensorView<T, kDim> data, Op op) {
|
||||||
|
if (!comm.IsDistributed()) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
CHECK(data.Contiguous());
|
||||||
|
auto erased = common::EraseType(data.Values());
|
||||||
|
auto type = ToDType<T>::kType;
|
||||||
|
|
||||||
|
auto backend = comm.Backend(data.Device());
|
||||||
|
return backend->Allreduce(comm.Ctx(ctx, data.Device()), erased, type, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, std::int32_t kDim>
|
||||||
|
[[nodiscard]] Result Allreduce(Context const* ctx, linalg::TensorView<T, kDim> data, Op op) {
|
||||||
|
return Allreduce(ctx, *GlobalCommGroup(), data, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Specialization for std::vector.
|
||||||
|
*/
|
||||||
|
template <typename T, typename Alloc>
|
||||||
|
[[nodiscard]] Result Allreduce(Context const* ctx, std::vector<T, Alloc>* data, Op op) {
|
||||||
|
return Allreduce(ctx, linalg::MakeVec(data->data(), data->size()), op);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Specialization for scalar value.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
[[nodiscard]] std::enable_if_t<std::is_standard_layout_v<T> && std::is_trivial_v<T>, Result>
|
||||||
|
Allreduce(Context const* ctx, T* data, Op op) {
|
||||||
|
return Allreduce(ctx, linalg::MakeVec(data, 1), op);
|
||||||
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -1,11 +1,15 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <cstdint> // for int32_t, int8_t
|
#include <cstdint> // for int32_t, int8_t
|
||||||
|
|
||||||
#include "comm.h" // for Comm
|
#include "../common/type.h"
|
||||||
#include "xgboost/collective/result.h" // for
|
#include "comm.h" // for Comm, EraseType
|
||||||
|
#include "comm_group.h" // for CommGroup
|
||||||
|
#include "xgboost/collective/result.h" // for Result
|
||||||
|
#include "xgboost/context.h" // for Context
|
||||||
|
#include "xgboost/linalg.h" // for VectorView
|
||||||
#include "xgboost/span.h" // for Span
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
@ -23,4 +27,21 @@ template <typename T>
|
|||||||
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes};
|
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes};
|
||||||
return cpu_impl::Broadcast(comm, erased, root);
|
return cpu_impl::Broadcast(comm, erased, root);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
[[nodiscard]] Result Broadcast(Context const* ctx, CommGroup const& comm,
|
||||||
|
linalg::VectorView<T> data, std::int32_t root) {
|
||||||
|
if (!comm.IsDistributed()) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
CHECK(data.Contiguous());
|
||||||
|
auto erased = common::EraseType(data.Values());
|
||||||
|
auto backend = comm.Backend(data.Device());
|
||||||
|
return backend->Broadcast(comm.Ctx(ctx, data.Device()), erased, root);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
[[nodiscard]] Result Broadcast(Context const* ctx, linalg::VectorView<T> data, std::int32_t root) {
|
||||||
|
return Broadcast(ctx, *GlobalCommGroup(), data, root);
|
||||||
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -42,6 +42,10 @@ bool constexpr IsFloatingPointV() {
|
|||||||
auto redop_fn = [](auto lhs, auto out, auto elem_op) {
|
auto redop_fn = [](auto lhs, auto out, auto elem_op) {
|
||||||
auto p_lhs = lhs.data();
|
auto p_lhs = lhs.data();
|
||||||
auto p_out = out.data();
|
auto p_out = out.data();
|
||||||
|
#if defined(__GNUC__) || defined(__clang__)
|
||||||
|
// For the sum op, one can verify the simd by: addps %xmm15, %xmm14
|
||||||
|
#pragma omp simd
|
||||||
|
#endif
|
||||||
for (std::size_t i = 0; i < lhs.size(); ++i) {
|
for (std::size_t i = 0; i < lhs.size(); ++i) {
|
||||||
p_out[i] = elem_op(p_lhs[i], p_out[i]);
|
p_out[i] = elem_op(p_lhs[i], p_out[i]);
|
||||||
}
|
}
|
||||||
@ -108,9 +112,8 @@ bool constexpr IsFloatingPointV() {
|
|||||||
return cpu_impl::Broadcast(comm, data, root);
|
return cpu_impl::Broadcast(comm, data, root);
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span<std::int8_t> data) {
|
||||||
std::int64_t size) {
|
return RingAllgather(comm, data);
|
||||||
return RingAllgather(comm, data, size);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
|
[[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
|
||||||
|
|||||||
@ -1,10 +1,9 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
|
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
|
||||||
#include <cstdint> // for int8_t, int64_t
|
#include <cstdint> // for int8_t, int64_t
|
||||||
|
|
||||||
#include "../common/cuda_context.cuh"
|
|
||||||
#include "../common/device_helpers.cuh"
|
#include "../common/device_helpers.cuh"
|
||||||
#include "../data/array_interface.h"
|
#include "../data/array_interface.h"
|
||||||
#include "allgather.h" // for AllgatherVOffset
|
#include "allgather.h" // for AllgatherVOffset
|
||||||
@ -166,14 +165,14 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) {
|
|||||||
} << [&] { return nccl->Block(); };
|
} << [&] { return nccl->Block(); };
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span<std::int8_t> data) {
|
||||||
std::int64_t size) {
|
|
||||||
if (!comm.IsDistributed()) {
|
if (!comm.IsDistributed()) {
|
||||||
return Success();
|
return Success();
|
||||||
}
|
}
|
||||||
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
auto nccl = dynamic_cast<NCCLComm const*>(&comm);
|
||||||
CHECK(nccl);
|
CHECK(nccl);
|
||||||
auto stub = nccl->Stub();
|
auto stub = nccl->Stub();
|
||||||
|
auto size = data.size_bytes() / comm.World();
|
||||||
|
|
||||||
auto send = data.subspan(comm.Rank() * size, size);
|
auto send = data.subspan(comm.Rank() * size, size);
|
||||||
return Success() << [&] {
|
return Success() << [&] {
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
@ -8,7 +8,6 @@
|
|||||||
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
#include "../data/array_interface.h" // for ArrayInterfaceHandler
|
||||||
#include "coll.h" // for Coll
|
#include "coll.h" // for Coll
|
||||||
#include "comm.h" // for Comm
|
#include "comm.h" // for Comm
|
||||||
#include "nccl_stub.h"
|
|
||||||
#include "xgboost/span.h" // for Span
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
@ -20,8 +19,7 @@ class NCCLColl : public Coll {
|
|||||||
ArrayInterfaceHandler::Type type, Op op) override;
|
ArrayInterfaceHandler::Type type, Op op) override;
|
||||||
[[nodiscard]] Result Broadcast(Comm const& comm, common::Span<std::int8_t> data,
|
[[nodiscard]] Result Broadcast(Comm const& comm, common::Span<std::int8_t> data,
|
||||||
std::int32_t root) override;
|
std::int32_t root) override;
|
||||||
[[nodiscard]] Result Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
[[nodiscard]] Result Allgather(Comm const& comm, common::Span<std::int8_t> data) override;
|
||||||
std::int64_t size) override;
|
|
||||||
[[nodiscard]] Result AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
|
[[nodiscard]] Result AllgatherV(Comm const& comm, common::Span<std::int8_t const> data,
|
||||||
common::Span<std::int64_t const> sizes,
|
common::Span<std::int64_t const> sizes,
|
||||||
common::Span<std::int64_t> recv_segments,
|
common::Span<std::int64_t> recv_segments,
|
||||||
|
|||||||
@ -48,10 +48,8 @@ class Coll : public std::enable_shared_from_this<Coll> {
|
|||||||
* @brief Allgather
|
* @brief Allgather
|
||||||
*
|
*
|
||||||
* @param [in,out] data Data buffer for input and output.
|
* @param [in,out] data Data buffer for input and output.
|
||||||
* @param [in] size Size of data for each worker.
|
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span<std::int8_t> data,
|
[[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span<std::int8_t> data);
|
||||||
std::int64_t size);
|
|
||||||
/**
|
/**
|
||||||
* @brief Allgather with variable length.
|
* @brief Allgather with variable length.
|
||||||
*
|
*
|
||||||
|
|||||||
@ -1,16 +1,19 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include "comm.h"
|
#include "comm.h"
|
||||||
|
|
||||||
#include <algorithm> // for copy
|
#include <algorithm> // for copy
|
||||||
#include <chrono> // for seconds
|
#include <chrono> // for seconds
|
||||||
|
#include <cstdint> // for int32_t
|
||||||
#include <cstdlib> // for exit
|
#include <cstdlib> // for exit
|
||||||
#include <memory> // for shared_ptr
|
#include <memory> // for shared_ptr
|
||||||
#include <string> // for string
|
#include <string> // for string
|
||||||
|
#include <thread> // for thread
|
||||||
#include <utility> // for move, forward
|
#include <utility> // for move, forward
|
||||||
|
#if !defined(XGBOOST_USE_NCCL)
|
||||||
#include "../common/common.h" // for AssertGPUSupport
|
#include "../common/common.h" // for AssertNCCLSupport
|
||||||
|
#endif // !defined(XGBOOST_USE_NCCL)
|
||||||
#include "allgather.h" // for RingAllgather
|
#include "allgather.h" // for RingAllgather
|
||||||
#include "protocol.h" // for kMagic
|
#include "protocol.h" // for kMagic
|
||||||
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
|
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
|
||||||
@ -21,11 +24,7 @@
|
|||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
||||||
std::int32_t retry, std::string task_id)
|
std::int32_t retry, std::string task_id)
|
||||||
: timeout_{timeout},
|
: timeout_{timeout}, retry_{retry}, tracker_{host, port, -1}, task_id_{std::move(task_id)} {}
|
||||||
retry_{retry},
|
|
||||||
tracker_{host, port, -1},
|
|
||||||
task_id_{std::move(task_id)},
|
|
||||||
loop_{std::shared_ptr<Loop>{new Loop{timeout}}} {}
|
|
||||||
|
|
||||||
Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
|
Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
|
||||||
std::string const& task_id, TCPSocket* out, std::int32_t rank,
|
std::string const& task_id, TCPSocket* out, std::int32_t rank,
|
||||||
@ -187,12 +186,30 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
|
|||||||
return Success();
|
return Success();
|
||||||
}
|
}
|
||||||
|
|
||||||
RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
namespace {
|
||||||
std::int32_t retry, std::string task_id, StringView nccl_path)
|
std::string InitLog(std::string task_id, std::int32_t rank) {
|
||||||
: HostComm{std::move(host), port, timeout, retry, std::move(task_id)},
|
if (task_id.empty()) {
|
||||||
|
return "Rank " + std::to_string(rank);
|
||||||
|
}
|
||||||
|
return "Task " + task_id + " got rank " + std::to_string(rank);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
RabitComm::RabitComm(std::string const& tracker_host, std::int32_t tracker_port,
|
||||||
|
std::chrono::seconds timeout, std::int32_t retry, std::string task_id,
|
||||||
|
StringView nccl_path)
|
||||||
|
: HostComm{tracker_host, tracker_port, timeout, retry, std::move(task_id)},
|
||||||
nccl_path_{std::move(nccl_path)} {
|
nccl_path_{std::move(nccl_path)} {
|
||||||
|
if (this->TrackerInfo().host.empty()) {
|
||||||
|
// Not in a distributed environment.
|
||||||
|
LOG(CONSOLE) << InitLog(task_id_, rank_);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
loop_.reset(new Loop{std::chrono::seconds{timeout_}}); // NOLINT
|
||||||
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
|
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
|
this->ResetState();
|
||||||
SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc)));
|
SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -219,20 +236,54 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
|||||||
|
|
||||||
// Start command
|
// Start command
|
||||||
TCPSocket listener = TCPSocket::Create(tracker.Domain());
|
TCPSocket listener = TCPSocket::Create(tracker.Domain());
|
||||||
std::int32_t lport = listener.BindHost();
|
std::int32_t lport{0};
|
||||||
listener.Listen();
|
rc = std::move(rc) << [&] {
|
||||||
|
return listener.BindHost(&lport);
|
||||||
|
} << [&] {
|
||||||
|
return listener.Listen();
|
||||||
|
};
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
// create worker for listening to error notice.
|
// create worker for listening to error notice.
|
||||||
auto domain = tracker.Domain();
|
auto domain = tracker.Domain();
|
||||||
std::shared_ptr<TCPSocket> error_sock{TCPSocket::CreatePtr(domain)};
|
std::shared_ptr<TCPSocket> error_sock{TCPSocket::CreatePtr(domain)};
|
||||||
auto eport = error_sock->BindHost();
|
std::int32_t eport{0};
|
||||||
error_sock->Listen();
|
rc = std::move(rc) << [&] {
|
||||||
|
return error_sock->BindHost(&eport);
|
||||||
|
} << [&] {
|
||||||
|
return error_sock->Listen();
|
||||||
|
};
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
error_port_ = eport;
|
||||||
|
|
||||||
error_worker_ = std::thread{[error_sock = std::move(error_sock)] {
|
error_worker_ = std::thread{[error_sock = std::move(error_sock)] {
|
||||||
auto conn = error_sock->Accept();
|
TCPSocket conn;
|
||||||
|
SockAddress addr;
|
||||||
|
auto rc = error_sock->Accept(&conn, &addr);
|
||||||
|
// On Linux, a shutdown causes an invalid argument error;
|
||||||
|
if (rc.Code() == std::errc::invalid_argument) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
// On Windows, accept returns a closed socket after finalize.
|
// On Windows, accept returns a closed socket after finalize.
|
||||||
if (conn.IsClosed()) {
|
if (conn.IsClosed()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
// The error signal is from the tracker, while shutdown signal is from the shutdown method
|
||||||
|
// of the RabitComm class (this).
|
||||||
|
bool is_error{false};
|
||||||
|
rc = proto::Error{}.RecvSignal(&conn, &is_error);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
LOG(WARNING) << rc.Report();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!is_error) {
|
||||||
|
return; // shutdown
|
||||||
|
}
|
||||||
|
|
||||||
LOG(WARNING) << "Another worker is running into error.";
|
LOG(WARNING) << "Another worker is running into error.";
|
||||||
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
|
#if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0
|
||||||
// exit is nicer than abort as the former performs cleanups.
|
// exit is nicer than abort as the former performs cleanups.
|
||||||
@ -241,6 +292,9 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
|||||||
LOG(FATAL) << "abort";
|
LOG(FATAL) << "abort";
|
||||||
#endif
|
#endif
|
||||||
}};
|
}};
|
||||||
|
// The worker thread is detached here to avoid the need to handle it later during
|
||||||
|
// destruction. For C++, if a thread is not joined or detached, it will segfault during
|
||||||
|
// destruction.
|
||||||
error_worker_.detach();
|
error_worker_.detach();
|
||||||
|
|
||||||
proto::Start start;
|
proto::Start start;
|
||||||
@ -253,7 +307,7 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
|||||||
|
|
||||||
// get ring neighbors
|
// get ring neighbors
|
||||||
std::string snext;
|
std::string snext;
|
||||||
tracker.Recv(&snext);
|
rc = tracker.Recv(&snext);
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return Fail("Failed to receive the rank for the next worker.", std::move(rc));
|
return Fail("Failed to receive the rank for the next worker.", std::move(rc));
|
||||||
}
|
}
|
||||||
@ -273,14 +327,21 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
|||||||
CHECK(this->channels_.empty());
|
CHECK(this->channels_.empty());
|
||||||
for (auto& w : workers) {
|
for (auto& w : workers) {
|
||||||
if (w) {
|
if (w) {
|
||||||
rc = std::move(rc) << [&] { return w->SetNoDelay(); } << [&] { return w->NonBlocking(true); }
|
rc = std::move(rc) << [&] {
|
||||||
<< [&] { return w->SetKeepAlive(); };
|
return w->SetNoDelay();
|
||||||
|
} << [&] {
|
||||||
|
return w->NonBlocking(true);
|
||||||
|
} << [&] {
|
||||||
|
return w->SetKeepAlive();
|
||||||
|
};
|
||||||
}
|
}
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
this->channels_.emplace_back(std::make_shared<Channel>(*this, w));
|
this->channels_.emplace_back(std::make_shared<Channel>(*this, w));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LOG(CONSOLE) << InitLog(task_id_, rank_);
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -288,6 +349,8 @@ RabitComm::~RabitComm() noexcept(false) {
|
|||||||
if (!this->IsDistributed()) {
|
if (!this->IsDistributed()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
LOG(WARNING) << "The communicator is being destroyed without a call to shutdown first. This can "
|
||||||
|
"lead to undefined behaviour.";
|
||||||
auto rc = this->Shutdown();
|
auto rc = this->Shutdown();
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
LOG(WARNING) << rc.Report();
|
LOG(WARNING) << rc.Report();
|
||||||
@ -295,24 +358,52 @@ RabitComm::~RabitComm() noexcept(false) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result RabitComm::Shutdown() {
|
[[nodiscard]] Result RabitComm::Shutdown() {
|
||||||
|
if (!this->IsDistributed()) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
// Tell the tracker that this worker is shutting down.
|
||||||
TCPSocket tracker;
|
TCPSocket tracker;
|
||||||
|
// Tell the error hanlding thread that we are shutting down.
|
||||||
|
TCPSocket err_client;
|
||||||
|
|
||||||
return Success() << [&] {
|
return Success() << [&] {
|
||||||
return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World());
|
return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World());
|
||||||
} << [&] {
|
} << [&] {
|
||||||
return this->Block();
|
return this->Block();
|
||||||
} << [&] {
|
} << [&] {
|
||||||
Json jcmd{Object{}};
|
return proto::ShutdownCMD{}.Send(&tracker);
|
||||||
jcmd["cmd"] = Integer{static_cast<std::int32_t>(proto::CMD::kShutdown)};
|
} << [&] {
|
||||||
auto scmd = Json::Dump(jcmd);
|
this->channels_.clear();
|
||||||
auto n_bytes = tracker.Send(scmd);
|
|
||||||
if (n_bytes != scmd.size()) {
|
|
||||||
return Fail("Faled to send cmd.");
|
|
||||||
}
|
|
||||||
return Success();
|
return Success();
|
||||||
|
} << [&] {
|
||||||
|
// Use tracker address to determine whether we want to use IPv6.
|
||||||
|
auto taddr = MakeSockAddress(xgboost::StringView{this->tracker_.host}, this->tracker_.port);
|
||||||
|
// Shutdown the error handling thread. We signal the thread through socket,
|
||||||
|
// alternatively, we can get the native handle and use pthread_cancel. But using a
|
||||||
|
// socket seems to be clearer as we know what's happening.
|
||||||
|
auto const& addr = taddr.IsV4() ? SockAddrV4::Loopback().Addr() : SockAddrV6::Loopback().Addr();
|
||||||
|
// We use hardcoded 10 seconds and 1 retry here since we are just connecting to a
|
||||||
|
// local socket. For a normal OS, this should be enough time to schedule the
|
||||||
|
// connection.
|
||||||
|
auto rc = Connect(StringView{addr}, this->error_port_, 1,
|
||||||
|
std::min(std::chrono::seconds{10}, timeout_), &err_client);
|
||||||
|
this->ResetState();
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return Fail("Failed to connect to the error socket.", std::move(rc));
|
||||||
|
}
|
||||||
|
return rc;
|
||||||
|
} << [&] {
|
||||||
|
// We put error thread shutdown at the end so that we have a better chance to finish
|
||||||
|
// the previous more important steps.
|
||||||
|
return proto::Error{}.SignalShutdown(&err_client);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result RabitComm::LogTracker(std::string msg) const {
|
[[nodiscard]] Result RabitComm::LogTracker(std::string msg) const {
|
||||||
|
if (!this->IsDistributed()) {
|
||||||
|
LOG(CONSOLE) << msg;
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
TCPSocket out;
|
TCPSocket out;
|
||||||
proto::Print print;
|
proto::Print print;
|
||||||
return Success() << [&] { return this->ConnectTracker(&out); }
|
return Success() << [&] { return this->ConnectTracker(&out); }
|
||||||
@ -320,8 +411,11 @@ RabitComm::~RabitComm() noexcept(false) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result RabitComm::SignalError(Result const& res) {
|
[[nodiscard]] Result RabitComm::SignalError(Result const& res) {
|
||||||
TCPSocket out;
|
TCPSocket tracker;
|
||||||
return Success() << [&] { return this->ConnectTracker(&out); }
|
return Success() << [&] {
|
||||||
<< [&] { return proto::ErrorCMD{}.WorkerSend(&out, res); };
|
return this->ConnectTracker(&tracker);
|
||||||
|
} << [&] {
|
||||||
|
return proto::ErrorCMD{}.WorkerSend(&tracker, res);
|
||||||
|
};
|
||||||
}
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -27,7 +27,7 @@ Result GetUniqueId(Comm const& comm, std::shared_ptr<NcclStub> stub, std::shared
|
|||||||
ncclUniqueId id;
|
ncclUniqueId id;
|
||||||
if (comm.Rank() == kRootRank) {
|
if (comm.Rank() == kRootRank) {
|
||||||
auto rc = stub->GetUniqueId(&id);
|
auto rc = stub->GetUniqueId(&id);
|
||||||
CHECK(rc.OK()) << rc.Report();
|
SafeColl(rc);
|
||||||
}
|
}
|
||||||
auto rc = coll->Broadcast(
|
auto rc = coll->Broadcast(
|
||||||
comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)}, kRootRank);
|
comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)}, kRootRank);
|
||||||
@ -90,9 +90,8 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
|||||||
auto s_this_uuid = s_uuid.subspan(root.Rank() * kUuidLength, kUuidLength);
|
auto s_this_uuid = s_uuid.subspan(root.Rank() * kUuidLength, kUuidLength);
|
||||||
GetCudaUUID(s_this_uuid, ctx->Device());
|
GetCudaUUID(s_this_uuid, ctx->Device());
|
||||||
|
|
||||||
auto rc = pimpl->Allgather(root, common::EraseType(s_uuid), s_this_uuid.size_bytes());
|
auto rc = pimpl->Allgather(root, common::EraseType(s_uuid));
|
||||||
|
SafeColl(rc);
|
||||||
CHECK(rc.OK()) << rc.Report();
|
|
||||||
|
|
||||||
std::vector<xgboost::common::Span<std::uint64_t, kUuidLength>> converted(root.World());
|
std::vector<xgboost::common::Span<std::uint64_t, kUuidLength>> converted(root.World());
|
||||||
std::size_t j = 0;
|
std::size_t j = 0;
|
||||||
@ -113,7 +112,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
|||||||
[&] {
|
[&] {
|
||||||
return this->stub_->CommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank());
|
return this->stub_->CommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank());
|
||||||
};
|
};
|
||||||
CHECK(rc.OK()) << rc.Report();
|
SafeColl(rc);
|
||||||
|
|
||||||
for (std::int32_t r = 0; r < root.World(); ++r) {
|
for (std::int32_t r = 0; r < root.World(); ++r) {
|
||||||
this->channels_.emplace_back(
|
this->channels_.emplace_back(
|
||||||
@ -124,7 +123,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
|||||||
NCCLComm::~NCCLComm() {
|
NCCLComm::~NCCLComm() {
|
||||||
if (nccl_comm_) {
|
if (nccl_comm_) {
|
||||||
auto rc = stub_->CommDestroy(nccl_comm_);
|
auto rc = stub_->CommDestroy(nccl_comm_);
|
||||||
CHECK(rc.OK()) << rc.Report();
|
SafeColl(rc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -53,6 +53,10 @@ class NCCLComm : public Comm {
|
|||||||
auto rc = this->Stream().Sync(false);
|
auto rc = this->Stream().Sync(false);
|
||||||
return GetCUDAResult(rc);
|
return GetCUDAResult(rc);
|
||||||
}
|
}
|
||||||
|
[[nodiscard]] Result Shutdown() final {
|
||||||
|
this->ResetState();
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class NCCLChannel : public Channel {
|
class NCCLChannel : public Channel {
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <chrono> // for seconds
|
#include <chrono> // for seconds
|
||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
#include <cstdint> // for int32_t
|
#include <cstdint> // for int32_t, int64_t
|
||||||
#include <memory> // for shared_ptr
|
#include <memory> // for shared_ptr
|
||||||
#include <string> // for string
|
#include <string> // for string
|
||||||
#include <thread> // for thread
|
#include <thread> // for thread
|
||||||
@ -14,13 +14,13 @@
|
|||||||
#include "loop.h" // for Loop
|
#include "loop.h" // for Loop
|
||||||
#include "protocol.h" // for PeerInfo
|
#include "protocol.h" // for PeerInfo
|
||||||
#include "xgboost/collective/result.h" // for Result
|
#include "xgboost/collective/result.h" // for Result
|
||||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
#include "xgboost/collective/socket.h" // for TCPSocket, GetHostName
|
||||||
#include "xgboost/context.h" // for Context
|
#include "xgboost/context.h" // for Context
|
||||||
#include "xgboost/span.h" // for Span
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
|
|
||||||
inline constexpr std::int32_t DefaultTimeoutSec() { return 300; } // 5min
|
inline constexpr std::int64_t DefaultTimeoutSec() { return 300; } // 5min
|
||||||
inline constexpr std::int32_t DefaultRetry() { return 3; }
|
inline constexpr std::int32_t DefaultRetry() { return 3; }
|
||||||
|
|
||||||
// indexing into the ring
|
// indexing into the ring
|
||||||
@ -51,11 +51,25 @@ class Comm : public std::enable_shared_from_this<Comm> {
|
|||||||
|
|
||||||
proto::PeerInfo tracker_;
|
proto::PeerInfo tracker_;
|
||||||
SockDomain domain_{SockDomain::kV4};
|
SockDomain domain_{SockDomain::kV4};
|
||||||
|
|
||||||
std::thread error_worker_;
|
std::thread error_worker_;
|
||||||
|
std::int32_t error_port_;
|
||||||
|
|
||||||
std::string task_id_;
|
std::string task_id_;
|
||||||
std::vector<std::shared_ptr<Channel>> channels_;
|
std::vector<std::shared_ptr<Channel>> channels_;
|
||||||
std::shared_ptr<Loop> loop_{new Loop{std::chrono::seconds{
|
std::shared_ptr<Loop> loop_{nullptr}; // fixme: require federated comm to have a timeout
|
||||||
DefaultTimeoutSec()}}}; // fixme: require federated comm to have a timeout
|
|
||||||
|
void ResetState() {
|
||||||
|
this->world_ = -1;
|
||||||
|
this->rank_ = 0;
|
||||||
|
this->timeout_ = std::chrono::seconds{DefaultTimeoutSec()};
|
||||||
|
|
||||||
|
tracker_ = proto::PeerInfo{};
|
||||||
|
this->task_id_.clear();
|
||||||
|
channels_.clear();
|
||||||
|
|
||||||
|
loop_.reset();
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Comm() = default;
|
Comm() = default;
|
||||||
@ -75,10 +89,13 @@ class Comm : public std::enable_shared_from_this<Comm> {
|
|||||||
[[nodiscard]] auto Retry() const { return retry_; }
|
[[nodiscard]] auto Retry() const { return retry_; }
|
||||||
[[nodiscard]] auto TaskID() const { return task_id_; }
|
[[nodiscard]] auto TaskID() const { return task_id_; }
|
||||||
|
|
||||||
[[nodiscard]] auto Rank() const { return rank_; }
|
[[nodiscard]] auto Rank() const noexcept { return rank_; }
|
||||||
[[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; }
|
[[nodiscard]] auto World() const noexcept { return IsDistributed() ? world_ : 1; }
|
||||||
[[nodiscard]] bool IsDistributed() const { return world_ != -1; }
|
[[nodiscard]] bool IsDistributed() const noexcept { return world_ != -1; }
|
||||||
void Submit(Loop::Op op) const { loop_->Submit(op); }
|
void Submit(Loop::Op op) const {
|
||||||
|
CHECK(loop_);
|
||||||
|
loop_->Submit(op);
|
||||||
|
}
|
||||||
[[nodiscard]] virtual Result Block() const { return loop_->Block(); }
|
[[nodiscard]] virtual Result Block() const { return loop_->Block(); }
|
||||||
|
|
||||||
[[nodiscard]] virtual std::shared_ptr<Channel> Chan(std::int32_t rank) const {
|
[[nodiscard]] virtual std::shared_ptr<Channel> Chan(std::int32_t rank) const {
|
||||||
@ -88,6 +105,14 @@ class Comm : public std::enable_shared_from_this<Comm> {
|
|||||||
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;
|
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;
|
||||||
|
|
||||||
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
|
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
|
||||||
|
/**
|
||||||
|
* @brief Get a string ID for the current process.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] virtual Result ProcessorName(std::string* out) const {
|
||||||
|
auto rc = GetHostName(out);
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
[[nodiscard]] virtual Result Shutdown() = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -105,20 +130,20 @@ class RabitComm : public HostComm {
|
|||||||
|
|
||||||
[[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
|
[[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry,
|
||||||
std::string task_id);
|
std::string task_id);
|
||||||
[[nodiscard]] Result Shutdown();
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// bootstrapping construction.
|
// bootstrapping construction.
|
||||||
RabitComm() = default;
|
RabitComm() = default;
|
||||||
// ctor for testing where environment is known.
|
RabitComm(std::string const& tracker_host, std::int32_t tracker_port,
|
||||||
RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
std::chrono::seconds timeout, std::int32_t retry, std::string task_id,
|
||||||
std::int32_t retry, std::string task_id, StringView nccl_path);
|
StringView nccl_path);
|
||||||
~RabitComm() noexcept(false) override;
|
~RabitComm() noexcept(false) override;
|
||||||
|
|
||||||
[[nodiscard]] bool IsFederated() const override { return false; }
|
[[nodiscard]] bool IsFederated() const override { return false; }
|
||||||
[[nodiscard]] Result LogTracker(std::string msg) const override;
|
[[nodiscard]] Result LogTracker(std::string msg) const override;
|
||||||
|
|
||||||
[[nodiscard]] Result SignalError(Result const&) override;
|
[[nodiscard]] Result SignalError(Result const&) override;
|
||||||
|
[[nodiscard]] Result Shutdown() final;
|
||||||
|
|
||||||
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
|
[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -1,20 +1,19 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include "comm_group.h"
|
#include "comm_group.h"
|
||||||
|
|
||||||
#include <algorithm> // for transform
|
#include <algorithm> // for transform
|
||||||
|
#include <cctype> // for tolower
|
||||||
#include <chrono> // for seconds
|
#include <chrono> // for seconds
|
||||||
#include <cstdint> // for int32_t
|
#include <cstdint> // for int32_t
|
||||||
|
#include <iterator> // for back_inserter
|
||||||
#include <memory> // for shared_ptr, unique_ptr
|
#include <memory> // for shared_ptr, unique_ptr
|
||||||
#include <string> // for string
|
#include <string> // for string
|
||||||
#include <vector> // for vector
|
|
||||||
|
|
||||||
#include "../common/json_utils.h" // for OptionalArg
|
#include "../common/json_utils.h" // for OptionalArg
|
||||||
#include "coll.h" // for Coll
|
#include "coll.h" // for Coll
|
||||||
#include "comm.h" // for Comm
|
#include "comm.h" // for Comm
|
||||||
#include "tracker.h" // for GetHostAddress
|
|
||||||
#include "xgboost/collective/result.h" // for Result
|
|
||||||
#include "xgboost/context.h" // for DeviceOrd
|
#include "xgboost/context.h" // for DeviceOrd
|
||||||
#include "xgboost/json.h" // for Json
|
#include "xgboost/json.h" // for Json
|
||||||
|
|
||||||
@ -65,6 +64,9 @@ CommGroup::CommGroup()
|
|||||||
|
|
||||||
auto const& obj = get<Object const>(config);
|
auto const& obj = get<Object const>(config);
|
||||||
auto it = obj.find(upper);
|
auto it = obj.find(upper);
|
||||||
|
if (it != obj.cend() && obj.find(name) != obj.cend()) {
|
||||||
|
LOG(FATAL) << "Duplicated parameter:" << name;
|
||||||
|
}
|
||||||
if (it != obj.cend()) {
|
if (it != obj.cend()) {
|
||||||
return OptionalArg<decltype(t)>(config, upper, dft);
|
return OptionalArg<decltype(t)>(config, upper, dft);
|
||||||
} else {
|
} else {
|
||||||
@ -78,12 +80,12 @@ CommGroup::CommGroup()
|
|||||||
auto task_id = get_param("dmlc_task_id", std::string{}, String{});
|
auto task_id = get_param("dmlc_task_id", std::string{}, String{});
|
||||||
|
|
||||||
if (type == "rabit") {
|
if (type == "rabit") {
|
||||||
auto host = get_param("dmlc_tracker_uri", std::string{}, String{});
|
auto tracker_host = get_param("dmlc_tracker_uri", std::string{}, String{});
|
||||||
auto port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{});
|
auto tracker_port = get_param("dmlc_tracker_port", static_cast<std::int64_t>(0), Integer{});
|
||||||
auto nccl = get_param("dmlc_nccl_path", std::string{DefaultNcclName()}, String{});
|
auto nccl = get_param("dmlc_nccl_path", std::string{DefaultNcclName()}, String{});
|
||||||
auto ptr =
|
auto ptr = new CommGroup{
|
||||||
new CommGroup{std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT
|
std::shared_ptr<RabitComm>{new RabitComm{ // NOLINT
|
||||||
host, static_cast<std::int32_t>(port), std::chrono::seconds{timeout},
|
tracker_host, static_cast<std::int32_t>(tracker_port), std::chrono::seconds{timeout},
|
||||||
static_cast<std::int32_t>(retry), task_id, nccl}},
|
static_cast<std::int32_t>(retry), task_id, nccl}},
|
||||||
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
|
std::shared_ptr<Coll>(new Coll{})}; // NOLINT
|
||||||
return ptr;
|
return ptr;
|
||||||
@ -117,6 +119,8 @@ void GlobalCommGroupInit(Json config) {
|
|||||||
|
|
||||||
void GlobalCommGroupFinalize() {
|
void GlobalCommGroupFinalize() {
|
||||||
auto& sptr = GlobalCommGroup();
|
auto& sptr = GlobalCommGroup();
|
||||||
|
auto rc = sptr->Finalize();
|
||||||
sptr.reset();
|
sptr.reset();
|
||||||
|
SafeColl(rc);
|
||||||
}
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -9,7 +9,6 @@
|
|||||||
#include "coll.h" // for Comm
|
#include "coll.h" // for Comm
|
||||||
#include "comm.h" // for Coll
|
#include "comm.h" // for Coll
|
||||||
#include "xgboost/collective/result.h" // for Result
|
#include "xgboost/collective/result.h" // for Result
|
||||||
#include "xgboost/collective/socket.h" // for GetHostName
|
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
/**
|
/**
|
||||||
@ -31,19 +30,35 @@ class CommGroup {
|
|||||||
public:
|
public:
|
||||||
CommGroup();
|
CommGroup();
|
||||||
|
|
||||||
[[nodiscard]] auto World() const { return comm_->World(); }
|
[[nodiscard]] auto World() const noexcept { return comm_->World(); }
|
||||||
[[nodiscard]] auto Rank() const { return comm_->Rank(); }
|
[[nodiscard]] auto Rank() const noexcept { return comm_->Rank(); }
|
||||||
[[nodiscard]] bool IsDistributed() const { return comm_->IsDistributed(); }
|
[[nodiscard]] bool IsDistributed() const noexcept { return comm_->IsDistributed(); }
|
||||||
|
|
||||||
|
[[nodiscard]] Result Finalize() const {
|
||||||
|
return Success() << [this] {
|
||||||
|
if (gpu_comm_) {
|
||||||
|
return gpu_comm_->Shutdown();
|
||||||
|
}
|
||||||
|
return Success();
|
||||||
|
} << [&] {
|
||||||
|
return comm_->Shutdown();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
[[nodiscard]] static CommGroup* Create(Json config);
|
[[nodiscard]] static CommGroup* Create(Json config);
|
||||||
|
|
||||||
[[nodiscard]] std::shared_ptr<Coll> Backend(DeviceOrd device) const;
|
[[nodiscard]] std::shared_ptr<Coll> Backend(DeviceOrd device) const;
|
||||||
|
/**
|
||||||
|
* @brief Decide the context to use for communication.
|
||||||
|
*
|
||||||
|
* @param ctx Global context, provides the CUDA stream and ordinal.
|
||||||
|
* @param device The device used by the data to be communicated.
|
||||||
|
*/
|
||||||
[[nodiscard]] Comm const& Ctx(Context const* ctx, DeviceOrd device) const;
|
[[nodiscard]] Comm const& Ctx(Context const* ctx, DeviceOrd device) const;
|
||||||
[[nodiscard]] Result SignalError(Result const& res) { return comm_->SignalError(res); }
|
[[nodiscard]] Result SignalError(Result const& res) { return comm_->SignalError(res); }
|
||||||
|
|
||||||
[[nodiscard]] Result ProcessorName(std::string* out) const {
|
[[nodiscard]] Result ProcessorName(std::string* out) const {
|
||||||
auto rc = GetHostName(out);
|
return this->comm_->ProcessorName(out);
|
||||||
return rc;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -32,7 +32,8 @@ class InMemoryHandler {
|
|||||||
*
|
*
|
||||||
* This is used when the handler only needs to be initialized once with a known world size.
|
* This is used when the handler only needs to be initialized once with a known world size.
|
||||||
*/
|
*/
|
||||||
explicit InMemoryHandler(std::size_t worldSize) : world_size_{worldSize} {}
|
explicit InMemoryHandler(std::int32_t worldSize)
|
||||||
|
: world_size_{static_cast<std::size_t>(worldSize)} {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Initialize the handler with the world size and rank.
|
* @brief Initialize the handler with the world size and rank.
|
||||||
|
|||||||
@ -18,9 +18,11 @@
|
|||||||
#include "xgboost/logging.h" // for CHECK
|
#include "xgboost/logging.h" // for CHECK
|
||||||
|
|
||||||
namespace xgboost::collective {
|
namespace xgboost::collective {
|
||||||
Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
Result Loop::ProcessQueue(std::queue<Op>* p_queue, bool blocking) const {
|
||||||
timer_.Start(__func__);
|
timer_.Start(__func__);
|
||||||
auto error = [this] { timer_.Stop(__func__); };
|
auto error = [this] {
|
||||||
|
timer_.Stop(__func__);
|
||||||
|
};
|
||||||
|
|
||||||
if (stop_) {
|
if (stop_) {
|
||||||
timer_.Stop(__func__);
|
timer_.Stop(__func__);
|
||||||
@ -48,6 +50,9 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
|||||||
poll.WatchWrite(*op.sock);
|
poll.WatchWrite(*op.sock);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case Op::kSleep: {
|
||||||
|
break;
|
||||||
|
}
|
||||||
default: {
|
default: {
|
||||||
error();
|
error();
|
||||||
return Fail("Invalid socket operation.");
|
return Fail("Invalid socket operation.");
|
||||||
@ -59,12 +64,14 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
|||||||
|
|
||||||
// poll, work on fds that are ready.
|
// poll, work on fds that are ready.
|
||||||
timer_.Start("poll");
|
timer_.Start("poll");
|
||||||
|
if (!poll.fds.empty()) {
|
||||||
auto rc = poll.Poll(timeout_);
|
auto rc = poll.Poll(timeout_);
|
||||||
timer_.Stop("poll");
|
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
error();
|
error();
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
timer_.Stop("poll");
|
||||||
|
|
||||||
// we wonldn't be here if the queue is empty.
|
// we wonldn't be here if the queue is empty.
|
||||||
CHECK(!qcopy.empty());
|
CHECK(!qcopy.empty());
|
||||||
@ -75,12 +82,20 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
|||||||
qcopy.pop();
|
qcopy.pop();
|
||||||
|
|
||||||
std::int32_t n_bytes_done{0};
|
std::int32_t n_bytes_done{0};
|
||||||
|
if (!op.sock) {
|
||||||
|
CHECK(op.code == Op::kSleep);
|
||||||
|
} else {
|
||||||
CHECK(op.sock->NonBlocking());
|
CHECK(op.sock->NonBlocking());
|
||||||
|
}
|
||||||
|
|
||||||
switch (op.code) {
|
switch (op.code) {
|
||||||
case Op::kRead: {
|
case Op::kRead: {
|
||||||
if (poll.CheckRead(*op.sock)) {
|
if (poll.CheckRead(*op.sock)) {
|
||||||
n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off);
|
n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off);
|
||||||
|
if (n_bytes_done == 0) {
|
||||||
|
error();
|
||||||
|
return Fail("Encountered EOF. The other end is likely closed.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -90,6 +105,12 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case Op::kSleep: {
|
||||||
|
// For testing only.
|
||||||
|
std::this_thread::sleep_for(std::chrono::seconds{op.n});
|
||||||
|
n_bytes_done = op.n;
|
||||||
|
break;
|
||||||
|
}
|
||||||
default: {
|
default: {
|
||||||
error();
|
error();
|
||||||
return Fail("Invalid socket operation.");
|
return Fail("Invalid socket operation.");
|
||||||
@ -110,6 +131,10 @@ Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
|
|||||||
qcopy.push(op);
|
qcopy.push(op);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!blocking) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
timer_.Stop(__func__);
|
timer_.Stop(__func__);
|
||||||
@ -128,6 +153,15 @@ void Loop::Process() {
|
|||||||
while (true) {
|
while (true) {
|
||||||
try {
|
try {
|
||||||
std::unique_lock lock{mu_};
|
std::unique_lock lock{mu_};
|
||||||
|
// This can handle missed notification: wait(lock, predicate) is equivalent to:
|
||||||
|
//
|
||||||
|
// while (!predicate()) {
|
||||||
|
// cv.wait(lock);
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// As a result, if there's a missed notification, the queue wouldn't be empty, hence
|
||||||
|
// the predicate would be false and the actual wait wouldn't be invoked. Therefore,
|
||||||
|
// the blocking call can never go unanswered.
|
||||||
cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; });
|
cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; });
|
||||||
if (stop_) {
|
if (stop_) {
|
||||||
break; // only point where this loop can exit.
|
break; // only point where this loop can exit.
|
||||||
@ -142,26 +176,27 @@ void Loop::Process() {
|
|||||||
queue_.pop();
|
queue_.pop();
|
||||||
if (op.code == Op::kBlock) {
|
if (op.code == Op::kBlock) {
|
||||||
is_blocking = true;
|
is_blocking = true;
|
||||||
// Block must be the last op in the current batch since no further submit can be
|
|
||||||
// issued until the blocking call is finished.
|
|
||||||
CHECK(queue_.empty());
|
|
||||||
} else {
|
} else {
|
||||||
qcopy.push(op);
|
qcopy.push(op);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!is_blocking) {
|
|
||||||
// Unblock, we can write to the global queue again.
|
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
|
// Clear the local queue, if `is_blocking` is true, this is blocking the current
|
||||||
|
// worker thread (but not the client thread), wait until all operations are
|
||||||
|
// finished.
|
||||||
|
auto rc = this->ProcessQueue(&qcopy, is_blocking);
|
||||||
|
|
||||||
|
if (is_blocking && rc.OK()) {
|
||||||
|
CHECK(qcopy.empty());
|
||||||
|
}
|
||||||
|
// Push back the remaining operations.
|
||||||
|
if (rc.OK()) {
|
||||||
|
std::unique_lock lock{mu_};
|
||||||
|
while (!qcopy.empty()) {
|
||||||
|
queue_.push(qcopy.front());
|
||||||
|
qcopy.pop();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear the local queue, this is blocking the current worker thread (but not the
|
|
||||||
// client thread), wait until all operations are finished.
|
|
||||||
auto rc = this->EmptyQueue(&qcopy);
|
|
||||||
|
|
||||||
if (is_blocking) {
|
|
||||||
// The unlock is delayed if this is a blocking call
|
|
||||||
lock.unlock();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Notify the client thread who called block after all error conditions are set.
|
// Notify the client thread who called block after all error conditions are set.
|
||||||
@ -228,7 +263,6 @@ Result Loop::Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
this->Submit(Op{Op::kBlock});
|
this->Submit(Op{Op::kBlock});
|
||||||
|
|
||||||
{
|
{
|
||||||
// Wait for the block call to finish.
|
// Wait for the block call to finish.
|
||||||
std::unique_lock lock{mu_};
|
std::unique_lock lock{mu_};
|
||||||
@ -243,8 +277,20 @@ Result Loop::Stop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Loop::Submit(Op op) {
|
||||||
|
std::unique_lock lock{mu_};
|
||||||
|
if (op.code != Op::kBlock) {
|
||||||
|
CHECK_NE(op.n, 0);
|
||||||
|
}
|
||||||
|
queue_.push(op);
|
||||||
|
lock.unlock();
|
||||||
|
cv_.notify_one();
|
||||||
|
}
|
||||||
|
|
||||||
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
|
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
|
||||||
timer_.Init(__func__);
|
timer_.Init(__func__);
|
||||||
worker_ = std::thread{[this] { this->Process(); }};
|
worker_ = std::thread{[this] {
|
||||||
|
this->Process();
|
||||||
|
}};
|
||||||
}
|
}
|
||||||
} // namespace xgboost::collective
|
} // namespace xgboost::collective
|
||||||
|
|||||||
@ -19,20 +19,27 @@ namespace xgboost::collective {
|
|||||||
class Loop {
|
class Loop {
|
||||||
public:
|
public:
|
||||||
struct Op {
|
struct Op {
|
||||||
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2 } code;
|
// kSleep is only for testing
|
||||||
|
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2, kSleep = 4 } code;
|
||||||
std::int32_t rank{-1};
|
std::int32_t rank{-1};
|
||||||
std::int8_t* ptr{nullptr};
|
std::int8_t* ptr{nullptr};
|
||||||
std::size_t n{0};
|
std::size_t n{0};
|
||||||
TCPSocket* sock{nullptr};
|
TCPSocket* sock{nullptr};
|
||||||
std::size_t off{0};
|
std::size_t off{0};
|
||||||
|
|
||||||
explicit Op(Code c) : code{c} { CHECK(c == kBlock); }
|
explicit Op(Code c) : code{c} { CHECK(c == kBlock || c == kSleep); }
|
||||||
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
|
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
|
||||||
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
|
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
|
||||||
Op(Op const&) = default;
|
Op(Op const&) = default;
|
||||||
Op& operator=(Op const&) = default;
|
Op& operator=(Op const&) = default;
|
||||||
Op(Op&&) = default;
|
Op(Op&&) = default;
|
||||||
Op& operator=(Op&&) = default;
|
Op& operator=(Op&&) = default;
|
||||||
|
// For testing purpose only
|
||||||
|
[[nodiscard]] static Op Sleep(std::size_t seconds) {
|
||||||
|
Op op{kSleep};
|
||||||
|
op.n = seconds;
|
||||||
|
return op;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -54,7 +61,7 @@ class Loop {
|
|||||||
std::exception_ptr curr_exce_{nullptr};
|
std::exception_ptr curr_exce_{nullptr};
|
||||||
common::Monitor mutable timer_;
|
common::Monitor mutable timer_;
|
||||||
|
|
||||||
Result EmptyQueue(std::queue<Op>* p_queue) const;
|
Result ProcessQueue(std::queue<Op>* p_queue, bool blocking) const;
|
||||||
// The cunsumer function that runs inside a worker thread.
|
// The cunsumer function that runs inside a worker thread.
|
||||||
void Process();
|
void Process();
|
||||||
|
|
||||||
@ -64,12 +71,7 @@ class Loop {
|
|||||||
*/
|
*/
|
||||||
Result Stop();
|
Result Stop();
|
||||||
|
|
||||||
void Submit(Op op) {
|
void Submit(Op op);
|
||||||
std::unique_lock lock{mu_};
|
|
||||||
queue_.push(op);
|
|
||||||
lock.unlock();
|
|
||||||
cv_.notify_one();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Block the event loop until all ops are finished. In the case of failure, this
|
* @brief Block the event loop until all ops are finished. In the case of failure, this
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
* Copyright 2023 XGBoost contributors
|
* Copyright 2023 XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
|
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
|
||||||
|
#include <numeric> // for accumulate
|
||||||
|
|
||||||
#include "comm.cuh"
|
#include "comm.cuh"
|
||||||
#include "nccl_device_communicator.cuh"
|
#include "nccl_device_communicator.cuh"
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <cstdint> // for int32_t
|
#include <cstdint> // for int32_t
|
||||||
@ -58,6 +58,7 @@ struct Magic {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Basic commands for communication between workers and the tracker.
|
||||||
enum class CMD : std::int32_t {
|
enum class CMD : std::int32_t {
|
||||||
kInvalid = 0,
|
kInvalid = 0,
|
||||||
kStart = 1,
|
kStart = 1,
|
||||||
@ -84,7 +85,10 @@ struct Connect {
|
|||||||
[[nodiscard]] Result TrackerRecv(TCPSocket* sock, std::int32_t* world, std::int32_t* rank,
|
[[nodiscard]] Result TrackerRecv(TCPSocket* sock, std::int32_t* world, std::int32_t* rank,
|
||||||
std::string* task_id) const {
|
std::string* task_id) const {
|
||||||
std::string init;
|
std::string init;
|
||||||
sock->Recv(&init);
|
auto rc = sock->Recv(&init);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return Fail("Connect protocol failed.", std::move(rc));
|
||||||
|
}
|
||||||
auto jinit = Json::Load(StringView{init});
|
auto jinit = Json::Load(StringView{init});
|
||||||
*world = get<Integer const>(jinit["world_size"]);
|
*world = get<Integer const>(jinit["world_size"]);
|
||||||
*rank = get<Integer const>(jinit["rank"]);
|
*rank = get<Integer const>(jinit["rank"]);
|
||||||
@ -122,9 +126,9 @@ class Start {
|
|||||||
}
|
}
|
||||||
[[nodiscard]] Result WorkerRecv(TCPSocket* tracker, std::int32_t* p_world) const {
|
[[nodiscard]] Result WorkerRecv(TCPSocket* tracker, std::int32_t* p_world) const {
|
||||||
std::string scmd;
|
std::string scmd;
|
||||||
auto n_bytes = tracker->Recv(&scmd);
|
auto rc = tracker->Recv(&scmd);
|
||||||
if (n_bytes <= 0) {
|
if (!rc.OK()) {
|
||||||
return Fail("Failed to recv init command from tracker.");
|
return Fail("Failed to recv init command from tracker.", std::move(rc));
|
||||||
}
|
}
|
||||||
auto jcmd = Json::Load(scmd);
|
auto jcmd = Json::Load(scmd);
|
||||||
auto world = get<Integer const>(jcmd["world_size"]);
|
auto world = get<Integer const>(jcmd["world_size"]);
|
||||||
@ -132,7 +136,7 @@ class Start {
|
|||||||
return Fail("Invalid world size.");
|
return Fail("Invalid world size.");
|
||||||
}
|
}
|
||||||
*p_world = world;
|
*p_world = world;
|
||||||
return Success();
|
return rc;
|
||||||
}
|
}
|
||||||
[[nodiscard]] Result TrackerHandle(Json jcmd, std::int32_t* recv_world, std::int32_t world,
|
[[nodiscard]] Result TrackerHandle(Json jcmd, std::int32_t* recv_world, std::int32_t world,
|
||||||
std::int32_t* p_port, TCPSocket* p_sock,
|
std::int32_t* p_port, TCPSocket* p_sock,
|
||||||
@ -150,6 +154,7 @@ class Start {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Protocol for communicating with the tracker for printing message.
|
||||||
struct Print {
|
struct Print {
|
||||||
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::string msg) const {
|
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::string msg) const {
|
||||||
Json jcmd{Object{}};
|
Json jcmd{Object{}};
|
||||||
@ -172,6 +177,7 @@ struct Print {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Protocol for communicating with the tracker during error.
|
||||||
struct ErrorCMD {
|
struct ErrorCMD {
|
||||||
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, Result const& res) const {
|
[[nodiscard]] Result WorkerSend(TCPSocket* tracker, Result const& res) const {
|
||||||
auto msg = res.Report();
|
auto msg = res.Report();
|
||||||
@ -199,6 +205,7 @@ struct ErrorCMD {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Protocol for communicating with the tracker during shutdown.
|
||||||
struct ShutdownCMD {
|
struct ShutdownCMD {
|
||||||
[[nodiscard]] Result Send(TCPSocket* peer) const {
|
[[nodiscard]] Result Send(TCPSocket* peer) const {
|
||||||
Json jcmd{Object{}};
|
Json jcmd{Object{}};
|
||||||
@ -211,4 +218,40 @@ struct ShutdownCMD {
|
|||||||
return Success();
|
return Success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Protocol for communicating with the local error handler during error or shutdown. Only
|
||||||
|
// one protocol that doesn't have the tracker involved.
|
||||||
|
struct Error {
|
||||||
|
constexpr static std::int32_t ShutdownSignal() { return 0; }
|
||||||
|
constexpr static std::int32_t ErrorSignal() { return -1; }
|
||||||
|
|
||||||
|
[[nodiscard]] Result SignalError(TCPSocket* worker) const {
|
||||||
|
std::int32_t err{ErrorSignal()};
|
||||||
|
auto n_sent = worker->SendAll(&err, sizeof(err));
|
||||||
|
if (n_sent == sizeof(err)) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
return Fail("Failed to send error signal");
|
||||||
|
}
|
||||||
|
// self is localhost, we are sending the signal to the error handling thread for it to
|
||||||
|
// close.
|
||||||
|
[[nodiscard]] Result SignalShutdown(TCPSocket* self) const {
|
||||||
|
std::int32_t err{ShutdownSignal()};
|
||||||
|
auto n_sent = self->SendAll(&err, sizeof(err));
|
||||||
|
if (n_sent == sizeof(err)) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
return Fail("Failed to send shutdown signal");
|
||||||
|
}
|
||||||
|
// get signal, either for error or for shutdown.
|
||||||
|
[[nodiscard]] Result RecvSignal(TCPSocket* peer, bool* p_is_error) const {
|
||||||
|
std::int32_t err{ShutdownSignal()};
|
||||||
|
auto n_recv = peer->RecvAll(&err, sizeof(err));
|
||||||
|
if (n_recv == sizeof(err)) {
|
||||||
|
*p_is_error = err == 1;
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
return Fail("Failed to receive error signal.");
|
||||||
|
}
|
||||||
|
};
|
||||||
} // namespace xgboost::collective::proto
|
} // namespace xgboost::collective::proto
|
||||||
|
|||||||
86
src/collective/result.cc
Normal file
86
src/collective/result.cc
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2024, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include "xgboost/collective/result.h"
|
||||||
|
|
||||||
|
#include <filesystem> // for path
|
||||||
|
#include <sstream> // for stringstream
|
||||||
|
#include <stack> // for stack
|
||||||
|
|
||||||
|
#include "xgboost/logging.h"
|
||||||
|
|
||||||
|
namespace xgboost::collective {
|
||||||
|
namespace detail {
|
||||||
|
[[nodiscard]] std::string ResultImpl::Report() const {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "\n- " << this->message;
|
||||||
|
if (this->errc != std::error_code{}) {
|
||||||
|
ss << " system error:" << this->errc.message();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ptr = prev.get();
|
||||||
|
while (ptr) {
|
||||||
|
ss << "\n- ";
|
||||||
|
ss << ptr->message;
|
||||||
|
|
||||||
|
if (ptr->errc != std::error_code{}) {
|
||||||
|
ss << " " << ptr->errc.message();
|
||||||
|
}
|
||||||
|
ptr = ptr->prev.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::error_code ResultImpl::Code() const {
|
||||||
|
// Find the root error.
|
||||||
|
std::stack<ResultImpl const*> stack;
|
||||||
|
auto ptr = this;
|
||||||
|
while (ptr) {
|
||||||
|
stack.push(ptr);
|
||||||
|
if (ptr->prev) {
|
||||||
|
ptr = ptr->prev.get();
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
while (!stack.empty()) {
|
||||||
|
auto frame = stack.top();
|
||||||
|
stack.pop();
|
||||||
|
if (frame->errc != std::error_code{}) {
|
||||||
|
return frame->errc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::error_code{};
|
||||||
|
}
|
||||||
|
|
||||||
|
void ResultImpl::Concat(std::unique_ptr<ResultImpl> rhs) {
|
||||||
|
auto ptr = this;
|
||||||
|
while (ptr->prev) {
|
||||||
|
ptr = ptr->prev.get();
|
||||||
|
}
|
||||||
|
ptr->prev = std::move(rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__)
|
||||||
|
std::string MakeMsg(std::string&& msg, char const*, std::int32_t) {
|
||||||
|
return std::forward<std::string>(msg);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line) {
|
||||||
|
auto name = std::filesystem::path{file}.filename();
|
||||||
|
if (file && line != -1) {
|
||||||
|
return "[" + name.string() + ":" + std::to_string(line) + // NOLINT
|
||||||
|
"]: " + std::forward<std::string>(msg);
|
||||||
|
}
|
||||||
|
return std::forward<std::string>(msg);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
void SafeColl(Result const& rc) {
|
||||||
|
if (!rc.OK()) {
|
||||||
|
LOG(FATAL) << rc.Report();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace xgboost::collective
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2022-2023 by XGBoost Contributors
|
* Copyright 2022-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include "xgboost/collective/socket.h"
|
#include "xgboost/collective/socket.h"
|
||||||
|
|
||||||
@ -8,7 +8,8 @@
|
|||||||
#include <cstdint> // std::int32_t
|
#include <cstdint> // std::int32_t
|
||||||
#include <cstring> // std::memcpy, std::memset
|
#include <cstring> // std::memcpy, std::memset
|
||||||
#include <filesystem> // for path
|
#include <filesystem> // for path
|
||||||
#include <system_error> // std::error_code, std::system_category
|
#include <system_error> // for error_code, system_category
|
||||||
|
#include <thread> // for sleep_for
|
||||||
|
|
||||||
#include "rabit/internal/socket.h" // for PollHelper
|
#include "rabit/internal/socket.h" // for PollHelper
|
||||||
#include "xgboost/collective/result.h" // for Result
|
#include "xgboost/collective/result.h" // for Result
|
||||||
@ -65,14 +66,18 @@ std::size_t TCPSocket::Send(StringView str) {
|
|||||||
return bytes;
|
return bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::size_t TCPSocket::Recv(std::string *p_str) {
|
[[nodiscard]] Result TCPSocket::Recv(std::string *p_str) {
|
||||||
CHECK(!this->IsClosed());
|
CHECK(!this->IsClosed());
|
||||||
std::int32_t len;
|
std::int32_t len;
|
||||||
CHECK_EQ(this->RecvAll(&len, sizeof(len)), sizeof(len)) << "Failed to recv string length.";
|
if (this->RecvAll(&len, sizeof(len)) != sizeof(len)) {
|
||||||
|
return Fail("Failed to recv string length.");
|
||||||
|
}
|
||||||
p_str->resize(len);
|
p_str->resize(len);
|
||||||
auto bytes = this->RecvAll(&(*p_str)[0], len);
|
auto bytes = this->RecvAll(&(*p_str)[0], len);
|
||||||
CHECK_EQ(bytes, len) << "Failed to recv string.";
|
if (static_cast<decltype(len)>(bytes) != len) {
|
||||||
return bytes;
|
return Fail("Failed to recv string.");
|
||||||
|
}
|
||||||
|
return Success();
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
|
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
|
||||||
@ -110,11 +115,7 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
|
|||||||
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
|
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
|
||||||
if (attempt > 0) {
|
if (attempt > 0) {
|
||||||
LOG(WARNING) << "Retrying connection to " << host << " for the " << attempt << " time.";
|
LOG(WARNING) << "Retrying connection to " << host << " for the " << attempt << " time.";
|
||||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
std::this_thread::sleep_for(std::chrono::seconds{attempt << 1});
|
||||||
Sleep(attempt << 1);
|
|
||||||
#else
|
|
||||||
sleep(attempt << 1);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto rc = connect(conn.Handle(), addr_handle, addr_len);
|
auto rc = connect(conn.Handle(), addr_handle, addr_len);
|
||||||
@ -158,8 +159,8 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
|
|||||||
|
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "Failed to connect to " << host << ":" << port;
|
ss << "Failed to connect to " << host << ":" << port;
|
||||||
conn.Close();
|
auto close_rc = conn.Close();
|
||||||
return Fail(ss.str(), std::move(last_error));
|
return Fail(ss.str(), std::move(close_rc) + std::move(last_error));
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result GetHostName(std::string *p_out) {
|
[[nodiscard]] Result GetHostName(std::string *p_out) {
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023-2024, XGBoost Contributors
|
* Copyright 2023-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
|
#include "rabit/internal/socket.h"
|
||||||
#if defined(__unix__) || defined(__APPLE__)
|
#if defined(__unix__) || defined(__APPLE__)
|
||||||
#include <netdb.h> // gethostbyname
|
#include <netdb.h> // gethostbyname
|
||||||
#include <sys/socket.h> // socket, AF_INET6, AF_INET, connect, getsockname
|
#include <sys/socket.h> // socket, AF_INET6, AF_INET, connect, getsockname
|
||||||
@ -70,10 +71,13 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA
|
|||||||
return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_);
|
return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_);
|
||||||
} << [&] {
|
} << [&] {
|
||||||
std::string cmd;
|
std::string cmd;
|
||||||
sock_.Recv(&cmd);
|
auto rc = sock_.Recv(&cmd);
|
||||||
|
if (!rc.OK()) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
jcmd = Json::Load(StringView{cmd});
|
jcmd = Json::Load(StringView{cmd});
|
||||||
cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"]));
|
cmd_ = static_cast<proto::CMD>(get<Integer const>(jcmd["cmd"]));
|
||||||
return Success();
|
return rc;
|
||||||
} << [&] {
|
} << [&] {
|
||||||
if (cmd_ == proto::CMD::kStart) {
|
if (cmd_ == proto::CMD::kStart) {
|
||||||
proto::Start start;
|
proto::Start start;
|
||||||
@ -100,14 +104,18 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA
|
|||||||
|
|
||||||
RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
|
RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
|
||||||
std::string self;
|
std::string self;
|
||||||
auto rc = collective::GetHostAddress(&self);
|
auto rc = Success() << [&] {
|
||||||
|
return collective::GetHostAddress(&self);
|
||||||
|
} << [&] {
|
||||||
host_ = OptionalArg<String>(config, "host", self);
|
host_ = OptionalArg<String>(config, "host", self);
|
||||||
|
|
||||||
auto addr = MakeSockAddress(xgboost::StringView{host_}, 0);
|
auto addr = MakeSockAddress(xgboost::StringView{host_}, 0);
|
||||||
listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6);
|
listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6);
|
||||||
rc = listener_.Bind(host_, &this->port_);
|
return listener_.Bind(host_, &this->port_);
|
||||||
|
} << [&] {
|
||||||
|
return listener_.Listen();
|
||||||
|
};
|
||||||
SafeColl(rc);
|
SafeColl(rc);
|
||||||
listener_.Listen();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
||||||
@ -220,9 +228,13 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
//
|
//
|
||||||
// retry is set to 1, just let the worker timeout or error. Otherwise the
|
// retry is set to 1, just let the worker timeout or error. Otherwise the
|
||||||
// tracker and the worker might be waiting for each other.
|
// tracker and the worker might be waiting for each other.
|
||||||
auto rc = Connect(w.first, w.second, 1, timeout_, &out);
|
auto rc = Success() << [&] {
|
||||||
|
return Connect(w.first, w.second, 1, timeout_, &out);
|
||||||
|
} << [&] {
|
||||||
|
return proto::Error{}.SignalError(&out);
|
||||||
|
};
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return Fail("Failed to inform workers to stop.");
|
return Fail("Failed to inform worker:" + w.first + " for error.", std::move(rc));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Success();
|
return Success();
|
||||||
@ -231,13 +243,37 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
return std::async(std::launch::async, [this, handle_error] {
|
return std::async(std::launch::async, [this, handle_error] {
|
||||||
State state{this->n_workers_};
|
State state{this->n_workers_};
|
||||||
|
|
||||||
|
auto select_accept = [&](TCPSocket* sock, auto* addr) {
|
||||||
|
// accept with poll so that we can enable timeout and interruption.
|
||||||
|
rabit::utils::PollHelper poll;
|
||||||
|
auto rc = Success() << [&] {
|
||||||
|
std::lock_guard lock{listener_mu_};
|
||||||
|
return listener_.NonBlocking(true);
|
||||||
|
} << [&] {
|
||||||
|
std::lock_guard lock{listener_mu_};
|
||||||
|
poll.WatchRead(listener_);
|
||||||
|
if (state.running) {
|
||||||
|
// Don't timeout if the communicator group is up and running.
|
||||||
|
return poll.Poll(std::chrono::seconds{-1});
|
||||||
|
} else {
|
||||||
|
// Have timeout for workers to bootstrap.
|
||||||
|
return poll.Poll(timeout_);
|
||||||
|
}
|
||||||
|
} << [&] {
|
||||||
|
// this->Stop() closes the socket with a lock. Therefore, when the accept returns
|
||||||
|
// due to shutdown, the state is still valid (closed).
|
||||||
|
return listener_.Accept(sock, addr);
|
||||||
|
};
|
||||||
|
return rc;
|
||||||
|
};
|
||||||
|
|
||||||
while (state.ShouldContinue()) {
|
while (state.ShouldContinue()) {
|
||||||
TCPSocket sock;
|
TCPSocket sock;
|
||||||
SockAddress addr;
|
SockAddress addr;
|
||||||
this->ready_ = true;
|
this->ready_ = true;
|
||||||
auto rc = listener_.Accept(&sock, &addr);
|
auto rc = select_accept(&sock, &addr);
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return Fail("Failed to accept connection.", std::move(rc));
|
return Fail("Failed to accept connection.", this->Stop() + std::move(rc));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto worker = WorkerProxy{n_workers_, std::move(sock), std::move(addr)};
|
auto worker = WorkerProxy{n_workers_, std::move(sock), std::move(addr)};
|
||||||
@ -252,7 +288,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
state.Error();
|
state.Error();
|
||||||
rc = handle_error(worker);
|
rc = handle_error(worker);
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return Fail("Failed to handle abort.", std::move(rc));
|
return Fail("Failed to handle abort.", this->Stop() + std::move(rc));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -262,7 +298,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
state.Bootstrap();
|
state.Bootstrap();
|
||||||
}
|
}
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
return rc;
|
return this->Stop() + std::move(rc);
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -289,12 +325,11 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
}
|
}
|
||||||
case proto::CMD::kInvalid:
|
case proto::CMD::kInvalid:
|
||||||
default: {
|
default: {
|
||||||
return Fail("Invalid command received.");
|
return Fail("Invalid command received.", this->Stop());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ready_ = false;
|
return this->Stop();
|
||||||
return Success();
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -303,11 +338,30 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
|
|||||||
SafeColl(rc);
|
SafeColl(rc);
|
||||||
|
|
||||||
Json args{Object{}};
|
Json args{Object{}};
|
||||||
args["DMLC_TRACKER_URI"] = String{host_};
|
args["dmlc_tracker_uri"] = String{host_};
|
||||||
args["DMLC_TRACKER_PORT"] = this->Port();
|
args["dmlc_tracker_port"] = this->Port();
|
||||||
return args;
|
return args;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] Result RabitTracker::Stop() {
|
||||||
|
if (!this->Ready()) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
|
||||||
|
ready_ = false;
|
||||||
|
std::lock_guard lock{listener_mu_};
|
||||||
|
if (this->listener_.IsClosed()) {
|
||||||
|
return Success();
|
||||||
|
}
|
||||||
|
|
||||||
|
return Success() << [&] {
|
||||||
|
// This should have the effect of stopping the `accept` call.
|
||||||
|
return this->listener_.Shutdown();
|
||||||
|
} << [&] {
|
||||||
|
return listener_.Close();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
[[nodiscard]] Result GetHostAddress(std::string* out) {
|
[[nodiscard]] Result GetHostAddress(std::string* out) {
|
||||||
auto rc = GetHostName(out);
|
auto rc = GetHostName(out);
|
||||||
if (!rc.OK()) {
|
if (!rc.OK()) {
|
||||||
|
|||||||
@ -36,15 +36,18 @@ namespace xgboost::collective {
|
|||||||
* signal an error to the tracker and the tracker will notify other workers.
|
* signal an error to the tracker and the tracker will notify other workers.
|
||||||
*/
|
*/
|
||||||
class Tracker {
|
class Tracker {
|
||||||
|
public:
|
||||||
|
enum class SortBy : std::int8_t {
|
||||||
|
kHost = 0,
|
||||||
|
kTask = 1,
|
||||||
|
};
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// How to sort the workers, either by host name or by task ID. When using a multi-GPU
|
// How to sort the workers, either by host name or by task ID. When using a multi-GPU
|
||||||
// setting, multiple workers can occupy the same host, in which case one should sort
|
// setting, multiple workers can occupy the same host, in which case one should sort
|
||||||
// workers by task. Due to compatibility reason, the task ID is not always available, so
|
// workers by task. Due to compatibility reason, the task ID is not always available, so
|
||||||
// we use host as the default.
|
// we use host as the default.
|
||||||
enum class SortBy : std::int8_t {
|
SortBy sortby_;
|
||||||
kHost = 0,
|
|
||||||
kTask = 1,
|
|
||||||
} sortby_;
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::int32_t n_workers_{0};
|
std::int32_t n_workers_{0};
|
||||||
@ -54,10 +57,7 @@ class Tracker {
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
explicit Tracker(Json const& config);
|
explicit Tracker(Json const& config);
|
||||||
Tracker(std::int32_t n_worders, std::int32_t port, std::chrono::seconds timeout)
|
virtual ~Tracker() = default;
|
||||||
: n_workers_{n_worders}, port_{port}, timeout_{timeout} {}
|
|
||||||
|
|
||||||
virtual ~Tracker() noexcept(false){}; // NOLINT
|
|
||||||
|
|
||||||
[[nodiscard]] Result WaitUntilReady() const;
|
[[nodiscard]] Result WaitUntilReady() const;
|
||||||
|
|
||||||
@ -69,6 +69,11 @@ class Tracker {
|
|||||||
* @brief Flag to indicate whether the server is running.
|
* @brief Flag to indicate whether the server is running.
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] bool Ready() const { return ready_; }
|
[[nodiscard]] bool Ready() const { return ready_; }
|
||||||
|
/**
|
||||||
|
* @brief Shutdown the tracker, cannot be restarted again. Useful when the tracker hangs while
|
||||||
|
* calling accept.
|
||||||
|
*/
|
||||||
|
virtual Result Stop() { return Success(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
class RabitTracker : public Tracker {
|
class RabitTracker : public Tracker {
|
||||||
@ -127,28 +132,22 @@ class RabitTracker : public Tracker {
|
|||||||
// record for how to reach out to workers if error happens.
|
// record for how to reach out to workers if error happens.
|
||||||
std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_;
|
std::vector<std::pair<std::string, std::int32_t>> worker_error_handles_;
|
||||||
// listening socket for incoming workers.
|
// listening socket for incoming workers.
|
||||||
//
|
|
||||||
// At the moment, the listener calls accept without first polling. We can add an
|
|
||||||
// additional unix domain socket to allow cancelling the accept.
|
|
||||||
TCPSocket listener_;
|
TCPSocket listener_;
|
||||||
|
// mutex for protecting the listener, used to prevent race when it's listening while
|
||||||
|
// another thread tries to shut it down.
|
||||||
|
std::mutex listener_mu_;
|
||||||
|
|
||||||
Result Bootstrap(std::vector<WorkerProxy>* p_workers);
|
Result Bootstrap(std::vector<WorkerProxy>* p_workers);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit RabitTracker(StringView host, std::int32_t n_worders, std::int32_t port,
|
|
||||||
std::chrono::seconds timeout)
|
|
||||||
: Tracker{n_worders, port, timeout}, host_{host.c_str(), host.size()} {
|
|
||||||
listener_ = TCPSocket::Create(SockDomain::kV4);
|
|
||||||
auto rc = listener_.Bind(host, &this->port_);
|
|
||||||
CHECK(rc.OK()) << rc.Report();
|
|
||||||
listener_.Listen();
|
|
||||||
}
|
|
||||||
|
|
||||||
explicit RabitTracker(Json const& config);
|
explicit RabitTracker(Json const& config);
|
||||||
~RabitTracker() noexcept(false) override = default;
|
~RabitTracker() override = default;
|
||||||
|
|
||||||
std::future<Result> Run() override;
|
std::future<Result> Run() override;
|
||||||
[[nodiscard]] Json WorkerArgs() const override;
|
[[nodiscard]] Json WorkerArgs() const override;
|
||||||
|
// Stop the tracker without waiting. This is to prevent the tracker from hanging when
|
||||||
|
// one of the workers failes to start.
|
||||||
|
[[nodiscard]] Result Stop() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Prob the public IP address of the host, need a better method.
|
// Prob the public IP address of the host, need a better method.
|
||||||
|
|||||||
@ -14,7 +14,6 @@
|
|||||||
#include <thrust/iterator/transform_output_iterator.h> // make_transform_output_iterator
|
#include <thrust/iterator/transform_output_iterator.h> // make_transform_output_iterator
|
||||||
#include <thrust/logical.h>
|
#include <thrust/logical.h>
|
||||||
#include <thrust/sequence.h>
|
#include <thrust/sequence.h>
|
||||||
#include <thrust/sort.h>
|
|
||||||
#include <thrust/system/cuda/error.h>
|
#include <thrust/system/cuda/error.h>
|
||||||
#include <thrust/system_error.h>
|
#include <thrust/system_error.h>
|
||||||
#include <thrust/transform_scan.h>
|
#include <thrust/transform_scan.h>
|
||||||
@ -301,22 +300,23 @@ class MemoryLogger {
|
|||||||
void RegisterAllocation(void *ptr, size_t n) {
|
void RegisterAllocation(void *ptr, size_t n) {
|
||||||
device_allocations[ptr] = n;
|
device_allocations[ptr] = n;
|
||||||
currently_allocated_bytes += n;
|
currently_allocated_bytes += n;
|
||||||
peak_allocated_bytes =
|
peak_allocated_bytes = std::max(peak_allocated_bytes, currently_allocated_bytes);
|
||||||
std::max(peak_allocated_bytes, currently_allocated_bytes);
|
|
||||||
num_allocations++;
|
num_allocations++;
|
||||||
CHECK_GT(num_allocations, num_deallocations);
|
CHECK_GT(num_allocations, num_deallocations);
|
||||||
}
|
}
|
||||||
void RegisterDeallocation(void *ptr, size_t n, int current_device) {
|
void RegisterDeallocation(void *ptr, size_t n, int current_device) {
|
||||||
auto itr = device_allocations.find(ptr);
|
auto itr = device_allocations.find(ptr);
|
||||||
if (itr == device_allocations.end()) {
|
if (itr == device_allocations.end()) {
|
||||||
LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device "
|
LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device " << current_device
|
||||||
<< current_device << " that was never allocated ";
|
<< " that was never allocated\n"
|
||||||
}
|
<< dmlc::StackTrace();
|
||||||
|
} else {
|
||||||
num_deallocations++;
|
num_deallocations++;
|
||||||
CHECK_LE(num_deallocations, num_allocations);
|
CHECK_LE(num_deallocations, num_allocations);
|
||||||
currently_allocated_bytes -= itr->second;
|
currently_allocated_bytes -= itr->second;
|
||||||
device_allocations.erase(itr);
|
device_allocations.erase(itr);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
DeviceStats stats_;
|
DeviceStats stats_;
|
||||||
std::mutex mutex_;
|
std::mutex mutex_;
|
||||||
|
|||||||
@ -11,7 +11,7 @@
|
|||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
|
|
||||||
namespace xgboost::error {
|
namespace xgboost::error {
|
||||||
std::string DeprecatedFunc(StringView old, StringView since, StringView replacement) {
|
[[nodiscard]] std::string DeprecatedFunc(StringView old, StringView since, StringView replacement) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "`" << old << "` is deprecated since" << since << ", use `" << replacement << "` instead.";
|
ss << "`" << old << "` is deprecated since" << since << ", use `" << replacement << "` instead.";
|
||||||
return ss.str();
|
return ss.str();
|
||||||
|
|||||||
@ -89,7 +89,7 @@ void WarnDeprecatedGPUId();
|
|||||||
|
|
||||||
void WarnEmptyDataset();
|
void WarnEmptyDataset();
|
||||||
|
|
||||||
std::string DeprecatedFunc(StringView old, StringView since, StringView replacement);
|
[[nodiscard]] std::string DeprecatedFunc(StringView old, StringView since, StringView replacement);
|
||||||
|
|
||||||
constexpr StringView InvalidCUDAOrdinal() {
|
constexpr StringView InvalidCUDAOrdinal() {
|
||||||
return "Invalid device. `device` is required to be CUDA and there must be at least one GPU "
|
return "Invalid device. `device` is required to be CUDA and there must be at least one GPU "
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
#define COMMON_HIST_UTIL_CUH_
|
#define COMMON_HIST_UTIL_CUH_
|
||||||
|
|
||||||
#include <thrust/host_vector.h>
|
#include <thrust/host_vector.h>
|
||||||
|
#include <thrust/sort.h> // for sort
|
||||||
|
|
||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,6 @@
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <mutex>
|
|
||||||
|
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/host_device_vector.h"
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
#include "quantile.h"
|
#include "quantile.h"
|
||||||
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
#include <numeric> // for partial_sum
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "../collective/aggregator.h"
|
#include "../collective/aggregator.h"
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2020-2023 by XGBoost Contributors
|
* Copyright 2020-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <thrust/binary_search.h>
|
#include <thrust/binary_search.h>
|
||||||
#include <thrust/execution_policy.h>
|
#include <thrust/execution_policy.h>
|
||||||
@ -9,7 +9,7 @@
|
|||||||
#include <thrust/unique.h>
|
#include <thrust/unique.h>
|
||||||
|
|
||||||
#include <limits> // std::numeric_limits
|
#include <limits> // std::numeric_limits
|
||||||
#include <memory>
|
#include <numeric> // for partial_sum
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "../collective/communicator-inl.cuh"
|
#include "../collective/communicator-inl.cuh"
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020-2024, XGBoost Contributors
|
||||||
|
*/
|
||||||
#ifndef XGBOOST_COMMON_QUANTILE_CUH_
|
#ifndef XGBOOST_COMMON_QUANTILE_CUH_
|
||||||
#define XGBOOST_COMMON_QUANTILE_CUH_
|
#define XGBOOST_COMMON_QUANTILE_CUH_
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "xgboost/span.h"
|
#include "xgboost/span.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
#include "device_helpers.cuh"
|
#include "device_helpers.cuh"
|
||||||
|
|||||||
@ -1,9 +1,8 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright by Contributors 2019
|
* Copyright 2019-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include "timer.h"
|
#include "timer.h"
|
||||||
|
|
||||||
#include <sstream>
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "../collective/communicator-inl.h"
|
#include "../collective/communicator-inl.h"
|
||||||
@ -61,6 +60,9 @@ void Monitor::Print() const {
|
|||||||
kv.second.timer.elapsed)
|
kv.second.timer.elapsed)
|
||||||
.count());
|
.count());
|
||||||
}
|
}
|
||||||
|
if (stat_map.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
LOG(CONSOLE) << "======== Monitor (" << rank << "): " << label_ << " ========";
|
LOG(CONSOLE) << "======== Monitor (" << rank << "): " << label_ << " ========";
|
||||||
this->PrintStatistics(stat_map);
|
this->PrintStatistics(stat_map);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,7 +11,6 @@
|
|||||||
#include <cmath> // for abs
|
#include <cmath> // for abs
|
||||||
#include <cstdint> // for uint64_t, int32_t, uint8_t, uint32_t
|
#include <cstdint> // for uint64_t, int32_t, uint8_t, uint32_t
|
||||||
#include <cstring> // for size_t, strcmp, memcpy
|
#include <cstring> // for size_t, strcmp, memcpy
|
||||||
#include <exception> // for exception
|
|
||||||
#include <iostream> // for operator<<, basic_ostream, basic_ostream::op...
|
#include <iostream> // for operator<<, basic_ostream, basic_ostream::op...
|
||||||
#include <map> // for map, operator!=
|
#include <map> // for map, operator!=
|
||||||
#include <numeric> // for accumulate, partial_sum
|
#include <numeric> // for accumulate, partial_sum
|
||||||
@ -22,7 +21,6 @@
|
|||||||
#include "../collective/communicator.h" // for Operation
|
#include "../collective/communicator.h" // for Operation
|
||||||
#include "../common/algorithm.h" // for StableSort
|
#include "../common/algorithm.h" // for StableSort
|
||||||
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
|
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
|
||||||
#include "../common/common.h" // for Split
|
|
||||||
#include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData
|
#include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData
|
||||||
#include "../common/group_data.h" // for ParallelGroupBuilder
|
#include "../common/group_data.h" // for ParallelGroupBuilder
|
||||||
#include "../common/io.h" // for PeekableInStream
|
#include "../common/io.h" // for PeekableInStream
|
||||||
@ -473,11 +471,11 @@ void MetaInfo::SetInfo(Context const& ctx, StringView key, StringView interface_
|
|||||||
<< ", must have at least 1 column even if it's empty.";
|
<< ", must have at least 1 column even if it's empty.";
|
||||||
auto const& first = get<Object const>(array.front());
|
auto const& first = get<Object const>(array.front());
|
||||||
auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(first);
|
auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(first);
|
||||||
is_cuda = ArrayInterfaceHandler::IsCudaPtr(ptr);
|
is_cuda = first.find("stream") != first.cend() || ArrayInterfaceHandler::IsCudaPtr(ptr);
|
||||||
} else {
|
} else {
|
||||||
auto const& first = get<Object const>(j_interface);
|
auto const& first = get<Object const>(j_interface);
|
||||||
auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(first);
|
auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(first);
|
||||||
is_cuda = ArrayInterfaceHandler::IsCudaPtr(ptr);
|
is_cuda = first.find("stream") != first.cend() || ArrayInterfaceHandler::IsCudaPtr(ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_cuda) {
|
if (is_cuda) {
|
||||||
@ -567,46 +565,6 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void MetaInfo::SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype,
|
|
||||||
size_t num) {
|
|
||||||
CHECK(key);
|
|
||||||
auto proc = [&](auto cast_d_ptr) {
|
|
||||||
using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
|
|
||||||
auto t = linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, DeviceOrd::CPU());
|
|
||||||
CHECK(t.CContiguous());
|
|
||||||
Json interface {
|
|
||||||
linalg::ArrayInterface(t)
|
|
||||||
};
|
|
||||||
assert(ArrayInterface<1>{interface}.is_contiguous);
|
|
||||||
return interface;
|
|
||||||
};
|
|
||||||
// Legacy code using XGBoost dtype, which is a small subset of array interface types.
|
|
||||||
switch (dtype) {
|
|
||||||
case xgboost::DataType::kFloat32: {
|
|
||||||
auto cast_ptr = reinterpret_cast<const float*>(dptr);
|
|
||||||
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case xgboost::DataType::kDouble: {
|
|
||||||
auto cast_ptr = reinterpret_cast<const double*>(dptr);
|
|
||||||
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case xgboost::DataType::kUInt32: {
|
|
||||||
auto cast_ptr = reinterpret_cast<const uint32_t*>(dptr);
|
|
||||||
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case xgboost::DataType::kUInt64: {
|
|
||||||
auto cast_ptr = reinterpret_cast<const uint64_t*>(dptr);
|
|
||||||
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
LOG(FATAL) << "Unknown data type" << static_cast<uint8_t>(dtype);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
|
void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
|
||||||
const void** out_dptr) const {
|
const void** out_dptr) const {
|
||||||
if (dtype == DataType::kFloat32) {
|
if (dtype == DataType::kFloat32) {
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2021-2023, XGBoost contributors
|
* Copyright 2021-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include "file_iterator.h"
|
#include "file_iterator.h"
|
||||||
|
|
||||||
@ -11,6 +11,9 @@
|
|||||||
#include <vector> // for vector
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../common/common.h" // for Split
|
#include "../common/common.h" // for Split
|
||||||
|
#include "xgboost/linalg.h" // for ArrayInterfaceStr, MakeVec
|
||||||
|
#include "xgboost/linalg.h"
|
||||||
|
#include "xgboost/logging.h" // for CHECK
|
||||||
#include "xgboost/string_view.h" // for operator<<, StringView
|
#include "xgboost/string_view.h" // for operator<<, StringView
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
@ -28,10 +31,10 @@ std::string ValidateFileFormat(std::string const& uri) {
|
|||||||
for (size_t i = 0; i < arg_list.size(); ++i) {
|
for (size_t i = 0; i < arg_list.size(); ++i) {
|
||||||
std::istringstream is(arg_list[i]);
|
std::istringstream is(arg_list[i]);
|
||||||
std::pair<std::string, std::string> kv;
|
std::pair<std::string, std::string> kv;
|
||||||
CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format"
|
CHECK(std::getline(is, kv.first, '='))
|
||||||
<< " for key in arg " << i + 1;
|
<< "Invalid uri argument format" << " for key in arg " << i + 1;
|
||||||
CHECK(std::getline(is, kv.second)) << "Invalid uri argument format"
|
CHECK(std::getline(is, kv.second))
|
||||||
<< " for value in arg " << i + 1;
|
<< "Invalid uri argument format" << " for value in arg " << i + 1;
|
||||||
args.insert(kv);
|
args.insert(kv);
|
||||||
}
|
}
|
||||||
if (args.find("format") == args.cend()) {
|
if (args.find("format") == args.cend()) {
|
||||||
@ -48,4 +51,41 @@ std::string ValidateFileFormat(std::string const& uri) {
|
|||||||
return name_args[0] + "?" + name_args[1] + '#' + name_args_cache[1];
|
return name_args[0] + "?" + name_args[1] + '#' + name_args_cache[1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int FileIterator::Next() {
|
||||||
|
CHECK(parser_);
|
||||||
|
if (parser_->Next()) {
|
||||||
|
row_block_ = parser_->Value();
|
||||||
|
|
||||||
|
indptr_ = linalg::Make1dInterface(row_block_.offset, row_block_.size + 1);
|
||||||
|
values_ = linalg::Make1dInterface(row_block_.value, row_block_.offset[row_block_.size]);
|
||||||
|
indices_ = linalg::Make1dInterface(row_block_.index, row_block_.offset[row_block_.size]);
|
||||||
|
|
||||||
|
size_t n_columns =
|
||||||
|
*std::max_element(row_block_.index, row_block_.index + row_block_.offset[row_block_.size]);
|
||||||
|
// dmlc parser converts 1-based indexing back to 0-based indexing so we can ignore
|
||||||
|
// this condition and just add 1 to n_columns
|
||||||
|
n_columns += 1;
|
||||||
|
|
||||||
|
XGProxyDMatrixSetDataCSR(proxy_, indptr_.c_str(), indices_.c_str(), values_.c_str(), n_columns);
|
||||||
|
|
||||||
|
if (row_block_.label) {
|
||||||
|
auto str = linalg::Make1dInterface(row_block_.label, row_block_.size);
|
||||||
|
XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str());
|
||||||
|
}
|
||||||
|
if (row_block_.qid) {
|
||||||
|
auto str = linalg::Make1dInterface(row_block_.qid, row_block_.size);
|
||||||
|
XGDMatrixSetInfoFromInterface(proxy_, "qid", str.c_str());
|
||||||
|
}
|
||||||
|
if (row_block_.weight) {
|
||||||
|
auto str = linalg::Make1dInterface(row_block_.weight, row_block_.size);
|
||||||
|
XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str());
|
||||||
|
}
|
||||||
|
// Continue iteration
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
// Stop iteration
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace xgboost::data
|
} // namespace xgboost::data
|
||||||
|
|||||||
@ -1,20 +1,16 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2021-2023, XGBoost contributors
|
* Copyright 2021-2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_DATA_FILE_ITERATOR_H_
|
#ifndef XGBOOST_DATA_FILE_ITERATOR_H_
|
||||||
#define XGBOOST_DATA_FILE_ITERATOR_H_
|
#define XGBOOST_DATA_FILE_ITERATOR_H_
|
||||||
|
|
||||||
#include <algorithm> // for max_element
|
|
||||||
#include <cstddef> // for size_t
|
|
||||||
#include <cstdint> // for uint32_t
|
#include <cstdint> // for uint32_t
|
||||||
#include <memory> // for unique_ptr
|
#include <memory> // for unique_ptr
|
||||||
#include <string> // for string
|
#include <string> // for string
|
||||||
#include <utility> // for move
|
#include <utility> // for move
|
||||||
|
|
||||||
#include "dmlc/data.h" // for RowBlock, Parser
|
#include "dmlc/data.h" // for RowBlock, Parser
|
||||||
#include "xgboost/c_api.h" // for XGDMatrixSetDenseInfo, XGDMatrixFree, XGProxyDMatrixCreate
|
#include "xgboost/c_api.h" // for XGDMatrixFree, XGProxyDMatrixCreate
|
||||||
#include "xgboost/linalg.h" // for ArrayInterfaceStr, MakeVec
|
|
||||||
#include "xgboost/logging.h" // for CHECK
|
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
[[nodiscard]] std::string ValidateFileFormat(std::string const& uri);
|
[[nodiscard]] std::string ValidateFileFormat(std::string const& uri);
|
||||||
@ -53,41 +49,7 @@ class FileIterator {
|
|||||||
XGDMatrixFree(proxy_);
|
XGDMatrixFree(proxy_);
|
||||||
}
|
}
|
||||||
|
|
||||||
int Next() {
|
int Next();
|
||||||
CHECK(parser_);
|
|
||||||
if (parser_->Next()) {
|
|
||||||
row_block_ = parser_->Value();
|
|
||||||
using linalg::MakeVec;
|
|
||||||
|
|
||||||
indptr_ = ArrayInterfaceStr(MakeVec(row_block_.offset, row_block_.size + 1));
|
|
||||||
values_ = ArrayInterfaceStr(MakeVec(row_block_.value, row_block_.offset[row_block_.size]));
|
|
||||||
indices_ = ArrayInterfaceStr(MakeVec(row_block_.index, row_block_.offset[row_block_.size]));
|
|
||||||
|
|
||||||
size_t n_columns = *std::max_element(row_block_.index,
|
|
||||||
row_block_.index + row_block_.offset[row_block_.size]);
|
|
||||||
// dmlc parser converts 1-based indexing back to 0-based indexing so we can ignore
|
|
||||||
// this condition and just add 1 to n_columns
|
|
||||||
n_columns += 1;
|
|
||||||
|
|
||||||
XGProxyDMatrixSetDataCSR(proxy_, indptr_.c_str(), indices_.c_str(),
|
|
||||||
values_.c_str(), n_columns);
|
|
||||||
|
|
||||||
if (row_block_.label) {
|
|
||||||
XGDMatrixSetDenseInfo(proxy_, "label", row_block_.label, row_block_.size, 1);
|
|
||||||
}
|
|
||||||
if (row_block_.qid) {
|
|
||||||
XGDMatrixSetDenseInfo(proxy_, "qid", row_block_.qid, row_block_.size, 1);
|
|
||||||
}
|
|
||||||
if (row_block_.weight) {
|
|
||||||
XGDMatrixSetDenseInfo(proxy_, "weight", row_block_.weight, row_block_.size, 1);
|
|
||||||
}
|
|
||||||
// Continue iteration
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
// Stop iteration
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto Proxy() -> decltype(proxy_) { return proxy_; }
|
auto Proxy() -> decltype(proxy_) { return proxy_; }
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2014-2023, XGBoost Contributors
|
* Copyright 2014-2024, XGBoost Contributors
|
||||||
* \file sparse_page_source.h
|
* \file sparse_page_source.h
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
|
#ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
|
||||||
@ -7,23 +7,26 @@
|
|||||||
|
|
||||||
#include <algorithm> // for min
|
#include <algorithm> // for min
|
||||||
#include <atomic> // for atomic
|
#include <atomic> // for atomic
|
||||||
|
#include <cstdio> // for remove
|
||||||
#include <future> // for async
|
#include <future> // for async
|
||||||
#include <map>
|
#include <memory> // for unique_ptr
|
||||||
#include <memory>
|
|
||||||
#include <mutex> // for mutex
|
#include <mutex> // for mutex
|
||||||
#include <string>
|
#include <string> // for string
|
||||||
#include <thread>
|
|
||||||
#include <utility> // for pair, move
|
#include <utility> // for pair, move
|
||||||
#include <vector>
|
#include <vector> // for vector
|
||||||
|
|
||||||
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
|
#include "../common/common.h" // for AssertGPUSupport
|
||||||
|
#endif // !defined(XGBOOST_USE_CUDA)
|
||||||
|
|
||||||
#include "../common/common.h"
|
|
||||||
#include "../common/io.h" // for PrivateMmapConstStream
|
#include "../common/io.h" // for PrivateMmapConstStream
|
||||||
#include "../common/timer.h" // for Monitor, Timer
|
#include "../common/timer.h" // for Monitor, Timer
|
||||||
#include "adapter.h"
|
|
||||||
#include "proxy_dmatrix.h" // for DMatrixProxy
|
#include "proxy_dmatrix.h" // for DMatrixProxy
|
||||||
#include "sparse_page_writer.h" // for SparsePageFormat
|
#include "sparse_page_writer.h" // for SparsePageFormat
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h" // for bst_feature_t
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h" // for SparsePage, CSCPage
|
||||||
|
#include "xgboost/global_config.h" // for GlobalConfigThreadLocalStore
|
||||||
|
#include "xgboost/logging.h" // for CHECK_EQ
|
||||||
|
|
||||||
namespace xgboost::data {
|
namespace xgboost::data {
|
||||||
inline void TryDeleteCacheFile(const std::string& file) {
|
inline void TryDeleteCacheFile(const std::string& file) {
|
||||||
@ -185,6 +188,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
|
|
||||||
exce_.Rethrow();
|
exce_.Rethrow();
|
||||||
|
|
||||||
|
auto const config = *GlobalConfigThreadLocalStore::Get();
|
||||||
for (std::int32_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
|
for (std::int32_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
|
||||||
fetch_it %= n_batches_; // ring
|
fetch_it %= n_batches_; // ring
|
||||||
if (ring_->at(fetch_it).valid()) {
|
if (ring_->at(fetch_it).valid()) {
|
||||||
@ -192,7 +196,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
}
|
}
|
||||||
auto const* self = this; // make sure it's const
|
auto const* self = this; // make sure it's const
|
||||||
CHECK_LT(fetch_it, cache_info_->offset.size());
|
CHECK_LT(fetch_it, cache_info_->offset.size());
|
||||||
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, this]() {
|
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, config, this]() {
|
||||||
|
*GlobalConfigThreadLocalStore::Get() = config;
|
||||||
auto page = std::make_shared<S>();
|
auto page = std::make_shared<S>();
|
||||||
this->exce_.Run([&] {
|
this->exce_.Run([&] {
|
||||||
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
|
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2014-2023 by Contributors
|
* Copyright 2014-2024, XGBoost Contributors
|
||||||
* \file gbtree.cc
|
* \file gbtree.cc
|
||||||
* \brief gradient boosted tree implementation.
|
* \brief gradient boosted tree implementation.
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
@ -11,14 +11,12 @@
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstdint> // std::int32_t
|
#include <cstdint> // std::int32_t
|
||||||
#include <map>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <numeric> // for iota
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../common/common.h"
|
|
||||||
#include "../common/timer.h"
|
#include "../common/timer.h"
|
||||||
#include "../tree/param.h" // TrainParam
|
#include "../tree/param.h" // TrainParam
|
||||||
#include "gbtree_model.h"
|
#include "gbtree_model.h"
|
||||||
|
|||||||
@ -10,15 +10,15 @@
|
|||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <numeric> // for accumulate
|
||||||
|
|
||||||
#include "../collective/communicator-inl.h"
|
#include "../common/common.h" // for AssertGPUSupport
|
||||||
#include "../common/common.h" // MetricNoCache
|
|
||||||
#include "../common/math.h"
|
#include "../common/math.h"
|
||||||
#include "../common/optional_weight.h" // OptionalWeights
|
#include "../common/optional_weight.h" // OptionalWeights
|
||||||
#include "../common/pseudo_huber.h"
|
#include "../common/pseudo_huber.h"
|
||||||
#include "../common/quantile_loss_utils.h" // QuantileLossParam
|
#include "../common/quantile_loss_utils.h" // QuantileLossParam
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
#include "metric_common.h"
|
#include "metric_common.h" // MetricNoCache
|
||||||
#include "xgboost/collective/result.h" // for SafeColl
|
#include "xgboost/collective/result.h" // for SafeColl
|
||||||
#include "xgboost/metric.h"
|
#include "xgboost/metric.h"
|
||||||
|
|
||||||
|
|||||||
@ -9,8 +9,6 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "../collective/aggregator.h"
|
#include "../collective/aggregator.h"
|
||||||
#include "../collective/communicator-inl.h"
|
|
||||||
#include "../common/common.h"
|
|
||||||
#include "xgboost/metric.h"
|
#include "xgboost/metric.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|||||||
@ -9,8 +9,8 @@
|
|||||||
#include <array>
|
#include <array>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <numeric> // for accumulate
|
||||||
|
|
||||||
#include "../collective/communicator-inl.h"
|
|
||||||
#include "../common/math.h"
|
#include "../common/math.h"
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
#include "metric_common.h" // MetricNoCache
|
#include "metric_common.h" // MetricNoCache
|
||||||
|
|||||||
@ -9,10 +9,9 @@
|
|||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <numeric> // for accumulate
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../collective/communicator-inl.h"
|
|
||||||
#include "../common/math.h"
|
|
||||||
#include "../common/survival_util.h"
|
#include "../common/survival_util.h"
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
#include "metric_common.h" // MetricNoCache
|
#include "metric_common.h" // MetricNoCache
|
||||||
|
|||||||
@ -1,13 +1,13 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2019-2023 by XGBoost Contributors
|
* Copyright 2019-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <thrust/functional.h>
|
#include <thrust/functional.h>
|
||||||
#include <thrust/random.h>
|
#include <thrust/random.h>
|
||||||
|
#include <thrust/sort.h> // for sort
|
||||||
#include <thrust/transform.h>
|
#include <thrust/transform.h>
|
||||||
#include <xgboost/host_device_vector.h>
|
#include <xgboost/host_device_vector.h>
|
||||||
#include <xgboost/logging.h>
|
#include <xgboost/logging.h>
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user