diff --git a/.github/workflows/jvm_tests.yml b/.github/workflows/jvm_tests.yml
index 8efcdc2ec..a2d8bb69a 100644
--- a/.github/workflows/jvm_tests.yml
+++ b/.github/workflows/jvm_tests.yml
@@ -40,7 +40,7 @@ jobs:
key: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }}
restore-keys: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }}
- - name: Test XGBoost4J
+ - name: Test XGBoost4J (Core)
run: |
cd jvm-packages
mvn test -B -pl :xgboost4j_2.12
@@ -67,7 +67,7 @@ jobs:
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_IAM_S3_UPLOADER }}
- - name: Test XGBoost4J-Spark
+ - name: Test XGBoost4J (Core, Spark, Examples)
run: |
rm -rfv build/
cd jvm-packages
diff --git a/.github/workflows/python_tests.yml b/.github/workflows/python_tests.yml
index 0d8e6d653..78a17d3f7 100644
--- a/.github/workflows/python_tests.yml
+++ b/.github/workflows/python_tests.yml
@@ -65,7 +65,7 @@ jobs:
run: |
cd python-package
python --version
- python setup.py sdist
+ python -m build --sdist
pip install -v ./dist/xgboost-*.tar.gz
cd ..
python -c 'import xgboost'
@@ -92,6 +92,9 @@ jobs:
auto-update-conda: true
python-version: ${{ matrix.python-version }}
activate-environment: test
+ - name: Install build
+ run: |
+ conda install -c conda-forge python-build
- name: Display Conda env
run: |
conda info
@@ -100,7 +103,7 @@ jobs:
run: |
cd python-package
python --version
- python setup.py sdist
+ python -m build --sdist
pip install -v ./dist/xgboost-*.tar.gz
cd ..
python -c 'import xgboost'
@@ -147,7 +150,7 @@ jobs:
run: |
cd python-package
python --version
- python setup.py install
+ pip install -v .
- name: Test Python package
run: |
@@ -194,7 +197,7 @@ jobs:
run: |
cd python-package
python --version
- python setup.py bdist_wheel --universal
+ pip wheel -v . --wheel-dir dist/
pip install ./dist/*.whl
- name: Test Python package
@@ -238,7 +241,7 @@ jobs:
run: |
cd python-package
python --version
- python setup.py install
+ pip install -v .
- name: Test Python package
run: |
diff --git a/.github/workflows/r_tests.yml b/.github/workflows/r_tests.yml
index 0ec95ace1..640ebce81 100644
--- a/.github/workflows/r_tests.yml
+++ b/.github/workflows/r_tests.yml
@@ -54,7 +54,7 @@ jobs:
matrix:
config:
- {os: windows-latest, r: 'release', compiler: 'mingw', build: 'autotools'}
- - {os: windows-latest, r: 'release', compiler: 'msvc', build: 'cmake'}
+ - {os: windows-latest, r: '4.2.0', compiler: 'msvc', build: 'cmake'}
env:
R_REMOTES_NO_ERRORS_FROM_WARNINGS: true
RSPM: ${{ matrix.config.rspm }}
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4cc47fa6a..2d3fdc728 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -47,6 +47,7 @@ option(USE_NVTX "Build with cuda profiling annotations. Developers only." OFF)
set(NVTX_HEADER_DIR "" CACHE PATH "Path to the stand-alone nvtx header")
option(RABIT_MOCK "Build rabit with mock" OFF)
option(HIDE_CXX_SYMBOLS "Build shared library and hide all C++ symbols" OFF)
+option(KEEP_BUILD_ARTIFACTS_IN_BINARY_DIR "Output build artifacts in CMake binary dir" OFF)
## CUDA
option(USE_CUDA "Build with GPU acceleration" OFF)
option(USE_NCCL "Build with NCCL to enable distributed GPU support." OFF)
@@ -312,8 +313,13 @@ if (JVM_BINDINGS)
xgboost_target_defs(xgboost4j)
endif (JVM_BINDINGS)
-set_output_directory(runxgboost ${xgboost_SOURCE_DIR})
-set_output_directory(xgboost ${xgboost_SOURCE_DIR}/lib)
+if (KEEP_BUILD_ARTIFACTS_IN_BINARY_DIR)
+ set_output_directory(runxgboost ${xgboost_BINARY_DIR})
+ set_output_directory(xgboost ${xgboost_BINARY_DIR}/lib)
+else ()
+ set_output_directory(runxgboost ${xgboost_SOURCE_DIR})
+ set_output_directory(xgboost ${xgboost_SOURCE_DIR}/lib)
+endif ()
# Ensure these two targets do not build simultaneously, as they produce outputs with conflicting names
add_dependencies(xgboost runxgboost)
diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in
index 743bf0a66..a84459db9 100644
--- a/R-package/src/Makevars.in
+++ b/R-package/src/Makevars.in
@@ -32,7 +32,7 @@ OBJECTS= \
$(PKGROOT)/src/objective/objective.o \
$(PKGROOT)/src/objective/regression_obj.o \
$(PKGROOT)/src/objective/multiclass_obj.o \
- $(PKGROOT)/src/objective/rank_obj.o \
+ $(PKGROOT)/src/objective/lambdarank_obj.o \
$(PKGROOT)/src/objective/hinge.o \
$(PKGROOT)/src/objective/aft_obj.o \
$(PKGROOT)/src/objective/adaptive.o \
diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win
index a32d2fd2e..25c577e3a 100644
--- a/R-package/src/Makevars.win
+++ b/R-package/src/Makevars.win
@@ -32,7 +32,7 @@ OBJECTS= \
$(PKGROOT)/src/objective/objective.o \
$(PKGROOT)/src/objective/regression_obj.o \
$(PKGROOT)/src/objective/multiclass_obj.o \
- $(PKGROOT)/src/objective/rank_obj.o \
+ $(PKGROOT)/src/objective/lambdarank_obj.o \
$(PKGROOT)/src/objective/hinge.o \
$(PKGROOT)/src/objective/aft_obj.o \
$(PKGROOT)/src/objective/adaptive.o \
diff --git a/R-package/tests/testthat/test_dmatrix.R b/R-package/tests/testthat/test_dmatrix.R
index 1d8cb0f23..21d39f255 100644
--- a/R-package/tests/testthat/test_dmatrix.R
+++ b/R-package/tests/testthat/test_dmatrix.R
@@ -72,7 +72,7 @@ test_that("xgb.DMatrix: saving, loading", {
tmp <- c("0 1:1 2:1", "1 3:1", "0 1:1")
tmp_file <- tempfile(fileext = ".libsvm")
writeLines(tmp, tmp_file)
- dtest4 <- xgb.DMatrix(tmp_file, silent = TRUE)
+ dtest4 <- xgb.DMatrix(paste(tmp_file, "?format=libsvm", sep = ""), silent = TRUE)
expect_equal(dim(dtest4), c(3, 4))
expect_equal(getinfo(dtest4, 'label'), c(0, 1, 0))
diff --git a/demo/CLI/binary_classification/mushroom.conf b/demo/CLI/binary_classification/mushroom.conf
index 3cf865465..d78199cd7 100644
--- a/demo/CLI/binary_classification/mushroom.conf
+++ b/demo/CLI/binary_classification/mushroom.conf
@@ -20,10 +20,10 @@ num_round = 2
# 0 means do not save any model except the final round model
save_period = 2
# The path of training data
-data = "agaricus.txt.train"
+data = "agaricus.txt.train?format=libsvm"
# The path of validation data, used to monitor training process, here [test] sets name of the validation set
-eval[test] = "agaricus.txt.test"
+eval[test] = "agaricus.txt.test?format=libsvm"
# evaluate on training data as well each round
eval_train = 1
# The path of test data
-test:data = "agaricus.txt.test"
+test:data = "agaricus.txt.test?format=libsvm"
diff --git a/demo/CLI/regression/machine.conf b/demo/CLI/regression/machine.conf
index 4ba8437d5..42e2b1227 100644
--- a/demo/CLI/regression/machine.conf
+++ b/demo/CLI/regression/machine.conf
@@ -21,8 +21,8 @@ num_round = 2
# 0 means do not save any model except the final round model
save_period = 0
# The path of training data
-data = "machine.txt.train"
+data = "machine.txt.train?format=libsvm"
# The path of validation data, used to monitor training process, here [test] sets name of the validation set
-eval[test] = "machine.txt.test"
+eval[test] = "machine.txt.test?format=libsvm"
# The path of test data
-test:data = "machine.txt.test"
+test:data = "machine.txt.test?format=libsvm"
diff --git a/demo/c-api/basic/c-api-demo.c b/demo/c-api/basic/c-api-demo.c
index ca6e689aa..15a224e9e 100644
--- a/demo/c-api/basic/c-api-demo.c
+++ b/demo/c-api/basic/c-api-demo.c
@@ -42,8 +42,8 @@ int main() {
// load the data
DMatrixHandle dtrain, dtest;
- safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.train", silent, &dtrain));
- safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.test", silent, &dtest));
+ safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.train?format=libsvm", silent, &dtrain));
+ safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.test?format=libsvm", silent, &dtest));
// create the booster
BoosterHandle booster;
diff --git a/demo/guide-python/boost_from_prediction.py b/demo/guide-python/boost_from_prediction.py
index 53a45549a..13f91d7c8 100644
--- a/demo/guide-python/boost_from_prediction.py
+++ b/demo/guide-python/boost_from_prediction.py
@@ -7,15 +7,19 @@ import os
import xgboost as xgb
CURRENT_DIR = os.path.dirname(__file__)
-dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
-dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test'))
-watchlist = [(dtest, 'eval'), (dtrain, 'train')]
+dtrain = xgb.DMatrix(
+ os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
+)
+dtest = xgb.DMatrix(
+ os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
+)
+watchlist = [(dtest, "eval"), (dtrain, "train")]
###
# advanced: start from a initial base prediction
#
-print('start running example to start from a initial prediction')
+print("start running example to start from a initial prediction")
# specify parameters via map, definition are same as c++ version
-param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
+param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
# train xgboost for 1 round
bst = xgb.train(param, dtrain, 1, watchlist)
# Note: we need the margin value instead of transformed prediction in
@@ -27,5 +31,5 @@ ptest = bst.predict(dtest, output_margin=True)
dtrain.set_base_margin(ptrain)
dtest.set_base_margin(ptest)
-print('this is result of running from initial prediction')
+print("this is result of running from initial prediction")
bst = xgb.train(param, dtrain, 1, watchlist)
diff --git a/demo/guide-python/cross_validation.py b/demo/guide-python/cross_validation.py
index 2565b02c9..4e537108a 100644
--- a/demo/guide-python/cross_validation.py
+++ b/demo/guide-python/cross_validation.py
@@ -10,27 +10,45 @@ import xgboost as xgb
# load data in do training
CURRENT_DIR = os.path.dirname(__file__)
-dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
-param = {'max_depth':2, 'eta':1, 'objective':'binary:logistic'}
+dtrain = xgb.DMatrix(
+ os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
+)
+param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
num_round = 2
-print('running cross validation')
+print("running cross validation")
# do cross validation, this will print result out as
# [iteration] metric_name:mean_value+std_value
# std_value is standard deviation of the metric
-xgb.cv(param, dtrain, num_round, nfold=5,
- metrics={'error'}, seed=0,
- callbacks=[xgb.callback.EvaluationMonitor(show_stdv=True)])
+xgb.cv(
+ param,
+ dtrain,
+ num_round,
+ nfold=5,
+ metrics={"error"},
+ seed=0,
+ callbacks=[xgb.callback.EvaluationMonitor(show_stdv=True)],
+)
-print('running cross validation, disable standard deviation display')
+print("running cross validation, disable standard deviation display")
# do cross validation, this will print result out as
# [iteration] metric_name:mean_value
-res = xgb.cv(param, dtrain, num_boost_round=10, nfold=5,
- metrics={'error'}, seed=0,
- callbacks=[xgb.callback.EvaluationMonitor(show_stdv=False),
- xgb.callback.EarlyStopping(3)])
+res = xgb.cv(
+ param,
+ dtrain,
+ num_boost_round=10,
+ nfold=5,
+ metrics={"error"},
+ seed=0,
+ callbacks=[
+ xgb.callback.EvaluationMonitor(show_stdv=False),
+ xgb.callback.EarlyStopping(3),
+ ],
+)
print(res)
-print('running cross validation, with preprocessing function')
+print("running cross validation, with preprocessing function")
+
+
# define the preprocessing function
# used to return the preprocessed training, test data, and parameter
# we can use this to do weight rescale, etc.
@@ -38,32 +56,36 @@ print('running cross validation, with preprocessing function')
def fpreproc(dtrain, dtest, param):
label = dtrain.get_label()
ratio = float(np.sum(label == 0)) / np.sum(label == 1)
- param['scale_pos_weight'] = ratio
+ param["scale_pos_weight"] = ratio
return (dtrain, dtest, param)
+
# do cross validation, for each fold
# the dtrain, dtest, param will be passed into fpreproc
# then the return value of fpreproc will be used to generate
# results of that fold
-xgb.cv(param, dtrain, num_round, nfold=5,
- metrics={'auc'}, seed=0, fpreproc=fpreproc)
+xgb.cv(param, dtrain, num_round, nfold=5, metrics={"auc"}, seed=0, fpreproc=fpreproc)
###
# you can also do cross validation with customized loss function
# See custom_objective.py
##
-print('running cross validation, with customized loss function')
+print("running cross validation, with customized loss function")
+
+
def logregobj(preds, dtrain):
labels = dtrain.get_label()
preds = 1.0 / (1.0 + np.exp(-preds))
grad = preds - labels
hess = preds * (1.0 - preds)
return grad, hess
+
+
def evalerror(preds, dtrain):
labels = dtrain.get_label()
- return 'error', float(sum(labels != (preds > 0.0))) / len(labels)
+ return "error", float(sum(labels != (preds > 0.0))) / len(labels)
-param = {'max_depth':2, 'eta':1}
+
+param = {"max_depth": 2, "eta": 1}
# train with customized objective
-xgb.cv(param, dtrain, num_round, nfold=5, seed=0,
- obj=logregobj, feval=evalerror)
+xgb.cv(param, dtrain, num_round, nfold=5, seed=0, obj=logregobj, feval=evalerror)
diff --git a/demo/guide-python/evals_result.py b/demo/guide-python/evals_result.py
index bba8862f5..7b9da96da 100644
--- a/demo/guide-python/evals_result.py
+++ b/demo/guide-python/evals_result.py
@@ -7,28 +7,37 @@ import os
import xgboost as xgb
CURRENT_DIR = os.path.dirname(__file__)
-dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
-dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test'))
+dtrain = xgb.DMatrix(
+ os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
+)
+dtest = xgb.DMatrix(
+ os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
+)
-param = [('max_depth', 2), ('objective', 'binary:logistic'), ('eval_metric', 'logloss'), ('eval_metric', 'error')]
+param = [
+ ("max_depth", 2),
+ ("objective", "binary:logistic"),
+ ("eval_metric", "logloss"),
+ ("eval_metric", "error"),
+]
num_round = 2
-watchlist = [(dtest,'eval'), (dtrain,'train')]
+watchlist = [(dtest, "eval"), (dtrain, "train")]
evals_result = {}
bst = xgb.train(param, dtrain, num_round, watchlist, evals_result=evals_result)
-print('Access logloss metric directly from evals_result:')
-print(evals_result['eval']['logloss'])
+print("Access logloss metric directly from evals_result:")
+print(evals_result["eval"]["logloss"])
-print('')
-print('Access metrics through a loop:')
+print("")
+print("Access metrics through a loop:")
for e_name, e_mtrs in evals_result.items():
- print('- {}'.format(e_name))
+ print("- {}".format(e_name))
for e_mtr_name, e_mtr_vals in e_mtrs.items():
- print(' - {}'.format(e_mtr_name))
- print(' - {}'.format(e_mtr_vals))
+ print(" - {}".format(e_mtr_name))
+ print(" - {}".format(e_mtr_vals))
-print('')
-print('Access complete dictionary:')
+print("")
+print("Access complete dictionary:")
print(evals_result)
diff --git a/demo/guide-python/generalized_linear_model.py b/demo/guide-python/generalized_linear_model.py
index 976428f13..3387b1982 100644
--- a/demo/guide-python/generalized_linear_model.py
+++ b/demo/guide-python/generalized_linear_model.py
@@ -11,14 +11,22 @@ import xgboost as xgb
# basically, we are using linear model, instead of tree for our boosters
##
CURRENT_DIR = os.path.dirname(__file__)
-dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
-dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test'))
+dtrain = xgb.DMatrix(
+ os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
+)
+dtest = xgb.DMatrix(
+ os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
+)
# change booster to gblinear, so that we are fitting a linear model
# alpha is the L1 regularizer
# lambda is the L2 regularizer
# you can also set lambda_bias which is L2 regularizer on the bias term
-param = {'objective':'binary:logistic', 'booster':'gblinear',
- 'alpha': 0.0001, 'lambda': 1}
+param = {
+ "objective": "binary:logistic",
+ "booster": "gblinear",
+ "alpha": 0.0001,
+ "lambda": 1,
+}
# normally, you do not need to set eta (step_size)
# XGBoost uses a parallel coordinate descent algorithm (shotgun),
@@ -29,9 +37,15 @@ param = {'objective':'binary:logistic', 'booster':'gblinear',
##
# the rest of settings are the same
##
-watchlist = [(dtest, 'eval'), (dtrain, 'train')]
+watchlist = [(dtest, "eval"), (dtrain, "train")]
num_round = 4
bst = xgb.train(param, dtrain, num_round, watchlist)
preds = bst.predict(dtest)
labels = dtest.get_label()
-print('error=%f' % (sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds))))
+print(
+ "error=%f"
+ % (
+ sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i])
+ / float(len(preds))
+ )
+)
diff --git a/demo/guide-python/predict_first_ntree.py b/demo/guide-python/predict_first_ntree.py
index 55f7c61af..78137b4e1 100644
--- a/demo/guide-python/predict_first_ntree.py
+++ b/demo/guide-python/predict_first_ntree.py
@@ -16,8 +16,8 @@ test = os.path.join(CURRENT_DIR, "../data/agaricus.txt.test")
def native_interface():
# load data in do training
- dtrain = xgb.DMatrix(train)
- dtest = xgb.DMatrix(test)
+ dtrain = xgb.DMatrix(train + "?format=libsvm")
+ dtest = xgb.DMatrix(test + "?format=libsvm")
param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
watchlist = [(dtest, "eval"), (dtrain, "train")]
num_round = 3
diff --git a/demo/guide-python/predict_leaf_indices.py b/demo/guide-python/predict_leaf_indices.py
index 45cc8fa7f..627619724 100644
--- a/demo/guide-python/predict_leaf_indices.py
+++ b/demo/guide-python/predict_leaf_indices.py
@@ -8,14 +8,18 @@ import xgboost as xgb
# load data in do training
CURRENT_DIR = os.path.dirname(__file__)
-dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
-dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test'))
-param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
-watchlist = [(dtest, 'eval'), (dtrain, 'train')]
+dtrain = xgb.DMatrix(
+ os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
+)
+dtest = xgb.DMatrix(
+ os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
+)
+param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
+watchlist = [(dtest, "eval"), (dtrain, "train")]
num_round = 3
bst = xgb.train(param, dtrain, num_round, watchlist)
-print('start testing predict the leaf indices')
+print("start testing predict the leaf indices")
# predict using first 2 tree
leafindex = bst.predict(
dtest, iteration_range=(0, 2), pred_leaf=True, strict_shape=True
diff --git a/demo/nvflare/README.md b/demo/nvflare/README.md
index 328dd7212..93f388208 100644
--- a/demo/nvflare/README.md
+++ b/demo/nvflare/README.md
@@ -3,61 +3,12 @@
This directory contains a demo of Federated Learning using
[NVFlare](https://nvidia.github.io/NVFlare/).
-## Training with CPU only
+## Horizontal Federated XGBoost
-To run the demo, first build XGBoost with the federated learning plugin enabled (see the
-[README](../../plugin/federated/README.md)).
+For horizontal federated learning using XGBoost (data is split row-wise), check out the `horizontal` directory
+(see the [README](horizontal/README.md)).
-Install NVFlare (note that currently NVFlare only supports Python 3.8):
-```shell
-pip install nvflare
-```
+## Vertical Federated XGBoost
-Prepare the data:
-```shell
-./prepare_data.sh
-```
-
-Start the NVFlare federated server:
-```shell
-/tmp/nvflare/poc/server/startup/start.sh
-```
-
-In another terminal, start the first worker:
-```shell
-/tmp/nvflare/poc/site-1/startup/start.sh
-```
-
-And the second worker:
-```shell
-/tmp/nvflare/poc/site-2/startup/start.sh
-```
-
-Then start the admin CLI:
-```shell
-/tmp/nvflare/poc/admin/startup/fl_admin.sh
-```
-
-In the admin CLI, run the following command:
-```shell
-submit_job hello-xgboost
-```
-
-Once the training finishes, the model file should be written into
-`/tmp/nvlfare/poc/site-1/run_1/test.model.json` and `/tmp/nvflare/poc/site-2/run_1/test.model.json`
-respectively.
-
-Finally, shutdown everything from the admin CLI, using `admin` as password:
-```shell
-shutdown client
-shutdown server
-```
-
-## Training with GPUs
-
-To demo with Federated Learning using GPUs, make sure your machine has at least 2 GPUs.
-Build XGBoost with the federated learning plugin enabled along with CUDA, but with NCCL
-turned off (see the [README](../../plugin/federated/README.md)).
-
-Modify `config/config_fed_client.json` and set `use_gpus` to `true`, then repeat the steps
-above.
+For vertical federated learning using XGBoost (data is split column-wise), check out the `vertical` directory
+(see the [README](vertical/README.md)).
diff --git a/demo/nvflare/config/config_fed_client.json b/demo/nvflare/config/config_fed_client.json
deleted file mode 100755
index c15a1997c..000000000
--- a/demo/nvflare/config/config_fed_client.json
+++ /dev/null
@@ -1,23 +0,0 @@
-{
- "format_version": 2,
- "executors": [
- {
- "tasks": [
- "train"
- ],
- "executor": {
- "path": "trainer.XGBoostTrainer",
- "args": {
- "server_address": "localhost:9091",
- "world_size": 2,
- "server_cert_path": "server-cert.pem",
- "client_key_path": "client-key.pem",
- "client_cert_path": "client-cert.pem",
- "use_gpus": "false"
- }
- }
- }
- ],
- "task_result_filters": [],
- "task_data_filters": []
-}
diff --git a/demo/nvflare/config/config_fed_server.json b/demo/nvflare/config/config_fed_server.json
deleted file mode 100755
index 32993b652..000000000
--- a/demo/nvflare/config/config_fed_server.json
+++ /dev/null
@@ -1,22 +0,0 @@
-{
- "format_version": 2,
- "server": {
- "heart_beat_timeout": 600
- },
- "task_data_filters": [],
- "task_result_filters": [],
- "workflows": [
- {
- "id": "server_workflow",
- "path": "controller.XGBoostController",
- "args": {
- "port": 9091,
- "world_size": 2,
- "server_key_path": "server-key.pem",
- "server_cert_path": "server-cert.pem",
- "client_cert_path": "client-cert.pem"
- }
- }
- ],
- "components": []
-}
diff --git a/demo/nvflare/horizontal/README.md b/demo/nvflare/horizontal/README.md
new file mode 100644
index 000000000..93ea3794c
--- /dev/null
+++ b/demo/nvflare/horizontal/README.md
@@ -0,0 +1,63 @@
+# Experimental Support of Horizontal Federated XGBoost using NVFlare
+
+This directory contains a demo of Horizontal Federated Learning using
+[NVFlare](https://nvidia.github.io/NVFlare/).
+
+## Training with CPU only
+
+To run the demo, first build XGBoost with the federated learning plugin enabled (see the
+[README](../../plugin/federated/README.md)).
+
+Install NVFlare (note that currently NVFlare only supports Python 3.8):
+```shell
+pip install nvflare
+```
+
+Prepare the data:
+```shell
+./prepare_data.sh
+```
+
+Start the NVFlare federated server:
+```shell
+/tmp/nvflare/poc/server/startup/start.sh
+```
+
+In another terminal, start the first worker:
+```shell
+/tmp/nvflare/poc/site-1/startup/start.sh
+```
+
+And the second worker:
+```shell
+/tmp/nvflare/poc/site-2/startup/start.sh
+```
+
+Then start the admin CLI:
+```shell
+/tmp/nvflare/poc/admin/startup/fl_admin.sh
+```
+
+In the admin CLI, run the following command:
+```shell
+submit_job horizontal-xgboost
+```
+
+Once the training finishes, the model file should be written into
+`/tmp/nvlfare/poc/site-1/run_1/test.model.json` and `/tmp/nvflare/poc/site-2/run_1/test.model.json`
+respectively.
+
+Finally, shutdown everything from the admin CLI, using `admin` as password:
+```shell
+shutdown client
+shutdown server
+```
+
+## Training with GPUs
+
+To demo with Federated Learning using GPUs, make sure your machine has at least 2 GPUs.
+Build XGBoost with the federated learning plugin enabled along with CUDA, but with NCCL
+turned off (see the [README](../../plugin/federated/README.md)).
+
+Modify `config/config_fed_client.json` and set `use_gpus` to `true`, then repeat the steps
+above.
diff --git a/demo/nvflare/custom/controller.py b/demo/nvflare/horizontal/custom/controller.py
similarity index 100%
rename from demo/nvflare/custom/controller.py
rename to demo/nvflare/horizontal/custom/controller.py
diff --git a/demo/nvflare/custom/trainer.py b/demo/nvflare/horizontal/custom/trainer.py
similarity index 100%
rename from demo/nvflare/custom/trainer.py
rename to demo/nvflare/horizontal/custom/trainer.py
diff --git a/demo/nvflare/prepare_data.sh b/demo/nvflare/horizontal/prepare_data.sh
similarity index 88%
rename from demo/nvflare/prepare_data.sh
rename to demo/nvflare/horizontal/prepare_data.sh
index 1c88c65fe..6a32008f8 100755
--- a/demo/nvflare/prepare_data.sh
+++ b/demo/nvflare/horizontal/prepare_data.sh
@@ -15,8 +15,8 @@ split -n l/${world_size} --numeric-suffixes=1 -a 1 ../data/agaricus.txt.train ag
split -n l/${world_size} --numeric-suffixes=1 -a 1 ../data/agaricus.txt.test agaricus.txt.test-site-
nvflare poc -n 2 --prepare
-mkdir -p /tmp/nvflare/poc/admin/transfer/hello-xgboost
-cp -fr config custom /tmp/nvflare/poc/admin/transfer/hello-xgboost
+mkdir -p /tmp/nvflare/poc/admin/transfer/horizontal-xgboost
+cp -fr config custom /tmp/nvflare/poc/admin/transfer/horizontal-xgboost
cp server-*.pem client-cert.pem /tmp/nvflare/poc/server/
for id in $(eval echo "{1..$world_size}"); do
cp server-cert.pem client-*.pem /tmp/nvflare/poc/site-"$id"/
diff --git a/demo/nvflare/vertical/README.md b/demo/nvflare/vertical/README.md
new file mode 100644
index 000000000..83c3111b6
--- /dev/null
+++ b/demo/nvflare/vertical/README.md
@@ -0,0 +1,59 @@
+# Experimental Support of Vertical Federated XGBoost using NVFlare
+
+This directory contains a demo of Vertical Federated Learning using
+[NVFlare](https://nvidia.github.io/NVFlare/).
+
+## Training with CPU only
+
+To run the demo, first build XGBoost with the federated learning plugin enabled (see the
+[README](../../plugin/federated/README.md)).
+
+Install NVFlare (note that currently NVFlare only supports Python 3.8):
+```shell
+pip install nvflare
+```
+
+Prepare the data (note that this step will download the HIGGS dataset, which is 2.6GB compressed, and 7.5GB
+uncompressed, so make sure you have enough disk space and are on a fast internet connection):
+```shell
+./prepare_data.sh
+```
+
+Start the NVFlare federated server:
+```shell
+/tmp/nvflare/poc/server/startup/start.sh
+```
+
+In another terminal, start the first worker:
+```shell
+/tmp/nvflare/poc/site-1/startup/start.sh
+```
+
+And the second worker:
+```shell
+/tmp/nvflare/poc/site-2/startup/start.sh
+```
+
+Then start the admin CLI:
+```shell
+/tmp/nvflare/poc/admin/startup/fl_admin.sh
+```
+
+In the admin CLI, run the following command:
+```shell
+submit_job vertical-xgboost
+```
+
+Once the training finishes, the model file should be written into
+`/tmp/nvlfare/poc/site-1/run_1/test.model.json` and `/tmp/nvflare/poc/site-2/run_1/test.model.json`
+respectively.
+
+Finally, shutdown everything from the admin CLI, using `admin` as password:
+```shell
+shutdown client
+shutdown server
+```
+
+## Training with GPUs
+
+Currently GPUs are not yet supported by vertical federated XGBoost.
diff --git a/demo/nvflare/vertical/custom/controller.py b/demo/nvflare/vertical/custom/controller.py
new file mode 100644
index 000000000..dd3e39f46
--- /dev/null
+++ b/demo/nvflare/vertical/custom/controller.py
@@ -0,0 +1,68 @@
+"""
+Example of training controller with NVFlare
+===========================================
+"""
+import multiprocessing
+
+from nvflare.apis.client import Client
+from nvflare.apis.fl_context import FLContext
+from nvflare.apis.impl.controller import Controller, Task
+from nvflare.apis.shareable import Shareable
+from nvflare.apis.signal import Signal
+from trainer import SupportedTasks
+
+import xgboost.federated
+
+
+class XGBoostController(Controller):
+ def __init__(self, port: int, world_size: int, server_key_path: str,
+ server_cert_path: str, client_cert_path: str):
+ """Controller for federated XGBoost.
+
+ Args:
+ port: the port for the gRPC server to listen on.
+ world_size: the number of sites.
+ server_key_path: the path to the server key file.
+ server_cert_path: the path to the server certificate file.
+ client_cert_path: the path to the client certificate file.
+ """
+ super().__init__()
+ self._port = port
+ self._world_size = world_size
+ self._server_key_path = server_key_path
+ self._server_cert_path = server_cert_path
+ self._client_cert_path = client_cert_path
+ self._server = None
+
+ def start_controller(self, fl_ctx: FLContext):
+ self._server = multiprocessing.Process(
+ target=xgboost.federated.run_federated_server,
+ args=(self._port, self._world_size, self._server_key_path,
+ self._server_cert_path, self._client_cert_path))
+ self._server.start()
+
+ def stop_controller(self, fl_ctx: FLContext):
+ if self._server:
+ self._server.terminate()
+
+ def process_result_of_unknown_task(self, client: Client, task_name: str,
+ client_task_id: str, result: Shareable,
+ fl_ctx: FLContext):
+ self.log_warning(fl_ctx, f"Unknown task: {task_name} from client {client.name}.")
+
+ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
+ self.log_info(fl_ctx, "XGBoost training control flow started.")
+ if abort_signal.triggered:
+ return
+ task = Task(name=SupportedTasks.TRAIN, data=Shareable())
+ self.broadcast_and_wait(
+ task=task,
+ min_responses=self._world_size,
+ fl_ctx=fl_ctx,
+ wait_time_after_min_received=1,
+ abort_signal=abort_signal,
+ )
+ if abort_signal.triggered:
+ return
+
+ self.log_info(fl_ctx, "XGBoost training control flow finished.")
diff --git a/demo/nvflare/vertical/custom/trainer.py b/demo/nvflare/vertical/custom/trainer.py
new file mode 100644
index 000000000..cd420129c
--- /dev/null
+++ b/demo/nvflare/vertical/custom/trainer.py
@@ -0,0 +1,97 @@
+import os
+
+from nvflare.apis.executor import Executor
+from nvflare.apis.fl_constant import FLContextKey, ReturnCode
+from nvflare.apis.fl_context import FLContext
+from nvflare.apis.shareable import Shareable, make_reply
+from nvflare.apis.signal import Signal
+
+import xgboost as xgb
+from xgboost import callback
+
+
+class SupportedTasks(object):
+ TRAIN = "train"
+
+
+class XGBoostTrainer(Executor):
+ def __init__(self, server_address: str, world_size: int, server_cert_path: str,
+ client_key_path: str, client_cert_path: str):
+ """Trainer for federated XGBoost.
+
+ Args:
+ server_address: address for the gRPC server to connect to.
+ world_size: the number of sites.
+ server_cert_path: the path to the server certificate file.
+ client_key_path: the path to the client key file.
+ client_cert_path: the path to the client certificate file.
+ """
+ super().__init__()
+ self._server_address = server_address
+ self._world_size = world_size
+ self._server_cert_path = server_cert_path
+ self._client_key_path = client_key_path
+ self._client_cert_path = client_cert_path
+
+ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext,
+ abort_signal: Signal) -> Shareable:
+ self.log_info(fl_ctx, f"Executing {task_name}")
+ try:
+ if task_name == SupportedTasks.TRAIN:
+ self._do_training(fl_ctx)
+ return make_reply(ReturnCode.OK)
+ else:
+ self.log_error(fl_ctx, f"{task_name} is not a supported task.")
+ return make_reply(ReturnCode.TASK_UNKNOWN)
+ except BaseException as e:
+ self.log_exception(fl_ctx,
+ f"Task {task_name} failed. Exception: {e.__str__()}")
+ return make_reply(ReturnCode.EXECUTION_EXCEPTION)
+
+ def _do_training(self, fl_ctx: FLContext):
+ client_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME)
+ rank = int(client_name.split('-')[1]) - 1
+ communicator_env = {
+ 'xgboost_communicator': 'federated',
+ 'federated_server_address': self._server_address,
+ 'federated_world_size': self._world_size,
+ 'federated_rank': rank,
+ 'federated_server_cert': self._server_cert_path,
+ 'federated_client_key': self._client_key_path,
+ 'federated_client_cert': self._client_cert_path
+ }
+ with xgb.collective.CommunicatorContext(**communicator_env):
+ # Load file, file will not be sharded in federated mode.
+ if rank == 0:
+ label = '&label_column=0'
+ else:
+ label = ''
+ dtrain = xgb.DMatrix(f'higgs.train.csv?format=csv{label}', data_split_mode=1)
+ dtest = xgb.DMatrix(f'higgs.test.csv?format=csv{label}', data_split_mode=1)
+
+ # specify parameters via map
+ param = {
+ 'validate_parameters': True,
+ 'eta': 0.1,
+ 'gamma': 1.0,
+ 'max_depth': 8,
+ 'min_child_weight': 100,
+ 'tree_method': 'approx',
+ 'grow_policy': 'depthwise',
+ 'objective': 'binary:logistic',
+ 'eval_metric': 'auc',
+ }
+
+ # specify validations set to watch performance
+ watchlist = [(dtest, "eval"), (dtrain, "train")]
+ # number of boosting rounds
+ num_round = 10
+
+ bst = xgb.train(param, dtrain, num_round, evals=watchlist, early_stopping_rounds=2)
+
+ # Save the model.
+ workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT)
+ run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN)
+ run_dir = workspace.get_run_dir(run_number)
+ bst.save_model(os.path.join(run_dir, "higgs.model.federated.vertical.json"))
+ xgb.collective.communicator_print("Finished training\n")
diff --git a/demo/nvflare/vertical/prepare_data.sh b/demo/nvflare/vertical/prepare_data.sh
new file mode 100755
index 000000000..86ec3dfa2
--- /dev/null
+++ b/demo/nvflare/vertical/prepare_data.sh
@@ -0,0 +1,65 @@
+#!/bin/bash
+
+set -e
+
+rm -fr ./*.pem /tmp/nvflare/poc
+
+world_size=2
+
+# Generate server and client certificates.
+openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out server-cert.pem -subj "/C=US/CN=localhost"
+openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost"
+
+# Download HIGGS dataset.
+if [ -f "HIGGS.csv" ]; then
+ echo "HIGGS.csv exists, skipping download."
+else
+ echo "Downloading HIGGS dataset."
+ wget https://archive.ics.uci.edu/ml/machine-learning-databases/00280/HIGGS.csv.gz
+ gunzip HIGGS.csv.gz
+fi
+
+# Split into train/test.
+if [[ -f higgs.train.csv && -f higgs.test.csv ]]; then
+ echo "higgs.train.csv and higgs.test.csv exist, skipping split."
+else
+ echo "Splitting HIGGS dataset into train/test."
+ head -n 10450000 HIGGS.csv > higgs.train.csv
+ tail -n 550000 HIGGS.csv > higgs.test.csv
+fi
+
+# Split train and test files by column to simulate a federated environment.
+site_files=(higgs.{train,test}.csv-site-*)
+if [ ${#site_files[@]} -eq $((world_size*2)) ]; then
+ echo "Site files exist, skipping split."
+else
+ echo "Splitting train/test into site files."
+ total_cols=28 # plus label
+ cols=$((total_cols/world_size))
+ echo "Columns per site: $cols"
+ for (( site=1; site<=world_size; site++ )); do
+ if (( site == 1 )); then
+ start=$((cols*(site-1)+1))
+ else
+ start=$((cols*(site-1)+2))
+ fi
+ if (( site == world_size )); then
+ end=$((total_cols+1))
+ else
+ end=$((cols*site+1))
+ fi
+ echo "Site $site, columns $start-$end"
+ cut -d, -f${start}-${end} higgs.train.csv > higgs.train.csv-site-"${site}"
+ cut -d, -f${start}-${end} higgs.test.csv > higgs.test.csv-site-"${site}"
+ done
+fi
+
+nvflare poc -n 2 --prepare
+mkdir -p /tmp/nvflare/poc/admin/transfer/vertical-xgboost
+cp -fr config custom /tmp/nvflare/poc/admin/transfer/vertical-xgboost
+cp server-*.pem client-cert.pem /tmp/nvflare/poc/server/
+for (( site=1; site<=world_size; site++ )); do
+ cp server-cert.pem client-*.pem /tmp/nvflare/poc/site-"${site}"/
+ ln -s "${PWD}"/higgs.train.csv-site-"${site}" /tmp/nvflare/poc/site-"${site}"/higgs.train.csv
+ ln -s "${PWD}"/higgs.test.csv-site-"${site}" /tmp/nvflare/poc/site-"${site}"/higgs.test.csv
+done
diff --git a/dev/release-artifacts.py b/dev/release-artifacts.py
index 18c317a91..eab64ff0c 100644
--- a/dev/release-artifacts.py
+++ b/dev/release-artifacts.py
@@ -105,7 +105,7 @@ def make_pysrc_wheel(release: str, outdir: str) -> None:
os.mkdir(dist)
with DirectoryExcursion(os.path.join(ROOT, "python-package")):
- subprocess.check_call(["python", "setup.py", "sdist"])
+ subprocess.check_call(["python", "-m", "build", "--sdist"])
src = os.path.join(DIST, f"xgboost-{release}.tar.gz")
subprocess.check_call(["twine", "check", src])
shutil.move(src, os.path.join(dist, f"xgboost-{release}.tar.gz"))
diff --git a/doc/Doxyfile.in b/doc/Doxyfile.in
index b159ef172..e24d67282 100644
--- a/doc/Doxyfile.in
+++ b/doc/Doxyfile.in
@@ -1,4 +1,4 @@
-# Doxyfile 1.8.8
+# Doxyfile 1.9.1
# This file describes the settings to be used by the documentation system
# doxygen (www.doxygen.org) for a project.
@@ -17,11 +17,11 @@
# Project related configuration options
#---------------------------------------------------------------------------
-# This tag specifies the encoding used for all characters in the config file
-# that follow. The default is UTF-8 which is also the encoding used for all text
-# before the first occurrence of this tag. Doxygen uses libiconv (or the iconv
-# built into libc) for the transcoding. See http://www.gnu.org/software/libiconv
-# for the list of possible encodings.
+# This tag specifies the encoding used for all characters in the configuration
+# file that follow. The default is UTF-8 which is also the encoding used for all
+# text before the first occurrence of this tag. Doxygen uses libiconv (or the
+# iconv built into libc) for the transcoding. See
+# https://www.gnu.org/software/libiconv/ for the list of possible encodings.
# The default value is: UTF-8.
DOXYFILE_ENCODING = UTF-8
@@ -32,7 +32,7 @@ DOXYFILE_ENCODING = UTF-8
# title of most generated pages and in a few other places.
# The default value is: My Project.
-PROJECT_NAME = "xgboost"
+PROJECT_NAME = xgboost
# The PROJECT_NUMBER tag can be used to enter a project or revision number. This
# could be handy for archiving the generated documentation or if some version
@@ -46,10 +46,10 @@ PROJECT_NUMBER = @XGBOOST_VERSION@
PROJECT_BRIEF =
-# With the PROJECT_LOGO tag one can specify an logo or icon that is included in
-# the documentation. The maximum height of the logo should not exceed 55 pixels
-# and the maximum width should not exceed 200 pixels. Doxygen will copy the logo
-# to the output directory.
+# With the PROJECT_LOGO tag one can specify a logo or an icon that is included
+# in the documentation. The maximum height of the logo should not exceed 55
+# pixels and the maximum width should not exceed 200 pixels. Doxygen will copy
+# the logo to the output directory.
PROJECT_LOGO =
@@ -60,7 +60,7 @@ PROJECT_LOGO =
OUTPUT_DIRECTORY = @PROJECT_BINARY_DIR@/doc_doxygen
-# If the CREATE_SUBDIRS tag is set to YES, then doxygen will create 4096 sub-
+# If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub-
# directories (in 2 levels) under the output directory of each output format and
# will distribute the generated files over these directories. Enabling this
# option can be useful when feeding doxygen a huge amount of source files, where
@@ -76,7 +76,7 @@ CREATE_SUBDIRS = NO
# U+3044.
# The default value is: NO.
-#ALLOW_UNICODE_NAMES = NO
+ALLOW_UNICODE_NAMES = NO
# The OUTPUT_LANGUAGE tag is used to specify the language in which all
# documentation generated by doxygen is written. Doxygen will use this
@@ -93,14 +93,22 @@ CREATE_SUBDIRS = NO
OUTPUT_LANGUAGE = English
-# If the BRIEF_MEMBER_DESC tag is set to YES doxygen will include brief member
+# The OUTPUT_TEXT_DIRECTION tag is used to specify the direction in which all
+# documentation generated by doxygen is written. Doxygen will use this
+# information to generate all generated output in the proper direction.
+# Possible values are: None, LTR, RTL and Context.
+# The default value is: None.
+
+OUTPUT_TEXT_DIRECTION = None
+
+# If the BRIEF_MEMBER_DESC tag is set to YES, doxygen will include brief member
# descriptions after the members that are listed in the file and class
# documentation (similar to Javadoc). Set to NO to disable this.
# The default value is: YES.
BRIEF_MEMBER_DESC = YES
-# If the REPEAT_BRIEF tag is set to YES doxygen will prepend the brief
+# If the REPEAT_BRIEF tag is set to YES, doxygen will prepend the brief
# description of a member or function before the detailed description
#
# Note: If both HIDE_UNDOC_MEMBERS and BRIEF_MEMBER_DESC are set to NO, the
@@ -135,7 +143,7 @@ ALWAYS_DETAILED_SEC = NO
INLINE_INHERITED_MEMB = NO
-# If the FULL_PATH_NAMES tag is set to YES doxygen will prepend the full path
+# If the FULL_PATH_NAMES tag is set to YES, doxygen will prepend the full path
# before files name in the file list and in the header files. If set to NO the
# shortest path that makes the file name unique will be used
# The default value is: YES.
@@ -179,6 +187,16 @@ SHORT_NAMES = NO
JAVADOC_AUTOBRIEF = NO
+# If the JAVADOC_BANNER tag is set to YES then doxygen will interpret a line
+# such as
+# /***************
+# as being the beginning of a Javadoc-style comment "banner". If set to NO, the
+# Javadoc-style will behave just like regular comments and it will not be
+# interpreted by doxygen.
+# The default value is: NO.
+
+JAVADOC_BANNER = NO
+
# If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first
# line (until the first dot) of a Qt-style comment as the brief description. If
# set to NO, the Qt-style will behave just like regular Qt-style comments (thus
@@ -199,15 +217,23 @@ QT_AUTOBRIEF = NO
MULTILINE_CPP_IS_BRIEF = NO
+# By default Python docstrings are displayed as preformatted text and doxygen's
+# special commands cannot be used. By setting PYTHON_DOCSTRING to NO the
+# doxygen's special commands can be used and the contents of the docstring
+# documentation blocks is shown as doxygen documentation.
+# The default value is: YES.
+
+PYTHON_DOCSTRING = YES
+
# If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the
# documentation from any documented member that it re-implements.
# The default value is: YES.
INHERIT_DOCS = YES
-# If the SEPARATE_MEMBER_PAGES tag is set to YES, then doxygen will produce a
-# new page for each member. If set to NO, the documentation of a member will be
-# part of the file/class/namespace that contains it.
+# If the SEPARATE_MEMBER_PAGES tag is set to YES then doxygen will produce a new
+# page for each member. If set to NO, the documentation of a member will be part
+# of the file/class/namespace that contains it.
# The default value is: NO.
SEPARATE_MEMBER_PAGES = NO
@@ -226,16 +252,15 @@ TAB_SIZE = 8
# will allow you to put the command \sideeffect (or @sideeffect) in the
# documentation, which will result in a user-defined paragraph with heading
# "Side Effects:". You can put \n's in the value part of an alias to insert
-# newlines.
+# newlines (in the resulting output). You can put ^^ in the value part of an
+# alias to insert a newline as if a physical newline was in the original file.
+# When you need a literal { or } or , in the value part of an alias you have to
+# escape them by means of a backslash (\), this can lead to conflicts with the
+# commands \{ and \} for these it is advised to use the version @{ and @} or use
+# a double escape (\\{ and \\})
ALIASES =
-# This tag can be used to specify a number of word-keyword mappings (TCL only).
-# A mapping has the form "name=value". For example adding "class=itcl::class"
-# will allow you to use the command class in the itcl::class meaning.
-
-TCL_SUBST =
-
# Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources
# only. Doxygen will then generate output that is more tailored for C. For
# instance, some of the names that are used will be different. The list of all
@@ -264,42 +289,63 @@ OPTIMIZE_FOR_FORTRAN = NO
OPTIMIZE_OUTPUT_VHDL = NO
+# Set the OPTIMIZE_OUTPUT_SLICE tag to YES if your project consists of Slice
+# sources only. Doxygen will then generate output that is more tailored for that
+# language. For instance, namespaces will be presented as modules, types will be
+# separated into more groups, etc.
+# The default value is: NO.
+
+OPTIMIZE_OUTPUT_SLICE = NO
+
# Doxygen selects the parser to use depending on the extension of the files it
# parses. With this tag you can assign which parser to use for a given
# extension. Doxygen has a built-in mapping, but you can override or extend it
# using this tag. The format is ext=language, where ext is a file extension, and
-# language is one of the parsers supported by doxygen: IDL, Java, Javascript,
-# C#, C, C++, D, PHP, Objective-C, Python, Fortran (fixed format Fortran:
-# FortranFixed, free formatted Fortran: FortranFree, unknown formatted Fortran:
-# Fortran. In the later case the parser tries to guess whether the code is fixed
-# or free formatted code, this is the default for Fortran type files), VHDL. For
-# instance to make doxygen treat .inc files as Fortran files (default is PHP),
-# and .f files as C (default is Fortran), use: inc=Fortran f=C.
+# language is one of the parsers supported by doxygen: IDL, Java, JavaScript,
+# Csharp (C#), C, C++, D, PHP, md (Markdown), Objective-C, Python, Slice, VHDL,
+# Fortran (fixed format Fortran: FortranFixed, free formatted Fortran:
+# FortranFree, unknown formatted Fortran: Fortran. In the later case the parser
+# tries to guess whether the code is fixed or free formatted code, this is the
+# default for Fortran type files). For instance to make doxygen treat .inc files
+# as Fortran files (default is PHP), and .f files as C (default is Fortran),
+# use: inc=Fortran f=C.
#
-# Note For files without extension you can use no_extension as a placeholder.
+# Note: For files without extension you can use no_extension as a placeholder.
#
# Note that for custom extensions you also need to set FILE_PATTERNS otherwise
-# the files are not read by doxygen.
+# the files are not read by doxygen. When specifying no_extension you should add
+# * to the FILE_PATTERNS.
+#
+# Note see also the list of default file extension mappings.
EXTENSION_MAPPING =
# If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments
# according to the Markdown format, which allows for more readable
-# documentation. See http://daringfireball.net/projects/markdown/ for details.
+# documentation. See https://daringfireball.net/projects/markdown/ for details.
# The output of markdown processing is further processed by doxygen, so you can
# mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in
# case of backward compatibilities issues.
# The default value is: YES.
-#MARKDOWN_SUPPORT = YES
+MARKDOWN_SUPPORT = YES
+
+# When the TOC_INCLUDE_HEADINGS tag is set to a non-zero value, all headings up
+# to that level are automatically included in the table of contents, even if
+# they do not have an id attribute.
+# Note: This feature currently applies only to Markdown headings.
+# Minimum value: 0, maximum value: 99, default value: 5.
+# This tag requires that the tag MARKDOWN_SUPPORT is set to YES.
+
+TOC_INCLUDE_HEADINGS = 5
# When enabled doxygen tries to link words that correspond to documented
# classes, or namespaces to their corresponding documentation. Such a link can
-# be prevented in individual cases by by putting a % sign in front of the word
-# or globally by setting AUTOLINK_SUPPORT to NO.
+# be prevented in individual cases by putting a % sign in front of the word or
+# globally by setting AUTOLINK_SUPPORT to NO.
# The default value is: YES.
-#AUTOLINK_SUPPORT = YES
+AUTOLINK_SUPPORT = YES
# If you use STL classes (i.e. std::string, std::vector, etc.) but do not want
# to include (a tag file for) the STL sources as input, then you should set this
@@ -318,7 +364,7 @@ BUILTIN_STL_SUPPORT = NO
CPP_CLI_SUPPORT = NO
# Set the SIP_SUPPORT tag to YES if your project consists of sip (see:
-# http://www.riverbankcomputing.co.uk/software/sip/intro) sources only. Doxygen
+# https://www.riverbankcomputing.com/software/sip/intro) sources only. Doxygen
# will parse them like normal C++ but will assume all classes use public instead
# of private inheritance when no explicit protection keyword is present.
# The default value is: NO.
@@ -336,13 +382,20 @@ SIP_SUPPORT = NO
IDL_PROPERTY_SUPPORT = YES
# If member grouping is used in the documentation and the DISTRIBUTE_GROUP_DOC
-# tag is set to YES, then doxygen will reuse the documentation of the first
+# tag is set to YES then doxygen will reuse the documentation of the first
# member in the group (if any) for the other members of the group. By default
# all members of a group must be documented explicitly.
# The default value is: NO.
DISTRIBUTE_GROUP_DOC = NO
+# If one adds a struct or class to a group and this option is enabled, then also
+# any nested class or struct is added to the same group. By default this option
+# is disabled and one has to add nested compounds explicitly via \ingroup.
+# The default value is: NO.
+
+GROUP_NESTED_COMPOUNDS = NO
+
# Set the SUBGROUPING tag to YES to allow class member groups of the same type
# (for instance a group of public functions) to be put as a subgroup of that
# type (e.g. under the Public Functions section). Set it to NO to prevent
@@ -397,11 +450,24 @@ TYPEDEF_HIDES_STRUCT = NO
LOOKUP_CACHE_SIZE = 0
+# The NUM_PROC_THREADS specifies the number threads doxygen is allowed to use
+# during processing. When set to 0 doxygen will based this on the number of
+# cores available in the system. You can set it explicitly to a value larger
+# than 0 to get more control over the balance between CPU load and processing
+# speed. At this moment only the input processing can be done using multiple
+# threads. Since this is still an experimental feature the default is set to 1,
+# which efficively disables parallel processing. Please report any issues you
+# encounter. Generating dot graphs in parallel is controlled by the
+# DOT_NUM_THREADS setting.
+# Minimum value: 0, maximum value: 32, default value: 1.
+
+NUM_PROC_THREADS = 1
+
#---------------------------------------------------------------------------
# Build related configuration options
#---------------------------------------------------------------------------
-# If the EXTRACT_ALL tag is set to YES doxygen will assume all entities in
+# If the EXTRACT_ALL tag is set to YES, doxygen will assume all entities in
# documentation are documented, even if no documentation was available. Private
# class members and static file members will be hidden unless the
# EXTRACT_PRIVATE respectively EXTRACT_STATIC tags are set to YES.
@@ -411,35 +477,41 @@ LOOKUP_CACHE_SIZE = 0
EXTRACT_ALL = YES
-# If the EXTRACT_PRIVATE tag is set to YES all private members of a class will
+# If the EXTRACT_PRIVATE tag is set to YES, all private members of a class will
# be included in the documentation.
# The default value is: NO.
EXTRACT_PRIVATE = NO
-# If the EXTRACT_PACKAGE tag is set to YES all members with package or internal
+# If the EXTRACT_PRIV_VIRTUAL tag is set to YES, documented private virtual
+# methods of a class will be included in the documentation.
+# The default value is: NO.
+
+EXTRACT_PRIV_VIRTUAL = NO
+
+# If the EXTRACT_PACKAGE tag is set to YES, all members with package or internal
# scope will be included in the documentation.
# The default value is: NO.
-#EXTRACT_PACKAGE = NO
+EXTRACT_PACKAGE = NO
-# If the EXTRACT_STATIC tag is set to YES all static members of a file will be
+# If the EXTRACT_STATIC tag is set to YES, all static members of a file will be
# included in the documentation.
# The default value is: NO.
EXTRACT_STATIC = NO
-# If the EXTRACT_LOCAL_CLASSES tag is set to YES classes (and structs) defined
-# locally in source files will be included in the documentation. If set to NO
+# If the EXTRACT_LOCAL_CLASSES tag is set to YES, classes (and structs) defined
+# locally in source files will be included in the documentation. If set to NO,
# only classes defined in header files are included. Does not have any effect
# for Java sources.
# The default value is: YES.
EXTRACT_LOCAL_CLASSES = YES
-# This flag is only useful for Objective-C code. When set to YES local methods,
+# This flag is only useful for Objective-C code. If set to YES, local methods,
# which are defined in the implementation section but not in the interface are
-# included in the documentation. If set to NO only methods in the interface are
+# included in the documentation. If set to NO, only methods in the interface are
# included.
# The default value is: NO.
@@ -454,6 +526,13 @@ EXTRACT_LOCAL_METHODS = NO
EXTRACT_ANON_NSPACES = NO
+# If this flag is set to YES, the name of an unnamed parameter in a declaration
+# will be determined by the corresponding definition. By default unnamed
+# parameters remain unnamed in the output.
+# The default value is: YES.
+
+RESOLVE_UNNAMED_PARAMS = YES
+
# If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all
# undocumented members inside documented classes or files. If set to NO these
# members will be included in the various overviews, but no documentation
@@ -464,21 +543,21 @@ HIDE_UNDOC_MEMBERS = NO
# If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all
# undocumented classes that are normally visible in the class hierarchy. If set
-# to NO these classes will be included in the various overviews. This option has
-# no effect if EXTRACT_ALL is enabled.
+# to NO, these classes will be included in the various overviews. This option
+# has no effect if EXTRACT_ALL is enabled.
# The default value is: NO.
HIDE_UNDOC_CLASSES = NO
# If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend
-# (class|struct|union) declarations. If set to NO these declarations will be
-# included in the documentation.
+# declarations. If set to NO, these declarations will be included in the
+# documentation.
# The default value is: NO.
HIDE_FRIEND_COMPOUNDS = NO
# If the HIDE_IN_BODY_DOCS tag is set to YES, doxygen will hide any
-# documentation blocks found inside the body of a function. If set to NO these
+# documentation blocks found inside the body of a function. If set to NO, these
# blocks will be appended to the function's detailed documentation block.
# The default value is: NO.
@@ -491,22 +570,36 @@ HIDE_IN_BODY_DOCS = NO
INTERNAL_DOCS = NO
-# If the CASE_SENSE_NAMES tag is set to NO then doxygen will only generate file
-# names in lower-case letters. If set to YES upper-case letters are also
-# allowed. This is useful if you have classes or files whose names only differ
-# in case and if your file system supports case sensitive file names. Windows
-# and Mac users are advised to set this option to NO.
+# With the correct setting of option CASE_SENSE_NAMES doxygen will better be
+# able to match the capabilities of the underlying filesystem. In case the
+# filesystem is case sensitive (i.e. it supports files in the same directory
+# whose names only differ in casing), the option must be set to YES to properly
+# deal with such files in case they appear in the input. For filesystems that
+# are not case sensitive the option should be be set to NO to properly deal with
+# output files written for symbols that only differ in casing, such as for two
+# classes, one named CLASS and the other named Class, and to also support
+# references to files without having to specify the exact matching casing. On
+# Windows (including Cygwin) and MacOS, users should typically set this option
+# to NO, whereas on Linux or other Unix flavors it should typically be set to
+# YES.
# The default value is: system dependent.
CASE_SENSE_NAMES = YES
# If the HIDE_SCOPE_NAMES tag is set to NO then doxygen will show members with
-# their full class and namespace scopes in the documentation. If set to YES the
+# their full class and namespace scopes in the documentation. If set to YES, the
# scope will be hidden.
# The default value is: NO.
HIDE_SCOPE_NAMES = NO
+# If the HIDE_COMPOUND_REFERENCE tag is set to NO (default) then doxygen will
+# append additional text to a page's title, such as Class Reference. If set to
+# YES the compound reference will be hidden.
+# The default value is: NO.
+
+HIDE_COMPOUND_REFERENCE= NO
+
# If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of
# the files that are included by a file in the documentation of that file.
# The default value is: YES.
@@ -518,7 +611,7 @@ SHOW_INCLUDE_FILES = YES
# which file to include in order to use the member.
# The default value is: NO.
-#SHOW_GROUPED_MEMB_INC = NO
+SHOW_GROUPED_MEMB_INC = NO
# If the FORCE_LOCAL_INCLUDES tag is set to YES then doxygen will list include
# files with double quotes in the documentation rather than with sharp brackets.
@@ -534,14 +627,14 @@ INLINE_INFO = YES
# If the SORT_MEMBER_DOCS tag is set to YES then doxygen will sort the
# (detailed) documentation of file and class members alphabetically by member
-# name. If set to NO the members will appear in declaration order.
+# name. If set to NO, the members will appear in declaration order.
# The default value is: YES.
SORT_MEMBER_DOCS = YES
# If the SORT_BRIEF_DOCS tag is set to YES then doxygen will sort the brief
# descriptions of file, namespace and class members alphabetically by member
-# name. If set to NO the members will appear in declaration order. Note that
+# name. If set to NO, the members will appear in declaration order. Note that
# this will also influence the order of the classes in the class list.
# The default value is: NO.
@@ -586,27 +679,25 @@ SORT_BY_SCOPE_NAME = NO
STRICT_PROTO_MATCHING = NO
-# The GENERATE_TODOLIST tag can be used to enable ( YES) or disable ( NO) the
-# todo list. This list is created by putting \todo commands in the
-# documentation.
+# The GENERATE_TODOLIST tag can be used to enable (YES) or disable (NO) the todo
+# list. This list is created by putting \todo commands in the documentation.
# The default value is: YES.
GENERATE_TODOLIST = YES
-# The GENERATE_TESTLIST tag can be used to enable ( YES) or disable ( NO) the
-# test list. This list is created by putting \test commands in the
-# documentation.
+# The GENERATE_TESTLIST tag can be used to enable (YES) or disable (NO) the test
+# list. This list is created by putting \test commands in the documentation.
# The default value is: YES.
GENERATE_TESTLIST = YES
-# The GENERATE_BUGLIST tag can be used to enable ( YES) or disable ( NO) the bug
+# The GENERATE_BUGLIST tag can be used to enable (YES) or disable (NO) the bug
# list. This list is created by putting \bug commands in the documentation.
# The default value is: YES.
GENERATE_BUGLIST = YES
-# The GENERATE_DEPRECATEDLIST tag can be used to enable ( YES) or disable ( NO)
+# The GENERATE_DEPRECATEDLIST tag can be used to enable (YES) or disable (NO)
# the deprecated list. This list is created by putting \deprecated commands in
# the documentation.
# The default value is: YES.
@@ -631,8 +722,8 @@ ENABLED_SECTIONS =
MAX_INITIALIZER_LINES = 30
# Set the SHOW_USED_FILES tag to NO to disable the list of files generated at
-# the bottom of the documentation of classes and structs. If set to YES the list
-# will mention the files that were used to generate the documentation.
+# the bottom of the documentation of classes and structs. If set to YES, the
+# list will mention the files that were used to generate the documentation.
# The default value is: YES.
SHOW_USED_FILES = YES
@@ -677,7 +768,7 @@ LAYOUT_FILE =
# The CITE_BIB_FILES tag can be used to specify one or more bib files containing
# the reference definitions. This must be a list of .bib files. The .bib
# extension is automatically appended if omitted. This requires the bibtex tool
-# to be installed. See also http://en.wikipedia.org/wiki/BibTeX for more info.
+# to be installed. See also https://en.wikipedia.org/wiki/BibTeX for more info.
# For LaTeX the style of the bibliography can be controlled using
# LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the
# search path. See also \cite for info how to create references.
@@ -696,7 +787,7 @@ CITE_BIB_FILES =
QUIET = NO
# The WARNINGS tag can be used to turn on/off the warning messages that are
-# generated to standard error ( stderr) by doxygen. If WARNINGS is set to YES
+# generated to standard error (stderr) by doxygen. If WARNINGS is set to YES
# this implies that the warnings are on.
#
# Tip: Turn warnings on while writing the documentation.
@@ -704,7 +795,7 @@ QUIET = NO
WARNINGS = YES
-# If the WARN_IF_UNDOCUMENTED tag is set to YES, then doxygen will generate
+# If the WARN_IF_UNDOCUMENTED tag is set to YES then doxygen will generate
# warnings for undocumented members. If EXTRACT_ALL is set to YES then this flag
# will automatically be disabled.
# The default value is: YES.
@@ -721,12 +812,22 @@ WARN_IF_DOC_ERROR = YES
# This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that
# are documented, but have no documentation for their parameters or return
-# value. If set to NO doxygen will only warn about wrong or incomplete parameter
-# documentation, but not about the absence of documentation.
+# value. If set to NO, doxygen will only warn about wrong or incomplete
+# parameter documentation, but not about the absence of documentation. If
+# EXTRACT_ALL is set to YES then this flag will automatically be disabled.
# The default value is: NO.
WARN_NO_PARAMDOC = YES
+# If the WARN_AS_ERROR tag is set to YES then doxygen will immediately stop when
+# a warning is encountered. If the WARN_AS_ERROR tag is set to FAIL_ON_WARNINGS
+# then doxygen will continue running as if WARN_AS_ERROR tag is set to NO, but
+# at the end of the doxygen process doxygen will return with a non-zero status.
+# Possible values are: NO, YES and FAIL_ON_WARNINGS.
+# The default value is: NO.
+
+WARN_AS_ERROR = NO
+
# The WARN_FORMAT tag determines the format of the warning messages that doxygen
# can produce. The string should contain the $file, $line, and $text tags, which
# will be replaced by the file and line number from which the warning originated
@@ -750,7 +851,7 @@ WARN_LOGFILE =
# The INPUT tag is used to specify the files and/or directories that contain
# documented source files. You may enter file names like myfile.cpp or
# directories like /usr/src/myproject. Separate the files or directories with
-# spaces.
+# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING
# Note: If this tag is empty the current directory is searched.
INPUT = @PROJECT_SOURCE_DIR@/include
@@ -758,20 +859,29 @@ INPUT = @PROJECT_SOURCE_DIR@/include
# This tag can be used to specify the character encoding of the source files
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
# libiconv (or the iconv built into libc) for the transcoding. See the libiconv
-# documentation (see: http://www.gnu.org/software/libiconv) for the list of
-# possible encodings.
+# documentation (see:
+# https://www.gnu.org/software/libiconv/) for the list of possible encodings.
# The default value is: UTF-8.
INPUT_ENCODING = UTF-8
# If the value of the INPUT tag contains directories, you can use the
# FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and
-# *.h) to filter out the source-files in the directories. If left blank the
-# following patterns are tested:*.c, *.cc, *.cxx, *.cpp, *.c++, *.java, *.ii,
-# *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h, *.hh, *.hxx, *.hpp,
-# *.h++, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, *.inc, *.m, *.markdown,
-# *.md, *.mm, *.dox, *.py, *.f90, *.f, *.for, *.tcl, *.vhd, *.vhdl, *.ucf,
-# *.qsf, *.as and *.js.
+# *.h) to filter out the source-files in the directories.
+#
+# Note that for custom extensions or not directly supported extensions you also
+# need to set EXTENSION_MAPPING for the extension otherwise the files are not
+# read by doxygen.
+#
+# Note the list of default checked file patterns might differ from the list of
+# default file extension mappings.
+#
+# If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cpp,
+# *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h,
+# *.hh, *.hxx, *.hpp, *.h++, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, *.inc,
+# *.m, *.markdown, *.md, *.mm, *.dox (to be provided as doxygen C comment),
+# *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, *.f18, *.f, *.for, *.vhd, *.vhdl,
+# *.ucf, *.qsf and *.ice.
FILE_PATTERNS = *.h
@@ -858,6 +968,10 @@ IMAGE_PATH =
# Note that the filter must not add or remove lines; it is applied before the
# code is scanned, but not when the output code is generated. If lines are added
# or removed, the anchors will not be placed correctly.
+#
+# Note that for custom extensions or not directly supported extensions you also
+# need to set EXTENSION_MAPPING for the extension otherwise the files are not
+# properly processed by doxygen.
INPUT_FILTER =
@@ -867,11 +981,15 @@ INPUT_FILTER =
# (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how
# filters are used. If the FILTER_PATTERNS tag is empty or if none of the
# patterns match the file name, INPUT_FILTER is applied.
+#
+# Note that for custom extensions or not directly supported extensions you also
+# need to set EXTENSION_MAPPING for the extension otherwise the files are not
+# properly processed by doxygen.
FILTER_PATTERNS =
# If the FILTER_SOURCE_FILES tag is set to YES, the input filter (if set using
-# INPUT_FILTER ) will also be used to filter the input files that are used for
+# INPUT_FILTER) will also be used to filter the input files that are used for
# producing the source files to browse (i.e. when SOURCE_BROWSER is set to YES).
# The default value is: NO.
@@ -890,7 +1008,7 @@ FILTER_SOURCE_PATTERNS =
# (index.html). This can be useful if you have a project on for instance GitHub
# and want to reuse the introduction page also for the doxygen output.
-#USE_MDFILE_AS_MAINPAGE =
+USE_MDFILE_AS_MAINPAGE =
#---------------------------------------------------------------------------
# Configuration options related to source browsing
@@ -919,7 +1037,7 @@ INLINE_SOURCES = NO
STRIP_CODE_COMMENTS = YES
# If the REFERENCED_BY_RELATION tag is set to YES then for each documented
-# function all documented functions referencing it will be listed.
+# entity all documented functions referencing it will be listed.
# The default value is: NO.
REFERENCED_BY_RELATION = NO
@@ -931,7 +1049,7 @@ REFERENCED_BY_RELATION = NO
REFERENCES_RELATION = NO
# If the REFERENCES_LINK_SOURCE tag is set to YES and SOURCE_BROWSER tag is set
-# to YES, then the hyperlinks from functions in REFERENCES_RELATION and
+# to YES then the hyperlinks from functions in REFERENCES_RELATION and
# REFERENCED_BY_RELATION lists will link to the source code. Otherwise they will
# link to the documentation.
# The default value is: YES.
@@ -946,17 +1064,17 @@ REFERENCES_LINK_SOURCE = YES
# The default value is: YES.
# This tag requires that the tag SOURCE_BROWSER is set to YES.
-#SOURCE_TOOLTIPS = YES
+SOURCE_TOOLTIPS = YES
# If the USE_HTAGS tag is set to YES then the references to source code will
# point to the HTML generated by the htags(1) tool instead of doxygen built-in
# source browser. The htags tool is part of GNU's global source tagging system
-# (see http://www.gnu.org/software/global/global.html). You will need version
+# (see https://www.gnu.org/software/global/global.html). You will need version
# 4.8.6 or higher.
#
# To use it do the following:
# - Install the latest version of global
-# - Enable SOURCE_BROWSER and USE_HTAGS in the config file
+# - Enable SOURCE_BROWSER and USE_HTAGS in the configuration file
# - Make sure the INPUT points to the root of the source tree
# - Run doxygen as normal
#
@@ -978,16 +1096,22 @@ USE_HTAGS = NO
VERBATIM_HEADERS = YES
-# If the CLANG_ASSISTED_PARSING tag is set to YES, then doxygen will use the
-# clang parser (see: http://clang.llvm.org/) for more accurate parsing at the
-# cost of reduced performance. This can be particularly helpful with template
-# rich C++ code for which doxygen's built-in parser lacks the necessary type
-# information.
+# If the CLANG_ASSISTED_PARSING tag is set to YES then doxygen will use the
+# clang parser (see:
+# http://clang.llvm.org/) for more accurate parsing at the cost of reduced
+# performance. This can be particularly helpful with template rich C++ code for
+# which doxygen's built-in parser lacks the necessary type information.
# Note: The availability of this option depends on whether or not doxygen was
-# compiled with the --with-libclang option.
+# generated with the -Duse_libclang=ON option for CMake.
# The default value is: NO.
-#CLANG_ASSISTED_PARSING = NO
+CLANG_ASSISTED_PARSING = NO
+
+# If clang assisted parsing is enabled and the CLANG_ADD_INC_PATHS tag is set to
+# YES then doxygen will add the directory of each input to the include path.
+# The default value is: YES.
+
+CLANG_ADD_INC_PATHS = YES
# If clang assisted parsing is enabled you can provide the compiler with command
# line options that you would normally use when invoking the compiler. Note that
@@ -995,7 +1119,20 @@ VERBATIM_HEADERS = YES
# specified with INPUT and INCLUDE_PATH.
# This tag requires that the tag CLANG_ASSISTED_PARSING is set to YES.
-#CLANG_OPTIONS =
+CLANG_OPTIONS =
+
+# If clang assisted parsing is enabled you can provide the clang parser with the
+# path to the directory containing a file called compile_commands.json. This
+# file is the compilation database (see:
+# http://clang.llvm.org/docs/HowToSetupToolingForLLVM.html) containing the
+# options used when the source files were built. This is equivalent to
+# specifying the -p option to a clang tool, such as clang-check. These options
+# will then be passed to the parser. Any options specified with CLANG_OPTIONS
+# will be added as well.
+# Note: The availability of this option depends on whether or not doxygen was
+# generated with the -Duse_libclang=ON option for CMake.
+
+CLANG_DATABASE_PATH =
#---------------------------------------------------------------------------
# Configuration options related to the alphabetical class index
@@ -1008,13 +1145,6 @@ VERBATIM_HEADERS = YES
ALPHABETICAL_INDEX = YES
-# The COLS_IN_ALPHA_INDEX tag can be used to specify the number of columns in
-# which the alphabetical index list will be split.
-# Minimum value: 1, maximum value: 20, default value: 5.
-# This tag requires that the tag ALPHABETICAL_INDEX is set to YES.
-
-COLS_IN_ALPHA_INDEX = 5
-
# In case all classes in a project start with a common prefix, all classes will
# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag
# can be used to specify a prefix (or a list of prefixes) that should be ignored
@@ -1027,7 +1157,7 @@ IGNORE_PREFIX =
# Configuration options related to the HTML output
#---------------------------------------------------------------------------
-# If the GENERATE_HTML tag is set to YES doxygen will generate HTML output
+# If the GENERATE_HTML tag is set to YES, doxygen will generate HTML output
# The default value is: YES.
GENERATE_HTML = YES
@@ -1093,14 +1223,14 @@ HTML_STYLESHEET =
# cascading style sheets that are included after the standard style sheets
# created by doxygen. Using this option one can overrule certain style aspects.
# This is preferred over using HTML_STYLESHEET since it does not replace the
-# standard style sheet and is therefor more robust against future updates.
+# standard style sheet and is therefore more robust against future updates.
# Doxygen will copy the style sheet files to the output directory.
-# Note: The order of the extra stylesheet files is of importance (e.g. the last
-# stylesheet in the list overrules the setting of the previous ones in the
+# Note: The order of the extra style sheet files is of importance (e.g. the last
+# style sheet in the list overrules the setting of the previous ones in the
# list). For an example see the documentation.
# This tag requires that the tag GENERATE_HTML is set to YES.
-#HTML_EXTRA_STYLESHEET =
+HTML_EXTRA_STYLESHEET =
# The HTML_EXTRA_FILES tag can be used to specify one or more extra images or
# other source files which should be copied to the HTML output directory. Note
@@ -1113,9 +1243,9 @@ HTML_STYLESHEET =
HTML_EXTRA_FILES =
# The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen
-# will adjust the colors in the stylesheet and background images according to
+# will adjust the colors in the style sheet and background images according to
# this color. Hue is specified as an angle on a colorwheel, see
-# http://en.wikipedia.org/wiki/Hue for more information. For instance the value
+# https://en.wikipedia.org/wiki/Hue for more information. For instance the value
# 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300
# purple, and 360 is red again.
# Minimum value: 0, maximum value: 359, default value: 220.
@@ -1144,12 +1274,24 @@ HTML_COLORSTYLE_GAMMA = 80
# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML
# page will contain the date and time when the page was generated. Setting this
-# to NO can help when comparing the output of multiple runs.
-# The default value is: YES.
+# to YES can help to show when doxygen was last run and thus if the
+# documentation is up to date.
+# The default value is: NO.
# This tag requires that the tag GENERATE_HTML is set to YES.
HTML_TIMESTAMP = YES
+# If the HTML_DYNAMIC_MENUS tag is set to YES then the generated HTML
+# documentation will contain a main index with vertical navigation menus that
+# are dynamically created via JavaScript. If disabled, the navigation index will
+# consists of multiple levels of tabs that are statically embedded in every HTML
+# page. Disable this option to support browsers that do not have JavaScript,
+# like the Qt help browser.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_DYNAMIC_MENUS = YES
+
# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML
# documentation will contain sections that can be hidden and shown after the
# page has loaded.
@@ -1169,17 +1311,18 @@ HTML_DYNAMIC_SECTIONS = NO
# Minimum value: 0, maximum value: 9999, default value: 100.
# This tag requires that the tag GENERATE_HTML is set to YES.
-#HTML_INDEX_NUM_ENTRIES = 100
+HTML_INDEX_NUM_ENTRIES = 100
# If the GENERATE_DOCSET tag is set to YES, additional index files will be
# generated that can be used as input for Apple's Xcode 3 integrated development
-# environment (see: http://developer.apple.com/tools/xcode/), introduced with
-# OSX 10.5 (Leopard). To create a documentation set, doxygen will generate a
-# Makefile in the HTML output directory. Running make will produce the docset in
-# that directory and running make install will install the docset in
+# environment (see:
+# https://developer.apple.com/xcode/), introduced with OSX 10.5 (Leopard). To
+# create a documentation set, doxygen will generate a Makefile in the HTML
+# output directory. Running make will produce the docset in that directory and
+# running make install will install the docset in
# ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at
-# startup. See http://developer.apple.com/tools/creatingdocsetswithdoxygen.html
-# for more information.
+# startup. See https://developer.apple.com/library/archive/featuredarticles/Doxy
+# genXcode/_index.html for more information.
# The default value is: NO.
# This tag requires that the tag GENERATE_HTML is set to YES.
@@ -1218,8 +1361,8 @@ DOCSET_PUBLISHER_NAME = Publisher
# If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three
# additional HTML index files: index.hhp, index.hhc, and index.hhk. The
# index.hhp is a project file that can be read by Microsoft's HTML Help Workshop
-# (see: http://www.microsoft.com/en-us/download/details.aspx?id=21138) on
-# Windows.
+# (see:
+# https://www.microsoft.com/en-us/download/details.aspx?id=21138) on Windows.
#
# The HTML Help Workshop contains a compiler that can convert all HTML output
# generated by doxygen into a single compiled HTML file (.chm). Compiled HTML
@@ -1241,28 +1384,28 @@ GENERATE_HTMLHELP = NO
CHM_FILE =
# The HHC_LOCATION tag can be used to specify the location (absolute path
-# including file name) of the HTML help compiler ( hhc.exe). If non-empty
+# including file name) of the HTML help compiler (hhc.exe). If non-empty,
# doxygen will try to run the HTML help compiler on the generated index.hhp.
# The file has to be specified with full path.
# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
HHC_LOCATION =
-# The GENERATE_CHI flag controls if a separate .chi index file is generated (
-# YES) or that it should be included in the master .chm file ( NO).
+# The GENERATE_CHI flag controls if a separate .chi index file is generated
+# (YES) or that it should be included in the main .chm file (NO).
# The default value is: NO.
# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
GENERATE_CHI = NO
-# The CHM_INDEX_ENCODING is used to encode HtmlHelp index ( hhk), content ( hhc)
+# The CHM_INDEX_ENCODING is used to encode HtmlHelp index (hhk), content (hhc)
# and project file content.
# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
CHM_INDEX_ENCODING =
-# The BINARY_TOC flag controls whether a binary table of contents is generated (
-# YES) or a normal table of contents ( NO) in the .chm file. Furthermore it
+# The BINARY_TOC flag controls whether a binary table of contents is generated
+# (YES) or a normal table of contents (NO) in the .chm file. Furthermore it
# enables the Previous and Next buttons.
# The default value is: NO.
# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
@@ -1294,7 +1437,8 @@ QCH_FILE =
# The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help
# Project output. For more information please see Qt Help Project / Namespace
-# (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#namespace).
+# (see:
+# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#namespace).
# The default value is: org.doxygen.Project.
# This tag requires that the tag GENERATE_QHP is set to YES.
@@ -1302,8 +1446,8 @@ QHP_NAMESPACE = org.doxygen.Project
# The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt
# Help Project output. For more information please see Qt Help Project / Virtual
-# Folders (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#virtual-
-# folders).
+# Folders (see:
+# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#virtual-folders).
# The default value is: doc.
# This tag requires that the tag GENERATE_QHP is set to YES.
@@ -1311,30 +1455,30 @@ QHP_VIRTUAL_FOLDER = doc
# If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom
# filter to add. For more information please see Qt Help Project / Custom
-# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom-
-# filters).
+# Filters (see:
+# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters).
# This tag requires that the tag GENERATE_QHP is set to YES.
QHP_CUST_FILTER_NAME =
# The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the
# custom filter to add. For more information please see Qt Help Project / Custom
-# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom-
-# filters).
+# Filters (see:
+# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters).
# This tag requires that the tag GENERATE_QHP is set to YES.
QHP_CUST_FILTER_ATTRS =
# The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this
# project's filter section matches. Qt Help Project / Filter Attributes (see:
-# http://qt-project.org/doc/qt-4.8/qthelpproject.html#filter-attributes).
+# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#filter-attributes).
# This tag requires that the tag GENERATE_QHP is set to YES.
QHP_SECT_FILTER_ATTRS =
-# The QHG_LOCATION tag can be used to specify the location of Qt's
-# qhelpgenerator. If non-empty doxygen will try to run qhelpgenerator on the
-# generated .qhp file.
+# The QHG_LOCATION tag can be used to specify the location (absolute path
+# including file name) of Qt's qhelpgenerator. If non-empty doxygen will try to
+# run qhelpgenerator on the generated .qhp file.
# This tag requires that the tag GENERATE_QHP is set to YES.
QHG_LOCATION =
@@ -1376,7 +1520,7 @@ DISABLE_INDEX = NO
# index structure (just like the one that is generated for HTML Help). For this
# to work a browser that supports JavaScript, DHTML, CSS and frames is required
# (i.e. any modern browser). Windows users are probably better off using the
-# HTML help feature. Via custom stylesheets (see HTML_EXTRA_STYLESHEET) one can
+# HTML help feature. Via custom style sheets (see HTML_EXTRA_STYLESHEET) one can
# further fine-tune the look of the index. As an example, the default style
# sheet generated by doxygen has an example that shows how to put an image at
# the root of the tree instead of the PROJECT_NAME. Since the tree basically has
@@ -1404,13 +1548,24 @@ ENUM_VALUES_PER_LINE = 4
TREEVIEW_WIDTH = 250
-# When the EXT_LINKS_IN_WINDOW option is set to YES doxygen will open links to
+# If the EXT_LINKS_IN_WINDOW option is set to YES, doxygen will open links to
# external symbols imported via tag files in a separate window.
# The default value is: NO.
# This tag requires that the tag GENERATE_HTML is set to YES.
EXT_LINKS_IN_WINDOW = NO
+# If the HTML_FORMULA_FORMAT option is set to svg, doxygen will use the pdf2svg
+# tool (see https://github.com/dawbarton/pdf2svg) or inkscape (see
+# https://inkscape.org) to generate formulas as SVG images instead of PNGs for
+# the HTML output. These images will generally look nicer at scaled resolutions.
+# Possible values are: png (the default) and svg (looks nicer but requires the
+# pdf2svg or inkscape tool).
+# The default value is: png.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_FORMULA_FORMAT = png
+
# Use this tag to change the font size of LaTeX formulas included as images in
# the HTML documentation. When you change the font size after a successful
# doxygen run you need to manually remove any form_*.png images from the HTML
@@ -1420,7 +1575,7 @@ EXT_LINKS_IN_WINDOW = NO
FORMULA_FONTSIZE = 10
-# Use the FORMULA_TRANPARENT tag to determine whether or not the images
+# Use the FORMULA_TRANSPARENT tag to determine whether or not the images
# generated for formulas are transparent PNGs. Transparent PNGs are not
# supported properly for IE 6.0, but are supported on all modern browsers.
#
@@ -1431,9 +1586,15 @@ FORMULA_FONTSIZE = 10
FORMULA_TRANSPARENT = YES
+# The FORMULA_MACROFILE can contain LaTeX \newcommand and \renewcommand commands
+# to create new LaTeX commands to be used in formulas as building blocks. See
+# the section "Including formulas" for details.
+
+FORMULA_MACROFILE =
+
# Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see
-# http://www.mathjax.org) which uses client side Javascript for the rendering
-# instead of using prerendered bitmaps. Use this if you do not have LaTeX
+# https://www.mathjax.org) which uses client side JavaScript for the rendering
+# instead of using pre-rendered bitmaps. Use this if you do not have LaTeX
# installed or if you want to formulas look prettier in the HTML output. When
# enabled you may also need to install MathJax separately and configure the path
# to it using the MATHJAX_RELPATH option.
@@ -1444,13 +1605,13 @@ USE_MATHJAX = NO
# When MathJax is enabled you can set the default output format to be used for
# the MathJax output. See the MathJax site (see:
-# http://docs.mathjax.org/en/latest/output.html) for more details.
+# http://docs.mathjax.org/en/v2.7-latest/output.html) for more details.
# Possible values are: HTML-CSS (which is slower, but has the best
# compatibility), NativeMML (i.e. MathML) and SVG.
# The default value is: HTML-CSS.
# This tag requires that the tag USE_MATHJAX is set to YES.
-#MATHJAX_FORMAT = HTML-CSS
+MATHJAX_FORMAT = HTML-CSS
# When MathJax is enabled you need to specify the location relative to the HTML
# output directory using the MATHJAX_RELPATH option. The destination directory
@@ -1459,8 +1620,8 @@ USE_MATHJAX = NO
# MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax
# Content Delivery Network so you can quickly see the result without installing
# MathJax. However, it is strongly recommended to install a local copy of
-# MathJax from http://www.mathjax.org before deployment.
-# The default value is: http://cdn.mathjax.org/mathjax/latest.
+# MathJax from https://www.mathjax.org before deployment.
+# The default value is: https://cdn.jsdelivr.net/npm/mathjax@2.
# This tag requires that the tag USE_MATHJAX is set to YES.
MATHJAX_RELPATH = http://www.mathjax.org/mathjax
@@ -1474,11 +1635,12 @@ MATHJAX_EXTENSIONS =
# The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces
# of code that will be used on startup of the MathJax code. See the MathJax site
-# (see: http://docs.mathjax.org/en/latest/output.html) for more details. For an
+# (see:
+# http://docs.mathjax.org/en/v2.7-latest/output.html) for more details. For an
# example see the documentation.
# This tag requires that the tag USE_MATHJAX is set to YES.
-#MATHJAX_CODEFILE =
+MATHJAX_CODEFILE =
# When the SEARCHENGINE tag is enabled doxygen will generate a search box for
# the HTML output. The underlying search engine uses javascript and DHTML and
@@ -1502,7 +1664,7 @@ MATHJAX_EXTENSIONS =
SEARCHENGINE = YES
# When the SERVER_BASED_SEARCH tag is enabled the search engine will be
-# implemented using a web server instead of a web client using Javascript. There
+# implemented using a web server instead of a web client using JavaScript. There
# are two flavors of web server based searching depending on the EXTERNAL_SEARCH
# setting. When disabled, doxygen will generate a PHP script for searching and
# an index file used by the script. When EXTERNAL_SEARCH is enabled the indexing
@@ -1519,26 +1681,28 @@ SERVER_BASED_SEARCH = NO
# external search engine pointed to by the SEARCHENGINE_URL option to obtain the
# search results.
#
-# Doxygen ships with an example indexer ( doxyindexer) and search engine
+# Doxygen ships with an example indexer (doxyindexer) and search engine
# (doxysearch.cgi) which are based on the open source search engine library
-# Xapian (see: http://xapian.org/).
+# Xapian (see:
+# https://xapian.org/).
#
# See the section "External Indexing and Searching" for details.
# The default value is: NO.
# This tag requires that the tag SEARCHENGINE is set to YES.
-#EXTERNAL_SEARCH = NO
+EXTERNAL_SEARCH = NO
# The SEARCHENGINE_URL should point to a search engine hosted by a web server
# which will return the search results when EXTERNAL_SEARCH is enabled.
#
-# Doxygen ships with an example indexer ( doxyindexer) and search engine
+# Doxygen ships with an example indexer (doxyindexer) and search engine
# (doxysearch.cgi) which are based on the open source search engine library
-# Xapian (see: http://xapian.org/). See the section "External Indexing and
-# Searching" for details.
+# Xapian (see:
+# https://xapian.org/). See the section "External Indexing and Searching" for
+# details.
# This tag requires that the tag SEARCHENGINE is set to YES.
-#SEARCHENGINE_URL =
+SEARCHENGINE_URL =
# When SERVER_BASED_SEARCH and EXTERNAL_SEARCH are both enabled the unindexed
# search data is written to a file for indexing by an external tool. With the
@@ -1546,7 +1710,7 @@ SERVER_BASED_SEARCH = NO
# The default file is: searchdata.xml.
# This tag requires that the tag SEARCHENGINE is set to YES.
-#SEARCHDATA_FILE = searchdata.xml
+SEARCHDATA_FILE = searchdata.xml
# When SERVER_BASED_SEARCH and EXTERNAL_SEARCH are both enabled the
# EXTERNAL_SEARCH_ID tag can be used as an identifier for the project. This is
@@ -1554,7 +1718,7 @@ SERVER_BASED_SEARCH = NO
# projects and redirect the results back to the right project.
# This tag requires that the tag SEARCHENGINE is set to YES.
-#EXTERNAL_SEARCH_ID =
+EXTERNAL_SEARCH_ID =
# The EXTRA_SEARCH_MAPPINGS tag can be used to enable searching through doxygen
# projects other than the one defined by this configuration file, but that are
@@ -1564,13 +1728,13 @@ SERVER_BASED_SEARCH = NO
# EXTRA_SEARCH_MAPPINGS = tagname1=loc1 tagname2=loc2 ...
# This tag requires that the tag SEARCHENGINE is set to YES.
-#EXTRA_SEARCH_MAPPINGS =
+EXTRA_SEARCH_MAPPINGS =
#---------------------------------------------------------------------------
# Configuration options related to the LaTeX output
#---------------------------------------------------------------------------
-# If the GENERATE_LATEX tag is set to YES doxygen will generate LaTeX output.
+# If the GENERATE_LATEX tag is set to YES, doxygen will generate LaTeX output.
# The default value is: YES.
GENERATE_LATEX = YES
@@ -1586,22 +1750,36 @@ LATEX_OUTPUT = latex
# The LATEX_CMD_NAME tag can be used to specify the LaTeX command name to be
# invoked.
#
-# Note that when enabling USE_PDFLATEX this option is only used for generating
-# bitmaps for formulas in the HTML output, but not in the Makefile that is
-# written to the output directory.
-# The default file is: latex.
+# Note that when not enabling USE_PDFLATEX the default is latex when enabling
+# USE_PDFLATEX the default is pdflatex and when in the later case latex is
+# chosen this is overwritten by pdflatex. For specific output languages the
+# default can have been set differently, this depends on the implementation of
+# the output language.
# This tag requires that the tag GENERATE_LATEX is set to YES.
LATEX_CMD_NAME = latex
# The MAKEINDEX_CMD_NAME tag can be used to specify the command name to generate
# index for LaTeX.
+# Note: This tag is used in the Makefile / make.bat.
+# See also: LATEX_MAKEINDEX_CMD for the part in the generated output file
+# (.tex).
# The default file is: makeindex.
# This tag requires that the tag GENERATE_LATEX is set to YES.
MAKEINDEX_CMD_NAME = makeindex
-# If the COMPACT_LATEX tag is set to YES doxygen generates more compact LaTeX
+# The LATEX_MAKEINDEX_CMD tag can be used to specify the command name to
+# generate index for LaTeX. In case there is no backslash (\) as first character
+# it will be automatically added in the LaTeX code.
+# Note: This tag is used in the generated output file (.tex).
+# See also: MAKEINDEX_CMD_NAME for the part in the Makefile / make.bat.
+# The default value is: makeindex.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_MAKEINDEX_CMD = makeindex
+
+# If the COMPACT_LATEX tag is set to YES, doxygen generates more compact LaTeX
# documents. This may be useful for small projects and may help to save some
# trees in general.
# The default value is: NO.
@@ -1619,9 +1797,12 @@ COMPACT_LATEX = NO
PAPER_TYPE = a4
# The EXTRA_PACKAGES tag can be used to specify one or more LaTeX package names
-# that should be included in the LaTeX output. To get the times font for
-# instance you can specify
-# EXTRA_PACKAGES=times
+# that should be included in the LaTeX output. The package can be specified just
+# by its name or with the correct syntax as to be used with the LaTeX
+# \usepackage command. To get the times font for instance you can specify :
+# EXTRA_PACKAGES=times or EXTRA_PACKAGES={times}
+# To use the option intlimits with the amsmath package you can specify:
+# EXTRA_PACKAGES=[intlimits]{amsmath}
# If left blank no extra packages will be included.
# This tag requires that the tag GENERATE_LATEX is set to YES.
@@ -1636,9 +1817,9 @@ EXTRA_PACKAGES =
# Note: Only use a user-defined header if you know what you are doing! The
# following commands have a special meaning inside the header: $title,
# $datetime, $date, $doxygenversion, $projectname, $projectnumber,
-# $projectbrief, $projectlogo. Doxygen will replace $title with the empy string,
-# for the replacement values of the other commands the user is refered to
-# HTML_HEADER.
+# $projectbrief, $projectlogo. Doxygen will replace $title with the empty
+# string, for the replacement values of the other commands the user is referred
+# to HTML_HEADER.
# This tag requires that the tag GENERATE_LATEX is set to YES.
LATEX_HEADER =
@@ -1654,13 +1835,24 @@ LATEX_HEADER =
LATEX_FOOTER =
+# The LATEX_EXTRA_STYLESHEET tag can be used to specify additional user-defined
+# LaTeX style sheets that are included after the standard style sheets created
+# by doxygen. Using this option one can overrule certain style aspects. Doxygen
+# will copy the style sheet files to the output directory.
+# Note: The order of the extra style sheet files is of importance (e.g. the last
+# style sheet in the list overrules the setting of the previous ones in the
+# list).
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_EXTRA_STYLESHEET =
+
# The LATEX_EXTRA_FILES tag can be used to specify one or more extra images or
# other source files which should be copied to the LATEX_OUTPUT output
# directory. Note that the files will be copied as-is; there are no commands or
# markers available.
# This tag requires that the tag GENERATE_LATEX is set to YES.
-#LATEX_EXTRA_FILES =
+LATEX_EXTRA_FILES =
# If the PDF_HYPERLINKS tag is set to YES, the LaTeX that is generated is
# prepared for conversion to PDF (using ps2pdf or pdflatex). The PDF file will
@@ -1671,9 +1863,11 @@ LATEX_FOOTER =
PDF_HYPERLINKS = YES
-# If the USE_PDFLATEX tag is set to YES, doxygen will use pdflatex to generate
-# the PDF file directly from the LaTeX files. Set this option to YES to get a
-# higher quality PDF documentation.
+# If the USE_PDFLATEX tag is set to YES, doxygen will use the engine as
+# specified with LATEX_CMD_NAME to generate the PDF file directly from the LaTeX
+# files. Set this option to YES, to get a higher quality PDF documentation.
+#
+# See also section LATEX_CMD_NAME for selecting the engine.
# The default value is: YES.
# This tag requires that the tag GENERATE_LATEX is set to YES.
@@ -1707,17 +1901,33 @@ LATEX_SOURCE_CODE = NO
# The LATEX_BIB_STYLE tag can be used to specify the style to use for the
# bibliography, e.g. plainnat, or ieeetr. See
-# http://en.wikipedia.org/wiki/BibTeX and \cite for more info.
+# https://en.wikipedia.org/wiki/BibTeX and \cite for more info.
# The default value is: plain.
# This tag requires that the tag GENERATE_LATEX is set to YES.
LATEX_BIB_STYLE = plain
+# If the LATEX_TIMESTAMP tag is set to YES then the footer of each generated
+# page will contain the date and time when the page was generated. Setting this
+# to NO can help when comparing the output of multiple runs.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_TIMESTAMP = NO
+
+# The LATEX_EMOJI_DIRECTORY tag is used to specify the (relative or absolute)
+# path from which the emoji images will be read. If a relative path is entered,
+# it will be relative to the LATEX_OUTPUT directory. If left blank the
+# LATEX_OUTPUT directory will be used.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_EMOJI_DIRECTORY =
+
#---------------------------------------------------------------------------
# Configuration options related to the RTF output
#---------------------------------------------------------------------------
-# If the GENERATE_RTF tag is set to YES doxygen will generate RTF output. The
+# If the GENERATE_RTF tag is set to YES, doxygen will generate RTF output. The
# RTF output is optimized for Word 97 and may not look too pretty with other RTF
# readers/editors.
# The default value is: NO.
@@ -1732,7 +1942,7 @@ GENERATE_RTF = NO
RTF_OUTPUT = rtf
-# If the COMPACT_RTF tag is set to YES doxygen generates more compact RTF
+# If the COMPACT_RTF tag is set to YES, doxygen generates more compact RTF
# documents. This may be useful for small projects and may help to save some
# trees in general.
# The default value is: NO.
@@ -1752,9 +1962,9 @@ COMPACT_RTF = NO
RTF_HYPERLINKS = NO
-# Load stylesheet definitions from file. Syntax is similar to doxygen's config
-# file, i.e. a series of assignments. You only have to provide replacements,
-# missing definitions are set to their default value.
+# Load stylesheet definitions from file. Syntax is similar to doxygen's
+# configuration file, i.e. a series of assignments. You only have to provide
+# replacements, missing definitions are set to their default value.
#
# See also section "Doxygen usage" for information on how to generate the
# default style sheet that doxygen normally uses.
@@ -1763,17 +1973,27 @@ RTF_HYPERLINKS = NO
RTF_STYLESHEET_FILE =
# Set optional variables used in the generation of an RTF document. Syntax is
-# similar to doxygen's config file. A template extensions file can be generated
-# using doxygen -e rtf extensionFile.
+# similar to doxygen's configuration file. A template extensions file can be
+# generated using doxygen -e rtf extensionFile.
# This tag requires that the tag GENERATE_RTF is set to YES.
RTF_EXTENSIONS_FILE =
+# If the RTF_SOURCE_CODE tag is set to YES then doxygen will include source code
+# with syntax highlighting in the RTF output.
+#
+# Note that which sources are shown also depends on other settings such as
+# SOURCE_BROWSER.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_RTF is set to YES.
+
+RTF_SOURCE_CODE = NO
+
#---------------------------------------------------------------------------
# Configuration options related to the man page output
#---------------------------------------------------------------------------
-# If the GENERATE_MAN tag is set to YES doxygen will generate man pages for
+# If the GENERATE_MAN tag is set to YES, doxygen will generate man pages for
# classes and files.
# The default value is: NO.
@@ -1802,7 +2022,7 @@ MAN_EXTENSION = .3
# MAN_EXTENSION with the initial . removed.
# This tag requires that the tag GENERATE_MAN is set to YES.
-#MAN_SUBDIR =
+MAN_SUBDIR =
# If the MAN_LINKS tag is set to YES and doxygen generates man output, then it
# will generate one additional man file for each entity documented in the real
@@ -1817,7 +2037,7 @@ MAN_LINKS = NO
# Configuration options related to the XML output
#---------------------------------------------------------------------------
-# If the GENERATE_XML tag is set to YES doxygen will generate an XML file that
+# If the GENERATE_XML tag is set to YES, doxygen will generate an XML file that
# captures the structure of the code including all documentation.
# The default value is: NO.
@@ -1831,7 +2051,7 @@ GENERATE_XML = YES
XML_OUTPUT = xml
-# If the XML_PROGRAMLISTING tag is set to YES doxygen will dump the program
+# If the XML_PROGRAMLISTING tag is set to YES, doxygen will dump the program
# listings (including syntax highlighting and cross-referencing information) to
# the XML output. Note that enabling this will significantly increase the size
# of the XML output.
@@ -1840,15 +2060,22 @@ XML_OUTPUT = xml
XML_PROGRAMLISTING = YES
+# If the XML_NS_MEMB_FILE_SCOPE tag is set to YES, doxygen will include
+# namespace members in file scope as well, matching the HTML output.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_XML is set to YES.
+
+XML_NS_MEMB_FILE_SCOPE = NO
+
#---------------------------------------------------------------------------
# Configuration options related to the DOCBOOK output
#---------------------------------------------------------------------------
-# If the GENERATE_DOCBOOK tag is set to YES doxygen will generate Docbook files
+# If the GENERATE_DOCBOOK tag is set to YES, doxygen will generate Docbook files
# that can be used to generate PDF.
# The default value is: NO.
-#GENERATE_DOCBOOK = NO
+GENERATE_DOCBOOK = NO
# The DOCBOOK_OUTPUT tag is used to specify where the Docbook pages will be put.
# If a relative path is entered the value of OUTPUT_DIRECTORY will be put in
@@ -1856,25 +2083,25 @@ XML_PROGRAMLISTING = YES
# The default directory is: docbook.
# This tag requires that the tag GENERATE_DOCBOOK is set to YES.
-#DOCBOOK_OUTPUT = docbook
+DOCBOOK_OUTPUT = docbook
-# If the DOCBOOK_PROGRAMLISTING tag is set to YES doxygen will include the
+# If the DOCBOOK_PROGRAMLISTING tag is set to YES, doxygen will include the
# program listings (including syntax highlighting and cross-referencing
# information) to the DOCBOOK output. Note that enabling this will significantly
# increase the size of the DOCBOOK output.
# The default value is: NO.
# This tag requires that the tag GENERATE_DOCBOOK is set to YES.
-#DOCBOOK_PROGRAMLISTING = NO
+DOCBOOK_PROGRAMLISTING = NO
#---------------------------------------------------------------------------
# Configuration options for the AutoGen Definitions output
#---------------------------------------------------------------------------
-# If the GENERATE_AUTOGEN_DEF tag is set to YES doxygen will generate an AutoGen
-# Definitions (see http://autogen.sf.net) file that captures the structure of
-# the code including all documentation. Note that this feature is still
-# experimental and incomplete at the moment.
+# If the GENERATE_AUTOGEN_DEF tag is set to YES, doxygen will generate an
+# AutoGen Definitions (see http://autogen.sourceforge.net/) file that captures
+# the structure of the code including all documentation. Note that this feature
+# is still experimental and incomplete at the moment.
# The default value is: NO.
GENERATE_AUTOGEN_DEF = NO
@@ -1883,7 +2110,7 @@ GENERATE_AUTOGEN_DEF = NO
# Configuration options related to the Perl module output
#---------------------------------------------------------------------------
-# If the GENERATE_PERLMOD tag is set to YES doxygen will generate a Perl module
+# If the GENERATE_PERLMOD tag is set to YES, doxygen will generate a Perl module
# file that captures the structure of the code including all documentation.
#
# Note that this feature is still experimental and incomplete at the moment.
@@ -1891,7 +2118,7 @@ GENERATE_AUTOGEN_DEF = NO
GENERATE_PERLMOD = NO
-# If the PERLMOD_LATEX tag is set to YES doxygen will generate the necessary
+# If the PERLMOD_LATEX tag is set to YES, doxygen will generate the necessary
# Makefile rules, Perl scripts and LaTeX code to be able to generate PDF and DVI
# output from the Perl module output.
# The default value is: NO.
@@ -1899,9 +2126,9 @@ GENERATE_PERLMOD = NO
PERLMOD_LATEX = NO
-# If the PERLMOD_PRETTY tag is set to YES the Perl module output will be nicely
+# If the PERLMOD_PRETTY tag is set to YES, the Perl module output will be nicely
# formatted so it can be parsed by a human reader. This is useful if you want to
-# understand what is going on. On the other hand, if this tag is set to NO the
+# understand what is going on. On the other hand, if this tag is set to NO, the
# size of the Perl module output will be much smaller and Perl will parse it
# just the same.
# The default value is: YES.
@@ -1921,14 +2148,14 @@ PERLMOD_MAKEVAR_PREFIX =
# Configuration options related to the preprocessor
#---------------------------------------------------------------------------
-# If the ENABLE_PREPROCESSING tag is set to YES doxygen will evaluate all
+# If the ENABLE_PREPROCESSING tag is set to YES, doxygen will evaluate all
# C-preprocessor directives found in the sources and include files.
# The default value is: YES.
ENABLE_PREPROCESSING = YES
-# If the MACRO_EXPANSION tag is set to YES doxygen will expand all macro names
-# in the source code. If set to NO only conditional compilation will be
+# If the MACRO_EXPANSION tag is set to YES, doxygen will expand all macro names
+# in the source code. If set to NO, only conditional compilation will be
# performed. Macro expansion can be done in a controlled way by setting
# EXPAND_ONLY_PREDEF to YES.
# The default value is: NO.
@@ -1944,7 +2171,7 @@ MACRO_EXPANSION = YES
EXPAND_ONLY_PREDEF = YES
-# If the SEARCH_INCLUDES tag is set to YES the includes files in the
+# If the SEARCH_INCLUDES tag is set to YES, the include files in the
# INCLUDE_PATH will be searched if a #include is found.
# The default value is: YES.
# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
@@ -1975,8 +2202,8 @@ INCLUDE_FILE_PATTERNS =
# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
PREDEFINED = DMLC_USE_CXX11 \
- "XGB_DLL=" \
- "XGB_EXTERN_C="
+ XGB_DLL= \
+ XGB_EXTERN_C=
# If the MACRO_EXPANSION and EXPAND_ONLY_PREDEF tags are set to YES then this
# tag can be used to specify a list of macro names that should be expanded. The
@@ -2022,37 +2249,32 @@ TAGFILES =
GENERATE_TAGFILE =
-# If the ALLEXTERNALS tag is set to YES all external class will be listed in the
-# class index. If set to NO only the inherited external classes will be listed.
+# If the ALLEXTERNALS tag is set to YES, all external class will be listed in
+# the class index. If set to NO, only the inherited external classes will be
+# listed.
# The default value is: NO.
ALLEXTERNALS = NO
-# If the EXTERNAL_GROUPS tag is set to YES all external groups will be listed in
-# the modules index. If set to NO, only the current project's groups will be
+# If the EXTERNAL_GROUPS tag is set to YES, all external groups will be listed
+# in the modules index. If set to NO, only the current project's groups will be
# listed.
# The default value is: YES.
EXTERNAL_GROUPS = YES
-# If the EXTERNAL_PAGES tag is set to YES all external pages will be listed in
+# If the EXTERNAL_PAGES tag is set to YES, all external pages will be listed in
# the related pages index. If set to NO, only the current project's pages will
# be listed.
# The default value is: YES.
-#EXTERNAL_PAGES = YES
-
-# The PERL_PATH should be the absolute path and name of the perl script
-# interpreter (i.e. the result of 'which perl').
-# The default file (with absolute path) is: /usr/bin/perl.
-
-PERL_PATH = /usr/bin/perl
+EXTERNAL_PAGES = YES
#---------------------------------------------------------------------------
# Configuration options related to the dot tool
#---------------------------------------------------------------------------
-# If the CLASS_DIAGRAMS tag is set to YES doxygen will generate a class diagram
+# If the CLASS_DIAGRAMS tag is set to YES, doxygen will generate a class diagram
# (in HTML and LaTeX) for classes with base or super classes. Setting the tag to
# NO turns the diagrams off. Note that this option also works with HAVE_DOT
# disabled, but it is recommended to install and use dot, since it yields more
@@ -2061,23 +2283,14 @@ PERL_PATH = /usr/bin/perl
CLASS_DIAGRAMS = YES
-# You can define message sequence charts within doxygen comments using the \msc
-# command. Doxygen will then run the mscgen tool (see:
-# http://www.mcternan.me.uk/mscgen/)) to produce the chart and insert it in the
-# documentation. The MSCGEN_PATH tag allows you to specify the directory where
-# the mscgen tool resides. If left empty the tool is assumed to be found in the
-# default search path.
-
-MSCGEN_PATH =
-
# You can include diagrams made with dia in doxygen documentation. Doxygen will
# then run dia to produce the diagram and insert it in the documentation. The
# DIA_PATH tag allows you to specify the directory where the dia binary resides.
# If left empty dia is assumed to be found in the default search path.
-#DIA_PATH =
+DIA_PATH =
-# If set to YES, the inheritance and collaboration graphs will hide inheritance
+# If set to YES the inheritance and collaboration graphs will hide inheritance
# and usage relations if the target is undocumented or is not a class.
# The default value is: YES.
@@ -2150,7 +2363,7 @@ COLLABORATION_GRAPH = YES
GROUP_GRAPHS = YES
-# If the UML_LOOK tag is set to YES doxygen will generate inheritance and
+# If the UML_LOOK tag is set to YES, doxygen will generate inheritance and
# collaboration diagrams in a style similar to the OMG's Unified Modeling
# Language.
# The default value is: NO.
@@ -2167,9 +2380,31 @@ UML_LOOK = YES
# but if the number exceeds 15, the total amount of fields shown is limited to
# 10.
# Minimum value: 0, maximum value: 100, default value: 10.
+# This tag requires that the tag UML_LOOK is set to YES.
+
+UML_LIMIT_NUM_FIELDS = 10
+
+# If the DOT_UML_DETAILS tag is set to NO, doxygen will show attributes and
+# methods without types and arguments in the UML graphs. If the DOT_UML_DETAILS
+# tag is set to YES, doxygen will add type and arguments for attributes and
+# methods in the UML graphs. If the DOT_UML_DETAILS tag is set to NONE, doxygen
+# will not generate fields with class member information in the UML graphs. The
+# class diagrams will look similar to the default class diagrams but using UML
+# notation for the relationships.
+# Possible values are: NO, YES and NONE.
+# The default value is: NO.
+# This tag requires that the tag UML_LOOK is set to YES.
+
+DOT_UML_DETAILS = NO
+
+# The DOT_WRAP_THRESHOLD tag can be used to set the maximum number of characters
+# to display on a single line. If the actual line length exceeds this threshold
+# significantly it will wrapped across multiple lines. Some heuristics are apply
+# to avoid ugly line breaks.
+# Minimum value: 0, maximum value: 1000, default value: 17.
# This tag requires that the tag HAVE_DOT is set to YES.
-#UML_LIMIT_NUM_FIELDS = 10
+DOT_WRAP_THRESHOLD = 17
# If the TEMPLATE_RELATIONS tag is set to YES then the inheritance and
# collaboration graphs will show the relations between templates and their
@@ -2202,7 +2437,8 @@ INCLUDED_BY_GRAPH = YES
#
# Note that enabling this option will significantly increase the time of a run.
# So in most cases it will be better to enable call graphs for selected
-# functions only using the \callgraph command.
+# functions only using the \callgraph command. Disabling a call graph can be
+# accomplished by means of the command \hidecallgraph.
# The default value is: NO.
# This tag requires that the tag HAVE_DOT is set to YES.
@@ -2213,7 +2449,8 @@ CALL_GRAPH = NO
#
# Note that enabling this option will significantly increase the time of a run.
# So in most cases it will be better to enable caller graphs for selected
-# functions only using the \callergraph command.
+# functions only using the \callergraph command. Disabling a caller graph can be
+# accomplished by means of the command \hidecallergraph.
# The default value is: NO.
# This tag requires that the tag HAVE_DOT is set to YES.
@@ -2236,13 +2473,17 @@ GRAPHICAL_HIERARCHY = YES
DIRECTORY_GRAPH = YES
# The DOT_IMAGE_FORMAT tag can be used to set the image format of the images
-# generated by dot.
+# generated by dot. For an explanation of the image formats see the section
+# output formats in the documentation of the dot tool (Graphviz (see:
+# http://www.graphviz.org/)).
# Note: If you choose svg you need to set HTML_FILE_EXTENSION to xhtml in order
# to make the SVG files visible in IE 9+ (other browsers do not have this
# requirement).
# Possible values are: png, png:cairo, png:cairo:cairo, png:cairo:gd, png:gd,
# png:gd:gd, jpg, jpg:cairo, jpg:cairo:gd, jpg:gd, jpg:gd:gd, gif, gif:cairo,
-# gif:cairo:gd, gif:gd, gif:gd:gd and svg.
+# gif:cairo:gd, gif:gd, gif:gd:gd, svg, png:gd, png:gd:gd, png:cairo,
+# png:cairo:gd, png:cairo:cairo, png:cairo:gdiplus, png:gdiplus and
+# png:gdiplus:gdiplus.
# The default value is: png.
# This tag requires that the tag HAVE_DOT is set to YES.
@@ -2283,16 +2524,25 @@ MSCFILE_DIRS =
# contain dia files that are included in the documentation (see the \diafile
# command).
-#DIAFILE_DIRS =
+DIAFILE_DIRS =
# When using plantuml, the PLANTUML_JAR_PATH tag should be used to specify the
# path where java can find the plantuml.jar file. If left blank, it is assumed
# PlantUML is not used or called during a preprocessing step. Doxygen will
# generate a warning when it encounters a \startuml command in this case and
# will not generate output for the diagram.
-# This tag requires that the tag HAVE_DOT is set to YES.
-#PLANTUML_JAR_PATH =
+PLANTUML_JAR_PATH =
+
+# When using plantuml, the PLANTUML_CFG_FILE tag can be used to specify a
+# configuration file for plantuml.
+
+PLANTUML_CFG_FILE =
+
+# When using plantuml, the specified paths are searched for files specified by
+# the !include statement in a plantuml block.
+
+PLANTUML_INCLUDE_PATH =
# The DOT_GRAPH_MAX_NODES tag can be used to set the maximum number of nodes
# that will be shown in the graph. If the number of nodes in a graph becomes
@@ -2330,7 +2580,7 @@ MAX_DOT_GRAPH_DEPTH = 0
DOT_TRANSPARENT = NO
-# Set the DOT_MULTI_TARGETS tag to YES allow dot to generate multiple output
+# Set the DOT_MULTI_TARGETS tag to YES to allow dot to generate multiple output
# files in one run (i.e. multiple -o and -T options on the command line). This
# makes dot run faster, but since only newer versions of dot (>1.8.10) support
# this, this feature is disabled by default.
@@ -2347,9 +2597,11 @@ DOT_MULTI_TARGETS = YES
GENERATE_LEGEND = YES
-# If the DOT_CLEANUP tag is set to YES doxygen will remove the intermediate dot
+# If the DOT_CLEANUP tag is set to YES, doxygen will remove the intermediate
# files that are used to generate the various graphs.
+#
+# Note: This setting is not only used for dot files but also for msc and
+# plantuml temporary files.
# The default value is: YES.
-# This tag requires that the tag HAVE_DOT is set to YES.
DOT_CLEANUP = YES
diff --git a/doc/build.rst b/doc/build.rst
index 53d9a3209..e78d2d2f4 100644
--- a/doc/build.rst
+++ b/doc/build.rst
@@ -12,6 +12,7 @@ systems. If the instructions do not work for you, please feel free to ask quest
Consider installing XGBoost from a pre-built binary, to avoid the trouble of building XGBoost from the source. Checkout :doc:`Installation Guide `.
.. contents:: Contents
+ :local:
.. _get_source:
@@ -152,11 +153,11 @@ On Windows, run CMake as follows:
mkdir build
cd build
- cmake .. -G"Visual Studio 14 2015 Win64" -DUSE_CUDA=ON
+ cmake .. -G"Visual Studio 17 2022" -A x64 -DUSE_CUDA=ON
(Change the ``-G`` option appropriately if you have a different version of Visual Studio installed.)
-The above cmake configuration run will create an ``xgboost.sln`` solution file in the build directory. Build this solution in release mode as a x64 build, either from Visual studio or from command line:
+The above cmake configuration run will create an ``xgboost.sln`` solution file in the build directory. Build this solution in Release mode, either from Visual studio or from command line:
.. code-block:: bash
@@ -176,111 +177,104 @@ Building Python Package with Default Toolchains
===============================================
There are several ways to build and install the package from source:
-1. Use Python setuptools directly
+1. Build C++ core with CMake first
- The XGBoost Python package supports most of the setuptools commands, here is a list of tested commands:
+ You can first build C++ library using CMake as described in :ref:`build_shared_lib`.
+ After compilation, a shared library will appear in ``lib/`` directory.
+ On Linux distributions, the shared library is ``lib/libxgboost.so``.
+ The install script ``pip install .`` will reuse the shared library instead of compiling
+ it from scratch, making it quite fast to run.
+
+ .. code-block:: console
+
+ $ cd python-package/
+ $ pip install . # Will re-use lib/libxgboost.so
+
+2. Install the Python package directly
+
+ You can navigate to ``python-package/`` directory and install the Python package directly
+ by running
+
+ .. code-block:: console
+
+ $ cd python-package/
+ $ pip install -v .
+
+ which will compile XGBoost's native (C++) code using default CMake flags.
+ To enable additional compilation options, pass corresponding ``--config-settings``:
+
+ .. code-block:: console
+
+ $ pip install -v . --config-settings use_cuda=True --config-settings use_nccl=True
+
+ Use Pip 22.1 or later to use ``--config-settings`` option.
+
+ Here are the available options for ``--config-settings``:
+
+ .. literalinclude:: ../python-package/packager/build_config.py
+ :language: python
+ :start-at: @dataclasses.dataclass
+ :end-before: def _set_config_setting(
+
+ ``use_system_libxgboost`` is a special option. See Item 4 below for
+ detailed description.
+
+ .. note:: Verbose flag recommended
+
+ As ``pip install .`` will build C++ code, it will take a while to complete.
+ To ensure that the build is progressing successfully, we suggest that
+ you add the verbose flag (``-v``) when invoking ``pip install``.
+
+
+3. Editable installation
+
+ To further enable rapid development and iteration, we provide an **editable installation**.
+ In an editable installation, the installed package is simply a symbolic link to your
+ working copy of the XGBoost source code. So every changes you make to your source
+ directory will be immediately visible to the Python interpreter. Here is how to
+ install XGBoost as editable installation:
.. code-block:: bash
- python setup.py install # Install the XGBoost to your current Python environment.
- python setup.py build # Build the Python package.
- python setup.py build_ext # Build only the C++ core.
- python setup.py sdist # Create a source distribution
- python setup.py bdist # Create a binary distribution
- python setup.py bdist_wheel # Create a binary distribution with wheel format
-
- Running ``python setup.py install`` will compile XGBoost using default CMake flags. For
- passing additional compilation options, append the flags to the command. For example,
- to enable CUDA acceleration and NCCL (distributed GPU) support:
-
- .. code-block:: bash
-
- python setup.py install --use-cuda --use-nccl
-
- Please refer to ``setup.py`` for a complete list of available options. Some other
- options used for development are only available for using CMake directly. See next
- section on how to use CMake with setuptools manually.
-
- You can install the created distribution packages using pip. For example, after running
- ``sdist`` setuptools command, a tar ball similar to ``xgboost-1.0.0.tar.gz`` will be
- created under the ``dist`` directory. Then you can install it by invoking the following
- command under ``dist`` directory:
-
- .. code-block:: bash
-
- # under python-package directory
- cd dist
- pip install ./xgboost-1.0.0.tar.gz
-
-
- For details about these commands, please refer to the official document of `setuptools
- `_, or just Google "how to install Python
- package from source". XGBoost Python package follows the general convention.
- Setuptools is usually available with your Python distribution, if not you can install it
- via system command. For example on Debian or Ubuntu:
-
- .. code-block:: bash
-
- sudo apt-get install python-setuptools
-
-
- For cleaning up the directory after running above commands, ``python setup.py clean`` is
- not sufficient. After copying out the build result, simply running ``git clean -xdf``
- under ``python-package`` is an efficient way to remove generated cache files. If you
- find weird behaviors in Python build or running linter, it might be caused by those
- cached files.
-
- For using develop command (editable installation), see next section.
-
- .. code-block::
-
- python setup.py develop # Create a editable installation.
- pip install -e . # Same as above, but carried out by pip.
-
-
-2. Build C++ core with CMake first
-
- This is mostly for C++ developers who don't want to go through the hooks in Python
- setuptools. You can build C++ library directly using CMake as described in above
- sections. After compilation, a shared object (or called dynamic linked library, jargon
- depending on your platform) will appear in XGBoost's source tree under ``lib/``
- directory. On Linux distributions it's ``lib/libxgboost.so``. From there all Python
- setuptools commands will reuse that shared object instead of compiling it again. This
- is especially convenient if you are using the editable installation, where the installed
- package is simply a link to the source tree. We can perform rapid testing during
- development. Here is a simple bash script does that:
-
- .. code-block:: bash
-
- # Under xgboost source tree.
+ # Under xgboost source directory
mkdir build
cd build
- cmake ..
- make -j$(nproc)
+ # Build shared library libxgboost.so
+ cmake .. -GNinja
+ ninja
+ # Install as editable installation
cd ../python-package
- pip install -e . # or equivalently python setup.py develop
+ pip install -e .
-3. Use ``libxgboost.so`` on system path.
+4. Use ``libxgboost.so`` on system path.
- This is for distributing xgboost in a language independent manner, where
- ``libxgboost.so`` is separately packaged with Python package. Assuming `libxgboost.so`
- is already presented in system library path, which can be queried via:
+ This option is useful for package managers that wish to separately package
+ ``libxgboost.so`` and the XGBoost Python package. For example, Conda
+ publishes ``libxgboost`` (for the shared library) and ``py-xgboost``
+ (for the Python package).
+
+ To use this option, first make sure that ``libxgboost.so`` exists in the system library path:
.. code-block:: python
import sys
- import os
- os.path.join(sys.prefix, 'lib')
+ import pathlib
+ libpath = pathlib.Path(sys.prefix).joinpath("lib", "libxgboost.so")
+ assert libpath.exists()
- Then one only needs to provide an user option when installing Python package to reuse the
- shared object in system path:
+ Then pass ``use_system_libxgboost=True`` option to ``pip install``:
.. code-block:: bash
- cd xgboost/python-package
- python setup.py install --use-system-libxgboost
+ cd python-package
+ pip install . --config-settings use_system_libxgboost=True
+.. note::
+
+ See :doc:`contrib/python_packaging` for instructions on packaging
+ and distributing XGBoost as Python distributions.
+
.. _python_mingw:
Building Python Package for Windows with MinGW-w64 (Advanced)
@@ -297,7 +291,7 @@ So you may want to build XGBoost with GCC own your own risk. This presents some
2. ``-O3`` is OK.
3. ``-mtune=native`` is also OK.
4. Don't use ``-march=native`` gcc flag. Using it causes the Python interpreter to crash if the DLL was actually used.
-5. You may need to provide the lib with the runtime libs. If ``mingw32/bin`` is not in ``PATH``, build a wheel (``python setup.py bdist_wheel``), open it with an archiver and put the needed dlls to the directory where ``xgboost.dll`` is situated. Then you can install the wheel with ``pip``.
+5. You may need to provide the lib with the runtime libs. If ``mingw32/bin`` is not in ``PATH``, build a wheel (``pip wheel``), open it with an archiver and put the needed dlls to the directory where ``xgboost.dll`` is situated. Then you can install the wheel with ``pip``.
******************************
Building R Package From Source
diff --git a/doc/contrib/ci.rst b/doc/contrib/ci.rst
index 6073e646a..76e06de35 100644
--- a/doc/contrib/ci.rst
+++ b/doc/contrib/ci.rst
@@ -35,8 +35,9 @@ calls ``cibuildwheel`` to build the wheel. The ``cibuildwheel`` is a library tha
suitable Python environment for each OS and processor target. Since we don't have Apple Silion
machine in GitHub Actions, cross-compilation is needed; ``cibuildwheel`` takes care of the complex
task of cross-compiling a Python wheel. (Note that ``cibuildwheel`` will call
-``setup.py bdist_wheel``. Since XGBoost has a native library component, ``setup.py`` contains
-a glue code to call CMake and a C++ compiler to build the native library on the fly.)
+``pip wheel``. Since XGBoost has a native library component, we created a customized build
+backend that hooks into ``pip``. The customized backend contains the glue code to compile the native
+library on the fly.)
*********************************************************
Reproduce CI testing environments using Docker containers
diff --git a/doc/contrib/index.rst b/doc/contrib/index.rst
index c9c5f93a2..6a36cb108 100644
--- a/doc/contrib/index.rst
+++ b/doc/contrib/index.rst
@@ -23,6 +23,7 @@ Here are guidelines for contributing to various aspect of the XGBoost project:
Community Guideline
donate
coding_guide
+ python_packaging
unit_tests
Docs and Examples
git_guide
diff --git a/doc/contrib/python_packaging.rst b/doc/contrib/python_packaging.rst
new file mode 100644
index 000000000..5cf085685
--- /dev/null
+++ b/doc/contrib/python_packaging.rst
@@ -0,0 +1,83 @@
+###########################################
+Notes on packaging XGBoost's Python package
+###########################################
+
+
+.. contents:: Contents
+ :local:
+
+.. _packaging_python_xgboost:
+
+***************************************************
+How to build binary wheels and source distributions
+***************************************************
+
+Wheels and source distributions (sdist for short) are the two main
+mechanisms for packaging and distributing Python packages.
+
+* A **source distribution** (sdist) is a tarball (``.tar.gz`` extension) that
+ contains the source code.
+* A **wheel** is a ZIP-compressed archive (with ``.whl`` extension)
+ representing a *built* distribution. Unlike an sdist, a wheel can contain
+ compiled components. The compiled components are compiled prior to distribution,
+ making it more convenient for end-users to install a wheel. Wheels containing
+ compiled components are referred to as **binary wheels**.
+
+See `Python Packaging User Guide `_
+to learn more about how Python packages in general are packaged and
+distributed.
+
+For the remainder of this document, we will focus on packaging and
+distributing XGBoost.
+
+Building sdists
+===============
+
+In the case of XGBoost, an sdist contains both the Python code as well as
+the C++ code, so that the core part of XGBoost can be compiled into the
+shared libary ``libxgboost.so`` [#shared_lib_name]_.
+
+You can obtain an sdist as follows:
+
+.. code-block:: console
+
+ $ python -m build --sdist .
+
+(You'll need to install the ``build`` package first:
+``pip install build`` or ``conda install python-build``.)
+
+Running ``pip install`` with an sdist will launch CMake and a C++ compiler
+to compile the bundled C++ code into ``libxgboost.so``:
+
+.. code-block:: console
+
+ $ pip install -v xgboost-2.0.0.tar.gz # Add -v to show build progress
+
+Building binary wheels
+======================
+
+You can also build a wheel as follows:
+
+.. code-block:: console
+
+ $ pip wheel --no-deps -v .
+
+Notably, the resulting wheel contains a copy of the shared library
+``libxgboost.so`` [#shared_lib_name]_. The wheel is a **binary wheel**,
+since it contains a compiled binary.
+
+
+Running ``pip install`` with the binary wheel will extract the content of
+the wheel into the current Python environment. Since the wheel already
+contains a pre-built copy of ``libxgboost.so``, it does not have to be
+built at the time of install. So ``pip install`` with the binary wheel
+completes quickly:
+
+.. code-block:: console
+
+ $ pip install xgboost-2.0.0-py3-none-linux_x86_64.whl # Completes quickly
+
+.. rubric:: Footnotes
+
+.. [#shared_lib_name] The name of the shared library file will differ
+ depending on the operating system in use. See :ref:`build_shared_lib`.
diff --git a/doc/install.rst b/doc/install.rst
index 03daf465f..0e155f647 100644
--- a/doc/install.rst
+++ b/doc/install.rst
@@ -16,15 +16,28 @@ Stable Release
Python
------
-Pre-built binary are uploaded to PyPI (Python Package Index) for each release. Supported platforms are Linux (x86_64, aarch64), Windows (x86_64) and MacOS (x86_64, Apple Silicon).
+Pre-built binary wheels are uploaded to PyPI (Python Package Index) for each release. Supported platforms are Linux (x86_64, aarch64), Windows (x86_64) and MacOS (x86_64, Apple Silicon).
.. code-block:: bash
+ # Pip 21.3+ is required
pip install xgboost
You might need to run the command with ``--user`` flag or use ``virtualenv`` if you run
-into permission errors. Python pre-built binary capability for each platform:
+into permission errors.
+
+.. note:: Windows users need to install Visual C++ Redistributable
+
+ XGBoost requires DLLs from `Visual C++ Redistributable
+ `_
+ in order to function, so make sure to install it. Exception: If
+ you have Visual Studio installed, you already have access to
+ necessary libraries and thus don't need to install Visual C++
+ Redistributable.
+
+
+Capabilities of binary wheels for each platform:
.. |tick| unicode:: U+2714
.. |cross| unicode:: U+2718
diff --git a/doc/jvm/index.rst b/doc/jvm/index.rst
index 6721908f9..2b476781b 100644
--- a/doc/jvm/index.rst
+++ b/doc/jvm/index.rst
@@ -41,3 +41,7 @@ Contents
XGBoost4J Scala API
XGBoost4J-Spark Scala API
XGBoost4J-Flink Scala API
+
+.. note::
+
+ Please note that the flink interface is still under construction.
diff --git a/doc/model.schema b/doc/model.schema
index b9e2da305..103d9d9e4 100644
--- a/doc/model.schema
+++ b/doc/model.schema
@@ -219,6 +219,16 @@
"num_pairsample": { "type": "string" },
"fix_list_weight": { "type": "string" }
}
+ },
+ "lambdarank_param": {
+ "type": "object",
+ "properties": {
+ "lambdarank_num_pair_per_sample": { "type": "string" },
+ "lambdarank_pair_method": { "type": "string" },
+ "lambdarank_unbiased": {"type": "string" },
+ "lambdarank_bias_norm": {"type": "string" },
+ "ndcg_exp_gain": {"type": "string"}
+ }
}
},
"type": "object",
@@ -477,22 +487,22 @@
"type": "object",
"properties": {
"name": { "const": "rank:pairwise" },
- "lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
+ "lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"}
},
"required": [
"name",
- "lambda_rank_param"
+ "lambdarank_param"
]
},
{
"type": "object",
"properties": {
"name": { "const": "rank:ndcg" },
- "lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
+ "lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"}
},
"required": [
"name",
- "lambda_rank_param"
+ "lambdarank_param"
]
},
{
diff --git a/doc/parameter.rst b/doc/parameter.rst
index c070e7018..8c7cadcdc 100644
--- a/doc/parameter.rst
+++ b/doc/parameter.rst
@@ -233,7 +233,7 @@ Parameters for Tree Booster
.. note:: This parameter is working-in-progress.
- The strategy used for training multi-target models, including multi-target regression
- and multi-class classification. See :doc:`/tutorials/multioutput` for more information.
+ and multi-class classification. See :doc:`/tutorials/multioutput` for more information.
- ``one_output_per_tree``: One model for each target.
- ``multi_output_tree``: Use multi-target trees.
@@ -380,9 +380,9 @@ Specify the learning task and the corresponding learning objective. The objectiv
See :doc:`/tutorials/aft_survival_analysis` for details.
- ``multi:softmax``: set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes)
- ``multi:softprob``: same as softmax, but output a vector of ``ndata * nclass``, which can be further reshaped to ``ndata * nclass`` matrix. The result contains predicted probability of each data point belonging to each class.
- - ``rank:pairwise``: Use LambdaMART to perform pairwise ranking where the pairwise loss is minimized
- - ``rank:ndcg``: Use LambdaMART to perform list-wise ranking where `Normalized Discounted Cumulative Gain (NDCG) `_ is maximized
- - ``rank:map``: Use LambdaMART to perform list-wise ranking where `Mean Average Precision (MAP) `_ is maximized
+ - ``rank:ndcg``: Use LambdaMART to perform pair-wise ranking where `Normalized Discounted Cumulative Gain (NDCG) `_ is maximized. This objective supports position debiasing for click data.
+ - ``rank:map``: Use LambdaMART to perform pair-wise ranking where `Mean Average Precision (MAP) `_ is maximized
+ - ``rank:pairwise``: Use LambdaRank to perform pair-wise ranking using the `ranknet` objective.
- ``reg:gamma``: gamma regression with log-link. Output is a mean of gamma distribution. It might be useful, e.g., for modeling insurance claims severity, or for any outcome that might be `gamma-distributed `_.
- ``reg:tweedie``: Tweedie regression with log-link. It might be useful, e.g., for modeling total loss in insurance, or for any outcome that might be `Tweedie-distributed `_.
@@ -395,8 +395,9 @@ Specify the learning task and the corresponding learning objective. The objectiv
* ``eval_metric`` [default according to objective]
- - Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, mean average precision for ranking)
- - User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous one
+ - Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, `mean average precision` for ``rank:map``, etc.)
+ - User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous ones
+
- The choices are listed below:
- ``rmse``: `root mean square error `_
@@ -480,6 +481,36 @@ Parameter for using AFT Survival Loss (``survival:aft``) and Negative Log Likeli
* ``aft_loss_distribution``: Probability Density Function, ``normal``, ``logistic``, or ``extreme``.
+.. _ltr-param:
+
+Parameters for learning to rank (``rank:ndcg``, ``rank:map``, ``rank:pairwise``)
+================================================================================
+
+These are parameters specific to learning to rank task. See :doc:`Learning to Rank ` for an in-depth explanation.
+
+* ``lambdarank_pair_method`` [default = ``mean``]
+
+ How to construct pairs for pair-wise learning.
+
+ - ``mean``: Sample ``lambdarank_num_pair_per_sample`` pairs for each document in the query list.
+ - ``topk``: Focus on top-``lambdarank_num_pair_per_sample`` documents. Construct :math:`|query|` pairs for each document at the top-``lambdarank_num_pair_per_sample`` ranked by the model.
+
+* ``lambdarank_num_pair_per_sample`` [range = :math:`[1, \infty]`]
+
+ It specifies the number of pairs sampled for each document when pair method is ``mean``, or the truncation level for queries when the pair method is ``topk``. For example, to train with ``ndcg@6``, set ``lambdarank_num_pair_per_sample`` to :math:`6` and ``lambdarank_pair_method`` to ``topk``.
+
+* ``lambdarank_unbiased`` [default = ``false``]
+
+ Specify whether do we need to debias input click data.
+
+* ``lambdarank_bias_norm`` [default = 2.0]
+
+ :math:`L_p` normalization for position debiasing, default is :math:`L_2`. Only relevant when ``lambdarank_unbiased`` is set to true.
+
+* ``ndcg_exp_gain`` [default = ``true``]
+
+ Whether we should use exponential gain function for ``NDCG``. There are two forms of gain function for ``NDCG``, one is using relevance value directly while the other is using :math:`2^{rel} - 1` to emphasize on retrieving relevant documents. When ``ndcg_exp_gain`` is true (the default), relevance degree cannot be greater than 31.
+
***********************
Command Line Parameters
***********************
diff --git a/doc/tutorials/dask.rst b/doc/tutorials/dask.rst
index c33a90c81..888683975 100644
--- a/doc/tutorials/dask.rst
+++ b/doc/tutorials/dask.rst
@@ -23,7 +23,7 @@ Requirements
Dask can be installed using either pip or conda (see the dask `installation
documentation `_ for more information). For
-accelerating XGBoost with GPUs, `dask-cuda `_ is
+accelerating XGBoost with GPUs, `dask-cuda `__ is
recommended for creating GPU clusters.
diff --git a/doc/tutorials/external_memory.rst b/doc/tutorials/external_memory.rst
index 3b96cfe92..006d63b43 100644
--- a/doc/tutorials/external_memory.rst
+++ b/doc/tutorials/external_memory.rst
@@ -77,7 +77,7 @@ The external memory version takes in the following `URI `_ for a description of the CSV format.). Please be careful that, XGBoost does **not** understand file extensions, nor try to guess the file format, as there is no universal agreement upon file extension of LIBSVM or CSV. Instead it employs `URI `_ format for specifying the precise input file type. For example if you provide a `csv` file ``./data.train.csv`` as input, XGBoost will blindly use the default LIBSVM parser to digest it and generate a parser error. Instead, users need to provide an URI in the form of ``train.csv?format=csv``. For external memory input, the URI should of a form similar to ``train.csv?format=csv#dtrain.cache``. See :ref:`python_data_interface` and :doc:`/tutorials/external_memory` also.
+
+XGBoost currently supports two text formats for ingesting data: LIBSVM and CSV. The rest of this document will describe the LIBSVM format. (See `this Wikipedia article `_ for a description of the CSV format.). Please be careful that, XGBoost does **not** understand file extensions, nor try to guess the file format, as there is no universal agreement upon file extension of LIBSVM or CSV. Instead it employs `URI `_ format for specifying the precise input file type. For example if you provide a `csv` file ``./data.train.csv`` as input, XGBoost will blindly use the default LIBSVM parser to digest it and generate a parser error. Instead, users need to provide an URI in the form of ``train.csv?format=csv`` or ``train.csv?format=libsvm``. For external memory input, the URI should of a form similar to ``train.csv?format=csv#dtrain.cache``. See :ref:`python_data_interface` and :doc:`/tutorials/external_memory` also.
For training or predicting, XGBoost takes an instance file with the format as below:
diff --git a/doc/tutorials/spark_estimator.rst b/doc/tutorials/spark_estimator.rst
index 02ddb60ea..fb69b70e1 100644
--- a/doc/tutorials/spark_estimator.rst
+++ b/doc/tutorials/spark_estimator.rst
@@ -108,8 +108,8 @@ virtualenv and pip:
python -m venv xgboost_env
source xgboost_env/bin/activate
pip install pyarrow pandas venv-pack xgboost
- # https://rapids.ai/pip.html#install
- pip install cudf-cu11 --extra-index-url=https://pypi.ngc.nvidia.com
+ # https://docs.rapids.ai/install#pip-install
+ pip install cudf-cu11 --extra-index-url=https://pypi.nvidia.com
venv-pack -o xgboost_env.tar.gz
With Conda:
@@ -241,7 +241,7 @@ additional spark configurations and dependencies:
--master spark://:7077 \
--conf spark.executor.resource.gpu.amount=1 \
--conf spark.task.resource.gpu.amount=1 \
- --packages com.nvidia:rapids-4-spark_2.12:22.08.0 \
+ --packages com.nvidia:rapids-4-spark_2.12:23.04.0 \
--conf spark.plugins=com.nvidia.spark.SQLPlugin \
--conf spark.sql.execution.arrow.maxRecordsPerBatch=1000000 \
--archives xgboost_env.tar.gz#environment \
diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h
index 2233336e9..4b9d37335 100644
--- a/include/xgboost/c_api.h
+++ b/include/xgboost/c_api.h
@@ -38,7 +38,7 @@ typedef uint64_t bst_ulong; // NOLINT(*)
*/
/**
- * @defgroup Library
+ * @defgroup Library Library
*
* These functions are used to obtain general information about XGBoost including version,
* build info and current global configuration.
@@ -112,7 +112,7 @@ XGB_DLL int XGBGetGlobalConfig(char const **out_config);
/**@}*/
/**
- * @defgroup DMatrix
+ * @defgroup DMatrix DMatrix
*
* @brief DMatrix is the baisc data storage for XGBoost used by all XGBoost algorithms
* including both training, prediction and explanation. There are a few variants of
@@ -138,7 +138,11 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle
/*!
* \brief load a data matrix
* \param config JSON encoded parameters for DMatrix construction. Accepted fields are:
- * - uri: The URI of the input file.
+
+ * - uri: The URI of the input file. The URI parameter `format` is required when loading text data.
+ * \verbatim embed:rst:leading-asterisk
+ * See :doc:`/tutorials/input_format` for more info.
+ * \endverbatim
* - silent (optional): Whether to print message during loading. Default to true.
* - data_split_mode (optional): Whether to split by row or column. In distributed mode, the
* file is split accordingly; otherwise this is only an indicator on how the file was split
@@ -200,7 +204,7 @@ XGB_DLL int XGDMatrixCreateFromDense(char const *data, char const *config, DMatr
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixCreateFromCSC(char const *indptr, char const *indices, char const *data,
- bst_ulong nrow, char const *c_json_config, DMatrixHandle *out);
+ bst_ulong nrow, char const *config, DMatrixHandle *out);
/*!
* \brief create a matrix content from CSC format
@@ -281,7 +285,7 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data, char const *
DMatrixHandle *out);
/**
- * @defgroup Streaming
+ * @defgroup Streaming Streaming
* @ingroup DMatrix
*
* @brief Quantile DMatrix and external memory DMatrix can be created from batches of
@@ -431,7 +435,7 @@ XGB_EXTERN_C typedef void DataIterResetCallback(DataIterHandle handle); // NOLIN
* - Step 0: Define a data iterator with 2 methods `reset`, and `next`.
* - Step 1: Create a DMatrix proxy by \ref XGProxyDMatrixCreate and hold the handle.
* - Step 2: Pass the iterator handle, proxy handle and 2 methods into
- * `XGDMatrixCreateFromCallback`, along with other parameters encoded as a JSON object.
+ * \ref XGDMatrixCreateFromCallback, along with other parameters encoded as a JSON object.
* - Step 3: Call appropriate data setters in `next` functions.
*
* \param iter A handle to external data iterator.
@@ -830,7 +834,7 @@ XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config
/** @} */ // End of DMatrix
/**
- * @defgroup Booster
+ * @defgroup Booster Booster
*
* @brief The `Booster` class is the gradient-boosted model for XGBoost.
* @{
@@ -953,7 +957,7 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle, int iter, DMatrixHandle d
*/
/**
- * @defgroup Prediction
+ * @defgroup Prediction Prediction
* @ingroup Booster
*
* @brief These functions are used for running prediction and explanation algorithms.
@@ -1155,7 +1159,7 @@ XGB_DLL int XGBoosterPredictFromCudaColumnar(BoosterHandle handle, char const *v
/**
- * @defgroup Serialization
+ * @defgroup Serialization Serialization
* @ingroup Booster
*
* @brief There are multiple ways to serialize a Booster object depending on the use case.
@@ -1490,7 +1494,7 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config,
/**@}*/ // End of Booster
/**
- * @defgroup Collective
+ * @defgroup Collective Collective
*
* @brief Experimental support for exposing internal communicator in XGBoost.
*
diff --git a/include/xgboost/context.h b/include/xgboost/context.h
index aaa1e3eb8..f1cd391df 100644
--- a/include/xgboost/context.h
+++ b/include/xgboost/context.h
@@ -50,7 +50,19 @@ struct Context : public XGBoostParameter {
bool IsCPU() const { return gpu_id == kCpuId; }
bool IsCUDA() const { return !IsCPU(); }
+
CUDAContext const* CUDACtx() const;
+ // Make a CUDA context based on the current context.
+ Context MakeCUDA(std::int32_t device = 0) const {
+ Context ctx = *this;
+ ctx.gpu_id = device;
+ return ctx;
+ }
+ Context MakeCPU() const {
+ Context ctx = *this;
+ ctx.gpu_id = kCpuId;
+ return ctx;
+ }
// declare parameters
DMLC_DECLARE_PARAMETER(Context) {
diff --git a/include/xgboost/data.h b/include/xgboost/data.h
index 4af306859..6305abff8 100644
--- a/include/xgboost/data.h
+++ b/include/xgboost/data.h
@@ -1,5 +1,5 @@
-/*!
- * Copyright (c) 2015-2022 by XGBoost Contributors
+/**
+ * Copyright 2015-2023 by XGBoost Contributors
* \file data.h
* \brief The input data structure of xgboost.
* \author Tianqi Chen
@@ -196,6 +196,14 @@ class MetaInfo {
*/
bool IsVerticalFederated() const;
+ /*!
+ * \brief A convenient method to check if the MetaInfo should contain labels.
+ *
+ * Normally we assume labels are available everywhere. The only exception is in vertical federated
+ * learning where labels are only available on worker 0.
+ */
+ bool ShouldHaveLabels() const;
+
private:
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
@@ -230,44 +238,72 @@ struct Entry {
}
};
-/*!
- * \brief Parameters for constructing batches.
+/**
+ * \brief Parameters for constructing histogram index batches.
*/
struct BatchParam {
- /*! \brief The GPU device to use. */
- int gpu_id {-1};
- /*! \brief Maximum number of bins per feature for histograms. */
+ /**
+ * \brief Maximum number of bins per feature for histograms.
+ */
bst_bin_t max_bin{0};
- /*! \brief Hessian, used for sketching with future approx implementation. */
+ /**
+ * \brief Hessian, used for sketching with future approx implementation.
+ */
common::Span hess;
- /*! \brief Whether should DMatrix regenerate the batch. Only used for GHistIndex. */
- bool regen {false};
- /*! \brief Parameter used to generate column matrix for hist. */
+ /**
+ * \brief Whether should we force DMatrix to regenerate the batch. Only used for
+ * GHistIndex.
+ */
+ bool regen{false};
+ /**
+ * \brief Forbid regenerating the gradient index. Used for internal validation.
+ */
+ bool forbid_regen{false};
+ /**
+ * \brief Parameter used to generate column matrix for hist.
+ */
double sparse_thresh{std::numeric_limits::quiet_NaN()};
+ /**
+ * \brief Exact or others that don't need histogram.
+ */
BatchParam() = default;
- // GPU Hist
- BatchParam(int32_t device, bst_bin_t max_bin)
- : gpu_id{device}, max_bin{max_bin} {}
- // Hist
+ /**
+ * \brief Used by the hist tree method.
+ */
BatchParam(bst_bin_t max_bin, double sparse_thresh)
: max_bin{max_bin}, sparse_thresh{sparse_thresh} {}
- // Approx
/**
- * \brief Get batch with sketch weighted by hessian. The batch will be regenerated if
- * the span is changed, so caller should keep the span for each iteration.
+ * \brief Used by the approx tree method.
+ *
+ * Get batch with sketch weighted by hessian. The batch will be regenerated if the
+ * span is changed, so caller should keep the span for each iteration.
*/
BatchParam(bst_bin_t max_bin, common::Span hessian, bool regenerate)
: max_bin{max_bin}, hess{hessian}, regen{regenerate} {}
- bool operator!=(BatchParam const& other) const {
- if (hess.empty() && other.hess.empty()) {
- return gpu_id != other.gpu_id || max_bin != other.max_bin;
- }
- return gpu_id != other.gpu_id || max_bin != other.max_bin || hess.data() != other.hess.data();
+ bool ParamNotEqual(BatchParam const& other) const {
+ // Check non-floating parameters.
+ bool cond = max_bin != other.max_bin;
+ // Check sparse thresh.
+ bool l_nan = std::isnan(sparse_thresh);
+ bool r_nan = std::isnan(other.sparse_thresh);
+ bool st_chg = (l_nan != r_nan) || (!l_nan && !r_nan && (sparse_thresh != other.sparse_thresh));
+ cond |= st_chg;
+
+ return cond;
}
- bool operator==(BatchParam const& other) const {
- return !(*this != other);
+ bool Initialized() const { return max_bin != 0; }
+ /**
+ * \brief Make a copy of self for DMatrix to describe how its existing index was generated.
+ */
+ BatchParam MakeCache() const {
+ auto p = *this;
+ // These parameters have nothing to do with how the gradient index was generated in the
+ // first place.
+ p.regen = false;
+ p.forbid_regen = false;
+ return p;
}
};
@@ -427,7 +463,7 @@ class EllpackPage {
* This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix
* in CSR format.
*/
- explicit EllpackPage(DMatrix* dmat, const BatchParam& param);
+ explicit EllpackPage(Context const* ctx, DMatrix* dmat, const BatchParam& param);
/*! \brief Destructor. */
~EllpackPage();
@@ -543,7 +579,9 @@ class DMatrix {
template
BatchSet GetBatches();
template
- BatchSet GetBatches(const BatchParam& param);
+ BatchSet GetBatches(Context const* ctx);
+ template
+ BatchSet GetBatches(Context const* ctx, const BatchParam& param);
template
bool PageExists() const;
@@ -558,21 +596,17 @@ class DMatrix {
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
}
- /*!
+ /**
* \brief Load DMatrix from URI.
+ *
* \param uri The URI of input.
* \param silent Whether print information during loading.
* \param data_split_mode In distributed mode, split the input according this mode; otherwise,
* it's just an indicator on how the input was split beforehand.
- * \param file_format The format type of the file, used for dmlc::Parser::Create.
- * By default "auto" will be able to load in both local binary file.
- * \param page_size Page size for external memory.
* \return The created DMatrix.
*/
- static DMatrix* Load(const std::string& uri,
- bool silent = true,
- DataSplitMode data_split_mode = DataSplitMode::kRow,
- const std::string& file_format = "auto");
+ static DMatrix* Load(const std::string& uri, bool silent = true,
+ DataSplitMode data_split_mode = DataSplitMode::kRow);
/**
* \brief Creates a new DMatrix from an external data adapter.
@@ -654,18 +688,19 @@ class DMatrix {
protected:
virtual BatchSet GetRowBatches() = 0;
- virtual BatchSet GetColumnBatches() = 0;
- virtual BatchSet GetSortedColumnBatches() = 0;
- virtual BatchSet GetEllpackBatches(const BatchParam& param) = 0;
- virtual BatchSet GetGradientIndex(const BatchParam& param) = 0;
- virtual BatchSet GetExtBatches(BatchParam const& param) = 0;
+ virtual BatchSet GetColumnBatches(Context const* ctx) = 0;
+ virtual BatchSet GetSortedColumnBatches(Context const* ctx) = 0;
+ virtual BatchSet GetEllpackBatches(Context const* ctx, BatchParam const& param) = 0;
+ virtual BatchSet GetGradientIndex(Context const* ctx,
+ BatchParam const& param) = 0;
+ virtual BatchSet GetExtBatches(Context const* ctx, BatchParam const& param) = 0;
virtual bool EllpackExists() const = 0;
virtual bool GHistIndexExists() const = 0;
virtual bool SparsePageExists() const = 0;
};
-template<>
+template <>
inline BatchSet DMatrix::GetBatches() {
return GetRowBatches();
}
@@ -680,34 +715,39 @@ inline bool DMatrix::PageExists() const {
return this->GHistIndexExists();
}
-template<>
+template <>
inline bool DMatrix::PageExists() const {
return this->SparsePageExists();
}
-template<>
-inline BatchSet DMatrix::GetBatches() {
- return GetColumnBatches();
-}
-
-template<>
-inline BatchSet DMatrix::GetBatches() {
- return GetSortedColumnBatches();
-}
-
-template<>
-inline BatchSet DMatrix::GetBatches(const BatchParam& param) {
- return GetEllpackBatches(param);
+template <>
+inline BatchSet DMatrix::GetBatches(Context const*) {
+ return GetRowBatches();
}
template <>
-inline BatchSet DMatrix::GetBatches(const BatchParam& param) {
- return GetGradientIndex(param);
+inline BatchSet DMatrix::GetBatches(Context const* ctx) {
+ return GetColumnBatches(ctx);
}
template <>
-inline BatchSet DMatrix::GetBatches() {
- return GetExtBatches(BatchParam{});
+inline BatchSet DMatrix::GetBatches(Context const* ctx) {
+ return GetSortedColumnBatches(ctx);
+}
+
+template <>
+inline BatchSet DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
+ return GetEllpackBatches(ctx, param);
+}
+
+template <>
+inline BatchSet DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
+ return GetGradientIndex(ctx, param);
+}
+
+template <>
+inline BatchSet DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
+ return GetExtBatches(ctx, param);
}
} // namespace xgboost
diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h
index 61dd94302..393dda59c 100644
--- a/include/xgboost/tree_model.h
+++ b/include/xgboost/tree_model.h
@@ -567,7 +567,7 @@ class RegTree : public Model {
* \brief drop the trace after fill, must be called after fill.
* \param inst The sparse instance to drop.
*/
- void Drop(const SparsePage::Inst& inst);
+ void Drop();
/*!
* \brief returns the size of the feature vector
* \return the size of the feature vector
@@ -807,13 +807,10 @@ inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
has_missing_ = data_.size() != feature_count;
}
-inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) {
- for (auto const& entry : inst) {
- if (entry.index >= data_.size()) {
- continue;
- }
- data_[entry.index].flag = -1;
- }
+inline void RegTree::FVec::Drop() {
+ Entry e{};
+ e.flag = -1;
+ std::fill_n(data_.data(), data_.size(), e);
has_missing_ = true;
}
diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml
index facb955ce..4903b8f38 100644
--- a/jvm-packages/pom.xml
+++ b/jvm-packages/pom.xml
@@ -33,16 +33,16 @@
UTF-8
1.8
1.8
- 1.8.3
- 3.1.1
- 2.12.8
+ 1.17.0
+ 3.4.0
+ 2.12.17
2.12
3.3.5
5
OFF
OFF
- 22.12.0
- 22.12.0
+ 23.04.0
+ 23.04.0
cuda11
@@ -374,7 +374,7 @@
org.apache.maven.plugins
maven-checkstyle-plugin
- 3.2.1
+ 3.2.2
checkstyle.xml
true
@@ -450,7 +450,7 @@
maven-project-info-reports-plugin
- 3.4.2
+ 3.4.3
net.alchim31.maven
@@ -469,7 +469,7 @@
com.esotericsoftware
kryo
- 5.4.0
+ 5.5.0
org.scala-lang
@@ -477,11 +477,6 @@
${scala.version}
provided
-
- org.scala-lang
- scala-reflect
- ${scala.version}
-
org.scala-lang
scala-library
@@ -495,13 +490,13 @@
org.scalatest
scalatest_${scala.binary.version}
- 3.0.8
+ 3.2.15
test
org.scalactic
scalactic_${scala.binary.version}
- 3.0.8
+ 3.2.15
test
diff --git a/jvm-packages/xgboost4j-example/pom.xml b/jvm-packages/xgboost4j-example/pom.xml
index d08e4f409..40c9c72a4 100644
--- a/jvm-packages/xgboost4j-example/pom.xml
+++ b/jvm-packages/xgboost4j-example/pom.xml
@@ -26,7 +26,7 @@
ml.dmlc
xgboost4j-spark_${scala.binary.version}
- 2.0.0-SNAPSHOT
+ ${project.version}
org.apache.spark
@@ -37,12 +37,7 @@
ml.dmlc
xgboost4j-flink_${scala.binary.version}
- 2.0.0-SNAPSHOT
-
-
- org.apache.commons
- commons-lang3
- 3.12.0
+ ${project.version}
diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/BasicWalkThrough.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/BasicWalkThrough.java
index 7e4fe6806..8a74b74da 100644
--- a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/BasicWalkThrough.java
+++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/BasicWalkThrough.java
@@ -1,5 +1,5 @@
/*
- Copyright (c) 2014-2021 by Contributors
+ Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -62,8 +62,8 @@ public class BasicWalkThrough {
public static void main(String[] args) throws IOException, XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
- DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
- DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
+ DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
HashMap params = new HashMap();
params.put("eta", 1.0);
@@ -112,7 +112,8 @@ public class BasicWalkThrough {
System.out.println("start build dmatrix from csr sparse data ...");
//build dmatrix from CSR Sparse Matrix
- DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
+ DataLoader.CSRSparseData spData =
+ DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm");
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
DMatrix.SparseType.CSR, 127);
diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/BoostFromPrediction.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/BoostFromPrediction.java
index 7eb9e99f0..fe5db0465 100644
--- a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/BoostFromPrediction.java
+++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/BoostFromPrediction.java
@@ -32,8 +32,8 @@ public class BoostFromPrediction {
System.out.println("start running example to start from a initial prediction");
// load file from text file, also binary buffer generated by xgboost4j
- DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
- DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
+ DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
//specify parameters
HashMap params = new HashMap();
diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/CrossValidation.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/CrossValidation.java
index dbe5f368c..3577be226 100644
--- a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/CrossValidation.java
+++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/CrossValidation.java
@@ -30,7 +30,7 @@ import ml.dmlc.xgboost4j.java.XGBoostError;
public class CrossValidation {
public static void main(String[] args) throws IOException, XGBoostError {
//load train mat
- DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
//set params
HashMap params = new HashMap();
diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/CustomObjective.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/CustomObjective.java
index 6d529974c..c631dc01a 100644
--- a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/CustomObjective.java
+++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/CustomObjective.java
@@ -139,9 +139,9 @@ public class CustomObjective {
public static void main(String[] args) throws XGBoostError {
//load train mat (svmlight format)
- DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
//load valid mat (svmlight format)
- DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
+ DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
HashMap params = new HashMap();
params.put("eta", 1.0);
diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/EarlyStopping.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/EarlyStopping.java
index 61e752f85..9e52c12fd 100644
--- a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/EarlyStopping.java
+++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/EarlyStopping.java
@@ -29,9 +29,9 @@ import ml.dmlc.xgboost4j.java.example.util.DataLoader;
public class EarlyStopping {
public static void main(String[] args) throws IOException, XGBoostError {
DataLoader.CSRSparseData trainCSR =
- DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
+ DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm");
DataLoader.CSRSparseData testCSR =
- DataLoader.loadSVMFile("../../demo/data/agaricus.txt.test");
+ DataLoader.loadSVMFile("../../demo/data/agaricus.txt.test?format=libsvm");
Map paramMap = new HashMap() {
{
diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java
index 349098ae1..70b2b85b5 100644
--- a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java
+++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/ExternalMemory.java
@@ -32,8 +32,8 @@ public class ExternalMemory {
//this is the only difference, add a # followed by a cache prefix name
//several cache file with the prefix will be generated
//currently only support convert from libsvm file
- DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
- DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm#dtrain.cache");
+ DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm#dtest.cache");
//specify parameters
HashMap params = new HashMap();
diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/GeneralizedLinearModel.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/GeneralizedLinearModel.java
index 422cdea6a..09cc91c7f 100644
--- a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/GeneralizedLinearModel.java
+++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/GeneralizedLinearModel.java
@@ -32,8 +32,8 @@ import ml.dmlc.xgboost4j.java.example.util.CustomEval;
public class GeneralizedLinearModel {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
- DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
- DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
+ DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
//specify parameters
//change booster to gblinear, so that we are fitting a linear model
diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/PredictFirstNtree.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/PredictFirstNtree.java
index c98534a93..9038502bd 100644
--- a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/PredictFirstNtree.java
+++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/PredictFirstNtree.java
@@ -31,8 +31,8 @@ import ml.dmlc.xgboost4j.java.example.util.CustomEval;
public class PredictFirstNtree {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
- DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
- DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
+ DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
//specify parameters
HashMap params = new HashMap();
diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/PredictLeafIndices.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/PredictLeafIndices.java
index 0fcfb39de..7b1dfcb28 100644
--- a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/PredictLeafIndices.java
+++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/PredictLeafIndices.java
@@ -31,8 +31,8 @@ import ml.dmlc.xgboost4j.java.XGBoostError;
public class PredictLeafIndices {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
- DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
- DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
+ DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
+ DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
//specify parameters
HashMap params = new HashMap();
diff --git a/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExample.java b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExample.java
new file mode 100644
index 000000000..94e5cdab5
--- /dev/null
+++ b/jvm-packages/xgboost4j-example/src/main/java/ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExample.java
@@ -0,0 +1,107 @@
+/*
+ Copyright (c) 2014-2021 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j.java.example.flink;
+
+import java.nio.file.Path;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+
+import org.apache.flink.api.common.typeinfo.TypeHint;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.operators.MapOperator;
+import org.apache.flink.api.java.tuple.Tuple13;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.utils.DataSetUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+
+import ml.dmlc.xgboost4j.java.flink.XGBoost;
+import ml.dmlc.xgboost4j.java.flink.XGBoostModel;
+
+
+public class DistTrainWithFlinkExample {
+
+ static Tuple2> runPrediction(
+ ExecutionEnvironment env,
+ java.nio.file.Path trainPath,
+ int percentage) throws Exception {
+ // reading data
+ final DataSet>> data =
+ DataSetUtils.zipWithIndex(parseCsv(env, trainPath));
+ final long size = data.count();
+ final long trainCount = Math.round(size * 0.01 * percentage);
+ final DataSet> trainData =
+ data
+ .filter(item -> item.f0 < trainCount)
+ .map(t -> t.f1)
+ .returns(TypeInformation.of(new TypeHint>(){}));
+ final DataSet testData =
+ data
+ .filter(tuple -> tuple.f0 >= trainCount)
+ .map(t -> t.f1.f0)
+ .returns(TypeInformation.of(new TypeHint(){}));
+
+ // define parameters
+ HashMap paramMap = new HashMap(3);
+ paramMap.put("eta", 0.1);
+ paramMap.put("max_depth", 2);
+ paramMap.put("objective", "binary:logistic");
+
+ // number of iterations
+ final int round = 2;
+ // train the model
+ XGBoostModel model = XGBoost.train(trainData, paramMap, round);
+ DataSet predTest = model.predict(testData);
+ return new Tuple2>(model, predTest);
+ }
+
+ private static MapOperator,
+ Tuple2> parseCsv(ExecutionEnvironment env, Path trainPath) {
+ return env.readCsvFile(trainPath.toString())
+ .ignoreFirstLine()
+ .types(Double.class, String.class, Double.class, Double.class, Double.class,
+ Integer.class, Integer.class, Integer.class, Integer.class, Integer.class,
+ Integer.class, Integer.class, Integer.class)
+ .map(DistTrainWithFlinkExample::mapFunction);
+ }
+
+ private static Tuple2 mapFunction(Tuple13 tuple) {
+ final DenseVector dense = Vectors.dense(tuple.f2, tuple.f3, tuple.f4, tuple.f5, tuple.f6,
+ tuple.f7, tuple.f8, tuple.f9, tuple.f10, tuple.f11, tuple.f12);
+ if (tuple.f1.contains("inf")) {
+ return new Tuple2(dense, 1.0);
+ } else {
+ return new Tuple2(dense, 0.0);
+ }
+ }
+
+ public static void main(String[] args) throws Exception {
+ final java.nio.file.Path parentPath = java.nio.file.Paths.get(Arrays.stream(args)
+ .findFirst().orElse("."));
+ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ Tuple2> tuple2 = runPrediction(
+ env, parentPath.resolve("veterans_lung_cancer.csv"), 70
+ );
+ List list = tuple2.f1.collect();
+ System.out.println(list.size());
+ }
+}
diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala
index e8481b047..1893288b4 100644
--- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala
+++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BasicWalkThrough.scala
@@ -1,5 +1,5 @@
/*
- Copyright (c) 2014 by Contributors
+ Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -36,8 +36,8 @@ object BasicWalkThrough {
}
def main(args: Array[String]): Unit = {
- val trainMax = new DMatrix("../../demo/data/agaricus.txt.train")
- val testMax = new DMatrix("../../demo/data/agaricus.txt.test")
+ val trainMax = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
+ val testMax = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0
@@ -76,7 +76,7 @@ object BasicWalkThrough {
// build dmatrix from CSR Sparse Matrix
println("start build dmatrix from csr sparse data ...")
- val spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train")
+ val spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm")
val trainMax2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
JDMatrix.SparseType.CSR)
trainMax2.setLabel(spData.labels)
diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BoostFromPrediction.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BoostFromPrediction.scala
index b894532fa..09b72fc50 100644
--- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BoostFromPrediction.scala
+++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/BoostFromPrediction.scala
@@ -24,8 +24,8 @@ object BoostFromPrediction {
def main(args: Array[String]): Unit = {
println("start running example to start from a initial prediction")
- val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
- val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0
diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CrossValidation.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CrossValidation.scala
index 62f8b461a..6083209ec 100644
--- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CrossValidation.scala
+++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CrossValidation.scala
@@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
object CrossValidation {
def main(args: Array[String]): Unit = {
- val trainMat: DMatrix = new DMatrix("../../demo/data/agaricus.txt.train")
+ val trainMat: DMatrix = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
// set params
val params = new mutable.HashMap[String, Any]
diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala
index fe88423e7..8cc49c90d 100644
--- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala
+++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/CustomObjective.scala
@@ -138,8 +138,8 @@ object CustomObjective {
}
def main(args: Array[String]): Unit = {
- val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
- val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0
params += "max_depth" -> 2
diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala
index 447c98295..c7f3d8bbb 100644
--- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala
+++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/ExternalMemory.scala
@@ -25,8 +25,8 @@ object ExternalMemory {
// this is the only difference, add a # followed by a cache prefix name
// several cache file with the prefix will be generated
// currently only support convert from libsvm file
- val trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache")
- val testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache")
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm#dtrain.cache")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm#dtest.cache")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0
diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala
index 27ed98eca..e370010b6 100644
--- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala
+++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/GeneralizedLinearModel.scala
@@ -27,8 +27,8 @@ import ml.dmlc.xgboost4j.scala.example.util.CustomEval
*/
object GeneralizedLinearModel {
def main(args: Array[String]): Unit = {
- val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
- val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
// specify parameters
// change booster to gblinear, so that we are fitting a linear model
diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictFirstNTree.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictFirstNTree.scala
index 5395e3638..40a5ffc44 100644
--- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictFirstNTree.scala
+++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictFirstNTree.scala
@@ -23,8 +23,8 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
object PredictFirstNTree {
def main(args: Array[String]): Unit = {
- val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
- val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0
diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictLeafIndices.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictLeafIndices.scala
index f40a8aac6..7ae2e6520 100644
--- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictLeafIndices.scala
+++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/PredictLeafIndices.scala
@@ -25,8 +25,8 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
object PredictLeafIndices {
def main(args: Array[String]): Unit = {
- val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
- val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0
diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala
index 74b24ac35..cb859f62d 100644
--- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala
+++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlink.scala
@@ -1,5 +1,5 @@
/*
- Copyright (c) 2014 by Contributors
+ Copyright (c) 2014 - 2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -15,27 +15,84 @@
*/
package ml.dmlc.xgboost4j.scala.example.flink
-import ml.dmlc.xgboost4j.scala.flink.XGBoost
-import org.apache.flink.api.scala.{ExecutionEnvironment, _}
-import org.apache.flink.ml.MLUtils
+import java.lang.{Double => JDouble, Long => JLong}
+import java.nio.file.{Path, Paths}
+import org.apache.flink.api.java.tuple.{Tuple13, Tuple2}
+import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
+import org.apache.flink.ml.linalg.{Vector, Vectors}
+import ml.dmlc.xgboost4j.java.flink.{XGBoost, XGBoostModel}
+import org.apache.flink.api.common.typeinfo.{TypeHint, TypeInformation}
+import org.apache.flink.api.java.utils.DataSetUtils
+
object DistTrainWithFlink {
- def main(args: Array[String]) {
- val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
- // read trainining data
- val trainData =
- MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
- val testData = MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.test")
- // define parameters
- val paramMap = List(
- "eta" -> 0.1,
- "max_depth" -> 2,
- "objective" -> "binary:logistic").toMap
+ import scala.jdk.CollectionConverters._
+ private val rowTypeHint = TypeInformation.of(new TypeHint[Tuple2[Vector, JDouble]]{})
+ private val testDataTypeHint = TypeInformation.of(classOf[Vector])
+
+ private[flink] def parseCsv(trainPath: Path)(implicit env: ExecutionEnvironment):
+ DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = {
+ DataSetUtils.zipWithIndex(
+ env
+ .readCsvFile(trainPath.toString)
+ .ignoreFirstLine
+ .types(
+ classOf[Double], classOf[String], classOf[Double], classOf[Double], classOf[Double],
+ classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer],
+ classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer]
+ )
+ .map((row: Tuple13[Double, String, Double, Double, Double,
+ Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer]) => {
+ val dense = Vectors.dense(row.f2, row.f3, row.f4,
+ row.f5.toDouble, row.f6.toDouble, row.f7.toDouble, row.f8.toDouble,
+ row.f9.toDouble, row.f10.toDouble, row.f11.toDouble, row.f12.toDouble)
+ val label = if (row.f1.contains("inf")) {
+ JDouble.valueOf(1.0)
+ } else {
+ JDouble.valueOf(0.0)
+ }
+ new Tuple2[Vector, JDouble](dense, label)
+ })
+ .returns(rowTypeHint)
+ )
+ }
+
+ private[flink] def runPrediction(trainPath: Path, percentage: Int)
+ (implicit env: ExecutionEnvironment):
+ (XGBoostModel, DataSet[Array[Float]]) = {
+ // read training data
+ val data: DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = parseCsv(trainPath)
+ val trainSize = Math.round(0.01 * percentage * data.count())
+ val trainData: DataSet[Tuple2[Vector, JDouble]] =
+ data.filter(d => d.f0 < trainSize).map(_.f1).returns(rowTypeHint)
+
+
+ val testData: DataSet[Vector] =
+ data
+ .filter(d => d.f0 >= trainSize)
+ .map(_.f1.f0)
+ .returns(testDataTypeHint)
+
+ val paramMap = mapAsJavaMap(Map(
+ ("eta", "0.1".asInstanceOf[AnyRef]),
+ ("max_depth", "2"),
+ ("objective", "binary:logistic"),
+ ("verbosity", "1")
+ ))
+
// number of iterations
val round = 2
// train the model
val model = XGBoost.train(trainData, paramMap, round)
- val predTest = model.predict(testData.map{x => x.vector})
- model.saveModelAsHadoopFile("file:///path/to/xgboost.model")
+ val result = model.predict(testData).map(prediction => prediction.map(Float.unbox))
+ (model, result)
+ }
+
+ def main(args: Array[String]): Unit = {
+ implicit val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val parentPath = Paths.get(args.headOption.getOrElse("."))
+ val (_, predTest) = runPrediction(parentPath.resolve("veterans_lung_cancer.csv"), 70)
+ val list = predTest.collect().asScala
+ println(list.length)
}
}
diff --git a/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExampleTest.scala b/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExampleTest.scala
new file mode 100644
index 000000000..b9929639f
--- /dev/null
+++ b/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/java/example/flink/DistTrainWithFlinkExampleTest.scala
@@ -0,0 +1,36 @@
+/*
+ Copyright (c) 2014-2023 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j.java.example.flink
+
+import org.apache.flink.api.java.ExecutionEnvironment
+import org.scalatest.Inspectors._
+import org.scalatest.funsuite.AnyFunSuite
+import org.scalatest.matchers.should.Matchers._
+
+import java.nio.file.Paths
+
+class DistTrainWithFlinkExampleTest extends AnyFunSuite {
+ private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
+ private val data = parentPath.resolve("veterans_lung_cancer.csv")
+
+ test("Smoke test for scala flink example") {
+ val env = ExecutionEnvironment.createLocalEnvironment(1)
+ val tuple2 = DistTrainWithFlinkExample.runPrediction(env, data, 70)
+ val results = tuple2.f1.collect()
+ results should have size 41
+ forEvery(results)(item => item should have size 1)
+ }
+}
diff --git a/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlinkSuite.scala b/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlinkSuite.scala
new file mode 100644
index 000000000..d9e98d81c
--- /dev/null
+++ b/jvm-packages/xgboost4j-example/src/test/scala/ml/dmlc/xgboost4j/scala/example/flink/DistTrainWithFlinkSuite.scala
@@ -0,0 +1,37 @@
+/*
+ Copyright (c) 2014-2023 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j.scala.example.flink
+
+import org.apache.flink.api.java.ExecutionEnvironment
+import org.scalatest.Inspectors._
+import org.scalatest.funsuite.AnyFunSuite
+import org.scalatest.matchers.should.Matchers._
+
+import java.nio.file.Paths
+import scala.jdk.CollectionConverters._
+
+class DistTrainWithFlinkSuite extends AnyFunSuite {
+ private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
+ private val data = parentPath.resolve("veterans_lung_cancer.csv")
+
+ test("Smoke test for scala flink example") {
+ implicit val env: ExecutionEnvironment = ExecutionEnvironment.createLocalEnvironment(1)
+ val (_, result) = DistTrainWithFlink.runPrediction(data, 70)
+ val results = result.collect().asScala
+ results should have size 41
+ forEvery(results)(item => item should have size 1)
+ }
+}
diff --git a/jvm-packages/xgboost4j-flink/pom.xml b/jvm-packages/xgboost4j-flink/pom.xml
index b8b757eae..a9a80e29a 100644
--- a/jvm-packages/xgboost4j-flink/pom.xml
+++ b/jvm-packages/xgboost4j-flink/pom.xml
@@ -8,8 +8,11 @@
xgboost-jvm_2.12
2.0.0-SNAPSHOT
- xgboost4j-flink_2.12
+ xgboost4j-flink_${scala.binary.version}
2.0.0-SNAPSHOT
+
+ 2.2.0
+
@@ -26,32 +29,22 @@
ml.dmlc
xgboost4j_${scala.binary.version}
- 2.0.0-SNAPSHOT
-
-
- org.apache.commons
- commons-lang3
- 3.12.0
+ ${project.version}
org.apache.flink
- flink-scala_${scala.binary.version}
+ flink-clients
${flink.version}
org.apache.flink
- flink-clients_${scala.binary.version}
- ${flink.version}
-
-
- org.apache.flink
- flink-ml_${scala.binary.version}
- ${flink.version}
+ flink-ml-servable-core
+ ${flink-ml.version}
org.apache.hadoop
hadoop-common
- 3.3.5
+ ${hadoop.version}
diff --git a/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java b/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java
new file mode 100644
index 000000000..7a5e3ac68
--- /dev/null
+++ b/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java
@@ -0,0 +1,187 @@
+/*
+ Copyright (c) 2014-2023 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.java.flink;
+
+
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.StreamSupport;
+
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.util.Collector;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import ml.dmlc.xgboost4j.LabeledPoint;
+import ml.dmlc.xgboost4j.java.Booster;
+import ml.dmlc.xgboost4j.java.Communicator;
+import ml.dmlc.xgboost4j.java.DMatrix;
+import ml.dmlc.xgboost4j.java.RabitTracker;
+import ml.dmlc.xgboost4j.java.XGBoostError;
+
+
+public class XGBoost {
+ private static final Logger logger = LoggerFactory.getLogger(XGBoost.class);
+
+ private static class MapFunction
+ extends RichMapPartitionFunction, XGBoostModel> {
+
+ private final Map params;
+ private final int round;
+ private final Map workerEnvs;
+
+ public MapFunction(Map params, int round, Map workerEnvs) {
+ this.params = params;
+ this.round = round;
+ this.workerEnvs = workerEnvs;
+ }
+
+ public void mapPartition(java.lang.Iterable> it,
+ Collector collector) throws XGBoostError {
+ workerEnvs.put(
+ "DMLC_TASK_ID",
+ String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask())
+ );
+
+ if (logger.isInfoEnabled()) {
+ logger.info("start with env: {}", workerEnvs.entrySet().stream()
+ .map(e -> String.format("\"%s\": \"%s\"", e.getKey(), e.getValue()))
+ .collect(Collectors.joining(", "))
+ );
+ }
+
+ final Iterator dataIter =
+ StreamSupport
+ .stream(it.spliterator(), false)
+ .map(VectorToPointMapper.INSTANCE)
+ .iterator();
+
+ if (dataIter.hasNext()) {
+ final DMatrix trainMat = new DMatrix(dataIter, null);
+ int numEarlyStoppingRounds =
+ Optional.ofNullable(params.get("numEarlyStoppingRounds"))
+ .map(x -> Integer.parseInt(x.toString()))
+ .orElse(0);
+
+ final Booster booster = trainBooster(trainMat, numEarlyStoppingRounds);
+ collector.collect(new XGBoostModel(booster));
+ } else {
+ logger.warn("Nothing to train with.");
+ }
+ }
+
+ private Booster trainBooster(DMatrix trainMat,
+ int numEarlyStoppingRounds) throws XGBoostError {
+ Booster booster;
+ final Map watches =
+ new HashMap() {{ put("train", trainMat); }};
+ try {
+ Communicator.init(workerEnvs);
+ booster = ml.dmlc.xgboost4j.java.XGBoost
+ .train(
+ trainMat,
+ params,
+ round,
+ watches,
+ null,
+ null,
+ null,
+ numEarlyStoppingRounds);
+ } catch (XGBoostError xgbException) {
+ final String identifier = String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask());
+ logger.warn(
+ String.format("XGBooster worker %s has failed due to", identifier),
+ xgbException
+ );
+ throw xgbException;
+ } finally {
+ Communicator.shutdown();
+ }
+ return booster;
+ }
+
+ private static class VectorToPointMapper
+ implements Function, LabeledPoint> {
+ public static VectorToPointMapper INSTANCE = new VectorToPointMapper();
+ @Override
+ public LabeledPoint apply(Tuple2 tuple) {
+ final SparseVector vector = tuple.f0.toSparse();
+ final double[] values = vector.values;
+ final int size = values.length;
+ final float[] array = new float[size];
+ for (int i = 0; i < size; i++) {
+ array[i] = (float) values[i];
+ }
+ return new LabeledPoint(
+ tuple.f1.floatValue(),
+ vector.size(),
+ vector.indices,
+ array);
+ }
+ }
+ }
+
+ /**
+ * Load XGBoost model from path, using Hadoop Filesystem API.
+ *
+ * @param modelPath The path that is accessible by hadoop filesystem API.
+ * @return The loaded model
+ */
+ public static XGBoostModel loadModelFromHadoopFile(final String modelPath) throws Exception {
+ final FileSystem fileSystem = FileSystem.get(new Configuration());
+ final Path f = new Path(modelPath);
+
+ try (FSDataInputStream opened = fileSystem.open(f)) {
+ return new XGBoostModel(ml.dmlc.xgboost4j.java.XGBoost.loadModel(opened));
+ }
+ }
+
+ /**
+ * Train a xgboost model with link.
+ *
+ * @param dtrain The training data.
+ * @param params XGBoost parameters.
+ * @param numBoostRound Number of rounds to train.
+ */
+ public static XGBoostModel train(DataSet> dtrain,
+ Map params,
+ int numBoostRound) throws Exception {
+ final RabitTracker tracker =
+ new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
+ if (tracker.start(0L)) {
+ return dtrain
+ .mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerEnvs()))
+ .reduce((x, y) -> x)
+ .collect()
+ .get(0);
+ } else {
+ throw new Error("Tracker cannot be started");
+ }
+ }
+}
diff --git a/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoostModel.java b/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoostModel.java
new file mode 100644
index 000000000..03de50482
--- /dev/null
+++ b/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoostModel.java
@@ -0,0 +1,136 @@
+/*
+ Copyright (c) 2014-2023 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.java.flink;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.stream.StreamSupport;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.util.Collector;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+
+import ml.dmlc.xgboost4j.LabeledPoint;
+import ml.dmlc.xgboost4j.java.Booster;
+import ml.dmlc.xgboost4j.java.DMatrix;
+import ml.dmlc.xgboost4j.java.XGBoostError;
+
+
+public class XGBoostModel implements Serializable {
+ private static final org.slf4j.Logger logger =
+ org.slf4j.LoggerFactory.getLogger(XGBoostModel.class);
+
+ private final Booster booster;
+ private final PredictorFunction predictorFunction;
+
+
+ public XGBoostModel(Booster booster) {
+ this.booster = booster;
+ this.predictorFunction = new PredictorFunction(booster);
+ }
+
+ /**
+ * Save the model as a Hadoop filesystem file.
+ *
+ * @param modelPath The model path as in Hadoop path.
+ */
+ public void saveModelAsHadoopFile(String modelPath) throws IOException, XGBoostError {
+ booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)));
+ }
+
+ public byte[] toByteArray(String format) throws XGBoostError {
+ return booster.toByteArray(format);
+ }
+
+ /**
+ * Save the model as a Hadoop filesystem file.
+ *
+ * @param modelPath The model path as in Hadoop path.
+ * @param format The model format (ubj, json, deprecated)
+ * @throws XGBoostError internal error
+ * @throws IOException save error
+ */
+ public void saveModelAsHadoopFile(String modelPath, String format)
+ throws IOException, XGBoostError {
+ booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)), format);
+ }
+
+ /**
+ * predict with the given DMatrix
+ *
+ * @param testSet the local test set represented as DMatrix
+ * @return prediction result
+ */
+ public float[][] predict(DMatrix testSet) throws XGBoostError {
+ return booster.predict(testSet, true, 0);
+ }
+
+ /**
+ * Predict given vector dataset.
+ *
+ * @param data The dataset to be predicted.
+ * @return The prediction result.
+ */
+ public DataSet predict(DataSet data) {
+ return data.mapPartition(predictorFunction);
+ }
+
+
+ private static class PredictorFunction implements MapPartitionFunction {
+
+ private final Booster booster;
+
+ public PredictorFunction(Booster booster) {
+ this.booster = booster;
+ }
+
+ @Override
+ public void mapPartition(Iterable it, Collector out) throws Exception {
+ final Iterator dataIter =
+ StreamSupport.stream(it.spliterator(), false)
+ .map(Vector::toSparse)
+ .map(PredictorFunction::fromVector)
+ .iterator();
+
+ if (dataIter.hasNext()) {
+ final DMatrix data = new DMatrix(dataIter, null);
+ float[][] predictions = booster.predict(data, true, 2);
+ Arrays.stream(predictions).map(ArrayUtils::toObject).forEach(out::collect);
+ } else {
+ logger.debug("Empty partition");
+ }
+ }
+
+ private static LabeledPoint fromVector(SparseVector vector) {
+ final int[] index = vector.indices;
+ final double[] value = vector.values;
+ int size = value.length;
+ final float[] values = new float[size];
+ for (int i = 0; i < size; i++) {
+ values[i] = (float) value[i];
+ }
+ return new LabeledPoint(0.0f, vector.size(), index, values);
+ }
+ }
+}
diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala
deleted file mode 100644
index 6878f1865..000000000
--- a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala
+++ /dev/null
@@ -1,99 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.flink
-
-import scala.collection.JavaConverters.asScalaIteratorConverter
-
-import ml.dmlc.xgboost4j.LabeledPoint
-import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
-import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala}
-
-import org.apache.commons.logging.LogFactory
-import org.apache.flink.api.common.functions.RichMapPartitionFunction
-import org.apache.flink.api.scala.{DataSet, _}
-import org.apache.flink.ml.common.LabeledVector
-import org.apache.flink.util.Collector
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileSystem, Path}
-
-object XGBoost {
- /**
- * Helper map function to start the job.
- *
- * @param workerEnvs
- */
- private class MapFunction(paramMap: Map[String, Any],
- round: Int,
- workerEnvs: java.util.Map[String, String])
- extends RichMapPartitionFunction[LabeledVector, XGBoostModel] {
- val logger = LogFactory.getLog(this.getClass)
-
- def mapPartition(it: java.lang.Iterable[LabeledVector],
- collector: Collector[XGBoostModel]): Unit = {
- workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
- logger.info("start with env" + workerEnvs.toString)
- Communicator.init(workerEnvs)
- val mapper = (x: LabeledVector) => {
- val (index, value) = x.vector.toSeq.unzip
- LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray)
- }
- val dataIter = for (x <- it.iterator().asScala) yield mapper(x)
- val trainMat = new DMatrix(dataIter, null)
- val watches = List("train" -> trainMat).toMap
- val round = 2
- val numEarlyStoppingRounds = paramMap.get("numEarlyStoppingRounds")
- .map(_.toString.toInt).getOrElse(0)
- val booster = XGBoostScala.train(trainMat, paramMap, round, watches,
- earlyStoppingRound = numEarlyStoppingRounds)
- Communicator.shutdown()
- collector.collect(new XGBoostModel(booster))
- }
- }
-
- val logger = LogFactory.getLog(this.getClass)
-
- /**
- * Load XGBoost model from path, using Hadoop Filesystem API.
- *
- * @param modelPath The path that is accessible by hadoop filesystem API.
- * @return The loaded model
- */
- def loadModelFromHadoopFile(modelPath: String) : XGBoostModel = {
- new XGBoostModel(
- XGBoostScala.loadModel(FileSystem.get(new Configuration).open(new Path(modelPath))))
- }
-
- /**
- * Train a xgboost model with link.
- *
- * @param dtrain The training data.
- * @param params The parameters to XGBoost.
- * @param round Number of rounds to train.
- */
- def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int):
- XGBoostModel = {
- val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
- if (tracker.start(0L)) {
- dtrain
- .mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))
- .reduce((x, y) => x).collect().head
- } else {
- throw new Error("Tracker cannot be started")
- null
- }
- }
-}
diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoostModel.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoostModel.scala
deleted file mode 100644
index 71b376974..000000000
--- a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoostModel.scala
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.flink
-
-import ml.dmlc.xgboost4j.LabeledPoint
-import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
-
-import org.apache.flink.api.scala.{DataSet, _}
-import org.apache.flink.ml.math.Vector
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileSystem, Path}
-
-class XGBoostModel (booster: Booster) extends Serializable {
- /**
- * Save the model as a Hadoop filesystem file.
- *
- * @param modelPath The model path as in Hadoop path.
- */
- def saveModelAsHadoopFile(modelPath: String): Unit = {
- booster.saveModel(FileSystem
- .get(new Configuration)
- .create(new Path(modelPath)))
- }
-
- /**
- * predict with the given DMatrix
- * @param testSet the local test set represented as DMatrix
- * @return prediction result
- */
- def predict(testSet: DMatrix): Array[Array[Float]] = {
- booster.predict(testSet, true, 0)
- }
-
- /**
- * Predict given vector dataset.
- *
- * @param data The dataset to be predicted.
- * @return The prediction result.
- */
- def predict(data: DataSet[Vector]) : DataSet[Array[Float]] = {
- val predictMap: Iterator[Vector] => Traversable[Array[Float]] =
- (it: Iterator[Vector]) => {
- val mapper = (x: Vector) => {
- val (index, value) = x.toSeq.unzip
- LabeledPoint(0.0f, x.size, index.toArray, value.map(_.toFloat).toArray)
- }
- val dataIter = for (x <- it) yield mapper(x)
- val dmat = new DMatrix(dataIter, null)
- this.booster.predict(dmat)
- }
- data.mapPartition(predictMap)
- }
-}
diff --git a/jvm-packages/xgboost4j-gpu/pom.xml b/jvm-packages/xgboost4j-gpu/pom.xml
index 167635209..1d7a06708 100644
--- a/jvm-packages/xgboost4j-gpu/pom.xml
+++ b/jvm-packages/xgboost4j-gpu/pom.xml
@@ -38,22 +38,10 @@
4.13.2
test
-
- com.typesafe.akka
- akka-actor_${scala.binary.version}
- 2.6.20
- compile
-
-
- com.typesafe.akka
- akka-testkit_${scala.binary.version}
- 2.6.20
- test
-
org.scalatest
scalatest_${scala.binary.version}
- 3.0.5
+ 3.2.15
provided
diff --git a/jvm-packages/xgboost4j-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/QuantileDMatrixSuite.scala b/jvm-packages/xgboost4j-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/QuantileDMatrixSuite.scala
index ba8c5fa9a..28ac2207a 100644
--- a/jvm-packages/xgboost4j-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/QuantileDMatrixSuite.scala
+++ b/jvm-packages/xgboost4j-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/QuantileDMatrixSuite.scala
@@ -19,10 +19,10 @@ package ml.dmlc.xgboost4j.scala
import scala.collection.mutable.ArrayBuffer
import ai.rapids.cudf.Table
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.AnyFunSuite
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
-class QuantileDMatrixSuite extends FunSuite {
+class QuantileDMatrixSuite extends AnyFunSuite {
test("QuantileDMatrix test") {
diff --git a/jvm-packages/xgboost4j-spark-gpu/pom.xml b/jvm-packages/xgboost4j-spark-gpu/pom.xml
index b1932f3cc..bcb7edb2a 100644
--- a/jvm-packages/xgboost4j-spark-gpu/pom.xml
+++ b/jvm-packages/xgboost4j-spark-gpu/pom.xml
@@ -44,13 +44,6 @@
${spark.version}
provided
-
- ai.rapids
- cudf
- ${cudf.version}
- ${cudf.classifier}
- provided
-
com.nvidia
rapids-4-spark_${scala.binary.version}
diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuTestSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuTestSuite.scala
index 175e00b39..2a355e160 100644
--- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuTestSuite.scala
+++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuTestSuite.scala
@@ -20,14 +20,15 @@ import java.nio.file.{Files, Path}
import java.sql.{Date, Timestamp}
import java.util.{Locale, TimeZone}
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.{GpuTestUtils, SparkConf}
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.{Row, SparkSession}
-trait GpuTestSuite extends FunSuite with TmpFolderSuite {
+trait GpuTestSuite extends AnyFunSuite with TmpFolderSuite {
import SparkSessionHolder.withSparkSession
protected def getResourcePath(resource: String): String = {
@@ -200,7 +201,7 @@ trait GpuTestSuite extends FunSuite with TmpFolderSuite {
}
-trait TmpFolderSuite extends BeforeAndAfterAll { self: FunSuite =>
+trait TmpFolderSuite extends BeforeAndAfterAll { self: AnyFunSuite =>
protected var tempDir: Path = _
override def beforeAll(): Unit = {
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala
index 176a54832..31d58224b 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala
@@ -1,5 +1,5 @@
/*
- Copyright (c) 2021-2022 by Contributors
+ Copyright (c) 2021-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -22,7 +22,6 @@ import java.util.ServiceLoader
import scala.collection.JavaConverters._
import scala.collection.{AbstractIterator, Iterator, mutable}
-import ml.dmlc.xgboost4j.java.Communicator
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
@@ -35,7 +34,6 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.logging.LogFactory
import org.apache.spark.TaskContext
-import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
@@ -263,12 +261,6 @@ object PreXGBoost extends PreXGBoostProvider {
private var batchCnt = 0
private val batchIterImpl = rowIterator.grouped(inferBatchSize).flatMap { batchRow =>
- if (batchCnt == 0) {
- val rabitEnv = Array(
- "DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
- Communicator.init(rabitEnv.asJava)
- }
-
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
@@ -295,13 +287,8 @@ object PreXGBoost extends PreXGBoostProvider {
override def hasNext: Boolean = batchIterImpl.hasNext
- override def next(): Row = {
- val ret = batchIterImpl.next()
- if (!batchIterImpl.hasNext) {
- Communicator.shutdown()
- }
- ret
- }
+ override def next(): Row = batchIterImpl.next()
+
}
}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
index 281997295..0aeae791a 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
@@ -1,5 +1,5 @@
/*
- Copyright (c) 2014-2022 by Contributors
+ Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -23,7 +23,6 @@ import scala.util.Random
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
-import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
@@ -44,21 +43,16 @@ import org.apache.spark.sql.SparkSession
* Use a finite, non-zero timeout value to prevent tracker from
* hanging indefinitely (in milliseconds)
* (supported by "scala" implementation only.)
- * @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
- * the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
- * in Scala without Python components, and with full support of timeouts.
- * The Scala implementation is currently experimental, use at your own risk.
- *
* @param hostIp The Rabit Tracker host IP address which is only used for python implementation.
* This is only needed if the host IP cannot be automatically guessed.
* @param pythonExec The python executed path for Rabit Tracker,
* which is only used for python implementation.
*/
-case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String,
+case class TrackerConf(workerConnectionTimeout: Long,
hostIp: String = "", pythonExec: String = "")
object TrackerConf {
- def apply(): TrackerConf = TrackerConf(0L, "python")
+ def apply(): TrackerConf = TrackerConf(0L)
}
private[scala] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int,
@@ -349,11 +343,9 @@ object XGBoost extends Serializable {
/** visiable for testing */
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
- val tracker: IRabitTracker = trackerConf.trackerImpl match {
- case "scala" => new RabitTracker(nWorkers)
- case "python" => new PyRabitTracker(nWorkers, trackerConf.hostIp, trackerConf.pythonExec)
- case _ => new PyRabitTracker(nWorkers)
- }
+ val tracker: IRabitTracker = new PyRabitTracker(
+ nWorkers, trackerConf.hostIp, trackerConf.pythonExec
+ )
tracker
}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala
index 579e3dd37..5445cd1bf 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala
@@ -22,11 +22,10 @@ import scala.util.Random
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
-import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
import ml.dmlc.xgboost4j.scala.DMatrix
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.AnyFunSuite
-class CommunicatorRobustnessSuite extends FunSuite with PerTest {
+class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
private def getXGBoostExecutionParams(paramMap: Map[String, Any]): XGBoostExecutionParams = {
val classifier = new XGBoostClassifier(paramMap)
@@ -40,7 +39,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
val paramMap = Map(
"num_workers" -> numWorkers,
- "tracker_conf" -> TrackerConf(0L, "python", hostIp))
+ "tracker_conf" -> TrackerConf(0L, hostIp))
val xgbExecParams = getXGBoostExecutionParams(paramMap)
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
tracker match {
@@ -53,7 +52,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
val paramMap1 = Map(
"num_workers" -> numWorkers,
- "tracker_conf" -> TrackerConf(0L, "python", "", pythonExec))
+ "tracker_conf" -> TrackerConf(0L, "", pythonExec))
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
tracker1 match {
@@ -66,7 +65,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
val paramMap2 = Map(
"num_workers" -> numWorkers,
- "tracker_conf" -> TrackerConf(0L, "python", hostIp, pythonExec))
+ "tracker_conf" -> TrackerConf(0L, hostIp, pythonExec))
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
tracker2 match {
@@ -78,58 +77,6 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
}
}
- test("training with Scala-implemented Rabit tracker") {
- val eval = new EvalError()
- val training = buildDataFrame(Classification.train)
- val testDM = new DMatrix(Classification.test.iterator)
- val paramMap = Map("eta" -> "1", "max_depth" -> "6",
- "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
- "tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
- val model = new XGBoostClassifier(paramMap).fit(training)
- assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
- }
-
- test("test Communicator allreduce to validate Scala-implemented Rabit tracker") {
- val vectorLength = 100
- val rdd = sc.parallelize(
- (1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache()
-
- val tracker = new ScalaRabitTracker(numWorkers)
- tracker.start(0)
- val trackerEnvs = tracker.getWorkerEnvs
- val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]()
-
- val rawData = rdd.mapPartitions { iter =>
- Iterator(iter.toArray)
- }.collect()
-
- val maxVec = (0 until vectorLength).toArray.map { j =>
- (0 until numWorkers).toArray.map { i => rawData(i)(j) }.max
- }
-
- val allReduceResults = rdd.mapPartitions { iter =>
- Communicator.init(trackerEnvs)
- val arr = iter.toArray
- val results = Communicator.allReduce(arr, Communicator.OpType.MAX)
- Communicator.shutdown()
- Iterator(results)
- }.cache()
-
- val sparkThread = new Thread() {
- override def run(): Unit = {
- allReduceResults.foreachPartition(() => _)
- val byPartitionResults = allReduceResults.collect()
- assert(byPartitionResults(0).length == vectorLength)
- collectedAllReduceResults.put(byPartitionResults(0))
- }
- }
- sparkThread.start()
- assert(tracker.waitFor(0L) == 0)
- sparkThread.join()
-
- assert(collectedAllReduceResults.poll().sameElements(maxVec))
- }
-
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
/*
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
@@ -193,68 +140,6 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
assert(tracker.waitFor(0) != 0)
}
- test("test Scala RabitTracker's exception handling: it should not hang forever.") {
- val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
-
- val tracker = new ScalaRabitTracker(numWorkers)
- tracker.start(0)
- val trackerEnvs = tracker.getWorkerEnvs
-
- val workerCount: Int = numWorkers
- val dummyTasks = rdd.mapPartitions { iter =>
- Communicator.init(trackerEnvs)
- val index = iter.next()
- Thread.sleep(100 + index * 10)
- if (index == workerCount) {
- // kill the worker by throwing an exception
- throw new RuntimeException("Worker exception.")
- }
- Communicator.shutdown()
- Iterator(index)
- }.cache()
-
- val sparkThread = new Thread() {
- override def run(): Unit = {
- // forces a Spark job.
- dummyTasks.foreachPartition(() => _)
- }
- }
- sparkThread.setUncaughtExceptionHandler(tracker)
- sparkThread.start()
- assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
- }
-
- test("test Scala RabitTracker's workerConnectionTimeout") {
- val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
-
- val tracker = new ScalaRabitTracker(numWorkers)
- tracker.start(500)
- val trackerEnvs = tracker.getWorkerEnvs
-
- val dummyTasks = rdd.mapPartitions { iter =>
- val index = iter.next()
- // simulate that the first worker cannot connect to tracker due to network issues.
- if (index != 1) {
- Communicator.init(trackerEnvs)
- Thread.sleep(1000)
- Communicator.shutdown()
- }
-
- Iterator(index)
- }.cache()
-
- val sparkThread = new Thread() {
- override def run(): Unit = {
- // forces a Spark job.
- dummyTasks.foreachPartition(() => _)
- }
- }
- sparkThread.setUncaughtExceptionHandler(tracker)
- sparkThread.start()
- // should fail due to connection timeout
- assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
- }
-
test("should allow the dataframe containing communicator calls to be partially evaluated for" +
" multiple times (ISSUE-4406)") {
val paramMap = Map(
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala
index 61766b755..8d9723bb6 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala
@@ -17,13 +17,13 @@
package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.ml.linalg.Vectors
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.AnyFunSuite
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
import org.apache.spark.sql.functions._
-class DeterministicPartitioningSuite extends FunSuite with TmpFolderPerSuite with PerTest {
+class DeterministicPartitioningSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
test("perform deterministic partitioning when checkpointInternal and" +
" checkpointPath is set (Classifier)") {
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala
index cdcfd76f5..adc9c1068 100755
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala
@@ -19,10 +19,10 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, ExternalCheckpointManager, XGBoost => SXGBoost}
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.AnyFunSuite
import org.apache.hadoop.fs.{FileSystem, Path}
-class ExternalCheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
+class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
private def produceParamMap(checkpointPath: String, checkpointInterval: Int):
Map[String, Any] = {
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/FeatureSizeValidatingSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/FeatureSizeValidatingSuite.scala
index e0151dde3..789fd162b 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/FeatureSizeValidatingSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/FeatureSizeValidatingSuite.scala
@@ -18,12 +18,12 @@ package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.Partitioner
import org.apache.spark.ml.feature.VectorAssembler
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.sql.functions._
import scala.util.Random
-class FeatureSizeValidatingSuite extends FunSuite with PerTest {
+class FeatureSizeValidatingSuite extends AnyFunSuite with PerTest {
test("transform throwing exception if feature size of dataset is greater than model's") {
val modelPath = getClass.getResource("/model/0.82/model").getPath
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala
index 5863e2ace..6a7f7129d 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala
@@ -19,12 +19,12 @@ package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.DataFrame
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.AnyFunSuite
import scala.util.Random
import org.apache.spark.SparkException
-class MissingValueHandlingSuite extends FunSuite with PerTest {
+class MissingValueHandlingSuite extends AnyFunSuite with PerTest {
test("dense vectors containing missing value") {
def buildDenseDataFrame(): DataFrame = {
val numRows = 100
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala
index e3468b811..11b60e74d 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala
@@ -16,12 +16,13 @@
package ml.dmlc.xgboost4j.scala.spark
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.SparkException
import org.apache.spark.ml.param.ParamMap
-class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {
+class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
test("XGBoost and Spark parameters synchronize correctly") {
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala
index e96618c51..24bc00e18 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala
@@ -22,13 +22,14 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.SparkContext
import org.apache.spark.sql._
-import org.scalatest.{BeforeAndAfterEach, FunSuite}
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.funsuite.AnyFunSuite
import scala.math.min
import scala.util.Random
import org.apache.commons.io.IOUtils
-trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
+trait PerTest extends BeforeAndAfterEach { self: AnyFunSuite =>
protected val numWorkers: Int = min(Runtime.getRuntime.availableProcessors(), 4)
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala
index cf8dcca57..5425b8647 100755
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala
@@ -25,9 +25,9 @@ import scala.util.Random
import org.apache.spark.ml.feature._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.functions._
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.AnyFunSuite
-class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest {
+class PersistenceSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
test("test persistence of XGBoostClassifier and XGBoostClassificationModel") {
val eval = new EvalError()
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/TmpFolderPerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/TmpFolderPerSuite.scala
index 96b74d679..bb523ffdf 100755
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/TmpFolderPerSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/TmpFolderPerSuite.scala
@@ -19,9 +19,10 @@ package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.{Files, Path}
import org.apache.spark.network.util.JavaUtils
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
-trait TmpFolderPerSuite extends BeforeAndAfterAll { self: FunSuite =>
+trait TmpFolderPerSuite extends BeforeAndAfterAll { self: AnyFunSuite =>
protected var tempDir: Path = _
override def beforeAll(): Unit = {
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala
index f31207b9f..0031be9c7 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala
@@ -22,13 +22,13 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg._
import org.apache.spark.sql._
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.AnyFunSuite
import org.apache.commons.io.IOUtils
import org.apache.spark.Partitioner
import org.apache.spark.ml.feature.VectorAssembler
-class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuite {
+class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
protected val treeMethod: String = "auto"
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala
index a7310f1ab..86b82e63c 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala
@@ -21,11 +21,11 @@ import ml.dmlc.xgboost4j.scala.Booster
import scala.collection.JavaConverters._
import org.apache.spark.sql._
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.SparkException
-class XGBoostCommunicatorRegressionSuite extends FunSuite with PerTest {
+class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
val predictionErrorMin = 0.00001f
val maxFailure = 2;
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala
index 7d588d97c..086fda2d7 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala
@@ -19,9 +19,9 @@ package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import org.apache.spark.sql._
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.AnyFunSuite
-class XGBoostConfigureSuite extends FunSuite with PerTest {
+class XGBoostConfigureSuite extends AnyFunSuite with PerTest {
override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
index 0bf8c2fbb..c1e34224c 100755
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
@@ -22,12 +22,12 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import ml.dmlc.xgboost4j.scala.DMatrix
import org.apache.spark.{SparkException, TaskContext}
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.functions.lit
-class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
+class XGBoostGeneralSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
test("distributed training with the specified worker number") {
val trainingRDD = sc.parallelize(Classification.train)
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala
index 4e3d59b25..efcb38cf6 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala
@@ -23,11 +23,11 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Row}
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.ml.feature.VectorAssembler
-class XGBoostRegressorSuite extends FunSuite with PerTest with TmpFolderPerSuite {
+class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
protected val treeMethod: String = "auto"
test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {
diff --git a/jvm-packages/xgboost4j-tester/generate_pom.py b/jvm-packages/xgboost4j-tester/generate_pom.py
index edc9759bd..06372e9b2 100644
--- a/jvm-packages/xgboost4j-tester/generate_pom.py
+++ b/jvm-packages/xgboost4j-tester/generate_pom.py
@@ -69,7 +69,7 @@ pom_template = """
org.scalactic
scalactic_${{scala.binary.version}}
- 3.0.8
+ 3.2.15
test
diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml
index aa8694751..3a1c4b2cf 100644
--- a/jvm-packages/xgboost4j/pom.xml
+++ b/jvm-packages/xgboost4j/pom.xml
@@ -31,22 +31,10 @@
4.13.2
test
-
- com.typesafe.akka
- akka-actor_${scala.binary.version}
- 2.6.20
- compile
-
-
- com.typesafe.akka
- akka-testkit_${scala.binary.version}
- 2.6.20
- test
-
org.scalatest
scalatest_${scala.binary.version}
- 3.0.5
+ 3.2.15
provided
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala
deleted file mode 100644
index fb388d083..000000000
--- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala
+++ /dev/null
@@ -1,195 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.rabit
-
-import java.net.{InetAddress, InetSocketAddress}
-
-import akka.actor.ActorSystem
-import akka.pattern.ask
-import ml.dmlc.xgboost4j.java.{IRabitTracker, TrackerProperties}
-import ml.dmlc.xgboost4j.scala.rabit.handler.RabitTrackerHandler
-
-import scala.concurrent.duration._
-import scala.concurrent.{Await, Future}
-import scala.util.{Failure, Success, Try}
-
-/**
- * Scala implementation of the Rabit tracker interface without Python dependency.
- * The Scala Rabit tracker fully implements the timeout logic, effectively preventing the tracker
- * (and thus any distributed tasks) to hang indefinitely due to network issues or worker node
- * failures.
- *
- * Note that this implementation is currently experimental, and should be used at your own risk.
- *
- * Example usage:
- * {{{
- * import scala.concurrent.duration._
- *
- * val tracker = new RabitTracker(32)
- * // allow up to 10 minutes for all workers to connect to the tracker.
- * tracker.start(10 minutes)
- *
- * /* ...
- * launching workers in parallel
- * ...
- * */
- *
- * // wait for worker execution up to 6 hours.
- * // providing a finite timeout prevents a long-running task from hanging forever in
- * // catastrophic events, like the loss of an executor during model training.
- * tracker.waitFor(6 hours)
- * }}}
- *
- * @param numWorkers Number of distributed workers from which the tracker expects connections.
- * @param port The minimum port number that the tracker binds to.
- * If port is omitted, or given as None, a random ephemeral port is chosen at runtime.
- * @param maxPortTrials The maximum number of trials of socket binding, by sequentially
- * increasing the port number.
- */
-private[scala] class RabitTracker(numWorkers: Int, port: Option[Int] = None,
- maxPortTrials: Int = 1000)
- extends IRabitTracker {
-
- import scala.collection.JavaConverters._
-
- require(numWorkers >=1, "numWorkers must be greater than or equal to one (1).")
-
- val system = ActorSystem.create("RabitTracker")
- val handler = system.actorOf(RabitTrackerHandler.props(numWorkers), "Handler")
- implicit val askTimeout: akka.util.Timeout = akka.util.Timeout(30 seconds)
- private[this] val tcpBindingTimeout: Duration = 1 minute
-
- var workerEnvs: Map[String, String] = Map.empty
-
- override def uncaughtException(t: Thread, e: Throwable): Unit = {
- handler ? RabitTrackerHandler.InterruptTracker(e)
- }
-
- /**
- * Start the Rabit tracker.
- *
- * @param timeout The timeout for awaiting connections from worker nodes.
- * Note that when used in Spark applications, because all Spark transformations are
- * lazily executed, the I/O time for loading RDDs/DataFrames from external sources
- * (local dist, HDFS, S3 etc.) must be taken into account for the timeout value.
- * If the timeout value is too small, the Rabit tracker will likely timeout before workers
- * establishing connections to the tracker, due to the overhead of loading data.
- * Using a finite timeout is encouraged, as it prevents the tracker (thus the Spark driver
- * running it) from hanging indefinitely due to worker connection issues (e.g. firewall.)
- * @return Boolean flag indicating if the Rabit tracker starts successfully.
- */
- private def start(timeout: Duration): Boolean = {
- val hostAddress = Option(TrackerProperties.getInstance().getHostIp)
- .map(InetAddress.getByName).getOrElse(InetAddress.getLocalHost)
-
- handler ? RabitTrackerHandler.StartTracker(
- new InetSocketAddress(hostAddress, port.getOrElse(0)), maxPortTrials, timeout)
-
- // block by waiting for the actor to bind to a port
- Try(Await.result(handler ? RabitTrackerHandler.RequestBoundFuture, askTimeout.duration)
- .asInstanceOf[Future[Map[String, String]]]) match {
- case Success(futurePortBound) =>
- // The success of the Future is contingent on binding to an InetSocketAddress.
- val isBound = Try(Await.ready(futurePortBound, tcpBindingTimeout)).isSuccess
- if (isBound) {
- workerEnvs = Await.result(futurePortBound, 0 nano)
- }
- isBound
- case Failure(ex: Throwable) =>
- false
- }
- }
-
- /**
- * Start the Rabit tracker.
- *
- * @param connectionTimeoutMillis Timeout, in milliseconds, for the tracker to wait for worker
- * connections. If a non-positive value is provided, the tracker
- * waits for incoming worker connections indefinitely.
- * @return Boolean flag indicating if the Rabit tracker starts successfully.
- */
- def start(connectionTimeoutMillis: Long): Boolean = {
- if (connectionTimeoutMillis <= 0) {
- start(Duration.Inf)
- } else {
- start(Duration.fromNanos(connectionTimeoutMillis * 1e6))
- }
- }
-
- def stop(): Unit = {
- system.terminate()
- }
-
- /**
- * Get a Map of necessary environment variables to initiate Rabit workers.
- *
- * @return HashMap containing tracker information.
- */
- def getWorkerEnvs: java.util.Map[String, String] = {
- new java.util.HashMap((workerEnvs ++ Map(
- "DMLC_NUM_WORKER" -> numWorkers.toString,
- "DMLC_NUM_SERVER" -> "0"
- )).asJava)
- }
-
- /**
- * Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
- * This method blocks until timeout or task completion.
- *
- * @param atMost the maximum execution time for the workers. By default,
- * the tracker waits for the workers indefinitely.
- * @return 0 if the tasks complete successfully, and non-zero otherwise.
- */
- private def waitFor(atMost: Duration): Int = {
- // request the completion Future from the tracker actor
- Try(Await.result(handler ? RabitTrackerHandler.RequestCompletionFuture, askTimeout.duration)
- .asInstanceOf[Future[Int]]) match {
- case Success(futureCompleted) =>
- // wait for all workers to complete synchronously.
- val statusCode = Try(Await.result(futureCompleted, atMost)) match {
- case Success(n) if n == numWorkers =>
- IRabitTracker.TrackerStatus.SUCCESS.getStatusCode
- case Success(n) if n < numWorkers =>
- IRabitTracker.TrackerStatus.TIMEOUT.getStatusCode
- case Failure(e) =>
- IRabitTracker.TrackerStatus.FAILURE.getStatusCode
- }
- system.terminate()
- statusCode
- case Failure(ex: Throwable) =>
- system.terminate()
- IRabitTracker.TrackerStatus.FAILURE.getStatusCode
- }
- }
-
- /**
- * Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
- * This method blocks until timeout or task completion.
- *
- * @param atMostMillis Number of milliseconds for the tracker to wait for workers. If a
- * non-positive number is given, the tracker waits indefinitely.
- * @return 0 if the tasks complete successfully, and non-zero otherwise
- */
- def waitFor(atMostMillis: Long): Int = {
- if (atMostMillis <= 0) {
- waitFor(Duration.Inf)
- } else {
- waitFor(Duration.fromNanos(atMostMillis * 1e6))
- }
- }
-}
-
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitTrackerHandler.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitTrackerHandler.scala
deleted file mode 100644
index f9de71746..000000000
--- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitTrackerHandler.scala
+++ /dev/null
@@ -1,361 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.rabit.handler
-
-import java.net.InetSocketAddress
-import java.util.UUID
-
-import scala.concurrent.duration._
-import scala.collection.mutable
-import scala.concurrent.{Promise, TimeoutException}
-import akka.io.{IO, Tcp}
-import akka.actor._
-import ml.dmlc.xgboost4j.java.XGBoostError
-import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, LinkMap}
-
-import scala.util.{Failure, Random, Success, Try}
-
-/** The Akka actor for handling and coordinating Rabit worker connections.
- * This is the main actor for handling socket connections, interacting with the synchronous
- * tracker interface, and resolving tree/ring/parent dependencies between workers.
- *
- * @param numWorkers Number of workers to track.
- */
-private[scala] class RabitTrackerHandler(numWorkers: Int)
- extends Actor with ActorLogging {
-
- import context.system
- import RabitWorkerHandler._
- import RabitTrackerHandler._
-
- private[this] val promisedWorkerEnvs = Promise[Map[String, String]]()
- private[this] val promisedShutdownWorkers = Promise[Int]()
- private[this] val tcpManager = IO(Tcp)
-
- // resolves worker connection dependency.
- val resolver = context.actorOf(Props(classOf[WorkerDependencyResolver], self), "Resolver")
-
- // workers that have sent "shutdown" signal
- private[this] val shutdownWorkers = mutable.Set.empty[Int]
- private[this] val jobToRankMap = mutable.HashMap.empty[String, Int]
- private[this] val actorRefToHost = mutable.HashMap.empty[ActorRef, String]
- private[this] val ranksToAssign = mutable.ListBuffer(0 until numWorkers: _*)
- private[this] var maxPortTrials = 0
- private[this] var workerConnectionTimeout: Duration = Duration.Inf
- private[this] var portTrials = 0
- private[this] val startedWorkers = mutable.Set.empty[Int]
-
- val linkMap = new LinkMap(numWorkers)
-
- def decideRank(rank: Int, jobId: String = "NULL"): Option[Int] = {
- rank match {
- case r if r >= 0 => Some(r)
- case _ =>
- jobId match {
- case "NULL" => None
- case jid => jobToRankMap.get(jid)
- }
- }
- }
-
- /**
- * Handler for all Akka Tcp connection/binding events. Read/write over the socket is handled
- * by the RabitWorkerHandler.
- *
- * @param event Generic Tcp.Event
- */
- private def handleTcpEvents(event: Tcp.Event): Unit = event match {
- case Tcp.Bound(local) =>
- // expect all workers to connect within timeout
- log.info(s"Tracker listening @ ${local.getAddress.getHostAddress}:${local.getPort}")
- log.info(s"Worker connection timeout is $workerConnectionTimeout.")
-
- context.setReceiveTimeout(workerConnectionTimeout)
- promisedWorkerEnvs.success(Map(
- "DMLC_TRACKER_URI" -> local.getAddress.getHostAddress,
- "DMLC_TRACKER_PORT" -> local.getPort.toString,
- // not required because the world size will be communicated to the
- // worker node after the rank is assigned.
- "rabit_world_size" -> numWorkers.toString
- ))
-
- case Tcp.CommandFailed(cmd: Tcp.Bind) =>
- if (portTrials < maxPortTrials) {
- portTrials += 1
- tcpManager ! Tcp.Bind(self,
- new InetSocketAddress(cmd.localAddress.getAddress, cmd.localAddress.getPort + 1),
- backlog = 256)
- }
-
- case Tcp.Connected(remote, local) =>
- log.debug(s"Incoming connection from worker @ ${remote.getAddress.getHostAddress}")
- // revoke timeout if all workers have connected.
- val workerHandler = context.actorOf(RabitWorkerHandler.props(
- remote.getAddress.getHostAddress, numWorkers, self, sender()
- ), s"ConnectionHandler-${UUID.randomUUID().toString}")
- val connection = sender()
- connection ! Tcp.Register(workerHandler, keepOpenOnPeerClosed = true)
-
- actorRefToHost.put(workerHandler, remote.getAddress.getHostName)
- }
-
- /**
- * Handles external tracker control messages sent by RabitTracker (usually in ask patterns)
- * to interact with the tracker interface.
- *
- * @param trackerMsg control messages sent by RabitTracker class.
- */
- private def handleTrackerControlMessage(trackerMsg: TrackerControlMessage): Unit =
- trackerMsg match {
-
- case msg: StartTracker =>
- maxPortTrials = msg.maxPortTrials
- workerConnectionTimeout = msg.connectionTimeout
-
- // if the port number is missing, try binding to a random ephemeral port.
- if (msg.addr.getPort == 0) {
- tcpManager ! Tcp.Bind(self,
- new InetSocketAddress(msg.addr.getAddress, new Random().nextInt(61000 - 32768) + 32768),
- backlog = 256)
- } else {
- tcpManager ! Tcp.Bind(self, msg.addr, backlog = 256)
- }
- sender() ! true
-
- case RequestBoundFuture =>
- sender() ! promisedWorkerEnvs.future
-
- case RequestCompletionFuture =>
- sender() ! promisedShutdownWorkers.future
-
- case InterruptTracker(e) =>
- log.error(e, "Uncaught exception thrown by worker.")
- // make sure that waitFor() does not hang indefinitely.
- promisedShutdownWorkers.failure(e)
- context.stop(self)
- }
-
- /**
- * Handles messages sent by child actors representing connecting Rabit workers, by brokering
- * messages to the dependency resolver, and processing worker commands.
- *
- * @param workerMsg Message sent by RabitWorkerHandler actors.
- */
- private def handleRabitWorkerMessage(workerMsg: RabitWorkerRequest): Unit = workerMsg match {
- case req @ RequestAwaitConnWorkers(_, _) =>
- // since the requester may request to connect to other workers
- // that have not fully set up, delegate this request to the
- // dependency resolver which handles the dependencies properly.
- resolver forward req
-
- // ---- Rabit worker commands: start/recover/shutdown/print ----
- case WorkerTrackerPrint(_, _, _, msg) =>
- log.info(msg.trim)
-
- case WorkerShutdown(rank, _, _) =>
- assert(rank >= 0, "Invalid rank.")
- assert(!shutdownWorkers.contains(rank))
- shutdownWorkers.add(rank)
-
- log.info(s"Received shutdown signal from $rank")
-
- if (shutdownWorkers.size == numWorkers) {
- promisedShutdownWorkers.success(shutdownWorkers.size)
- }
-
- case WorkerRecover(prevRank, worldSize, jobId) =>
- assert(prevRank >= 0)
- sender() ! linkMap.assignRank(prevRank)
-
- case WorkerStart(rank, worldSize, jobId) =>
- assert(worldSize == numWorkers || worldSize == -1,
- s"Purported worldSize ($worldSize) does not match worker count ($numWorkers)."
- )
-
- Try(decideRank(rank, jobId).getOrElse(ranksToAssign.remove(0))) match {
- case Success(wkRank) =>
- if (jobId != "NULL") {
- jobToRankMap.put(jobId, wkRank)
- }
-
- val assignedRank = linkMap.assignRank(wkRank)
- sender() ! assignedRank
- resolver ! assignedRank
-
- log.info("Received start signal from " +
- s"${actorRefToHost.getOrElse(sender(), "")} [rank: $wkRank]")
-
- case Failure(ex: IndexOutOfBoundsException) =>
- // More than worldSize workers have connected, likely due to executor loss.
- // Since Rabit currently does not support crash recovery (because the Allreduce results
- // are not cached by the tracker, and because existing workers cannot reestablish
- // connections to newly spawned executor/worker), the most reasonble action here is to
- // interrupt the tracker immediate with failure state.
- log.error("Received invalid start signal from " +
- s"${actorRefToHost.getOrElse(sender(), "")}: all $worldSize workers have started."
- )
- promisedShutdownWorkers.failure(new XGBoostError("Invalid start signal" +
- " received from worker, likely due to executor loss."))
-
- case Failure(ex) =>
- log.error(ex, "Unexpected error")
- promisedShutdownWorkers.failure(ex)
- }
-
-
- // ---- Dependency resolving related messages ----
- case msg @ WorkerStarted(host, rank, awaitingAcceptance) =>
- log.info(s"Worker $host (rank: $rank) has started.")
- resolver forward msg
-
- startedWorkers.add(rank)
- if (startedWorkers.size == numWorkers) {
- log.info("All workers have started.")
- }
-
- case req @ DropFromWaitingList(_) =>
- // all peer workers in dependency link map have connected;
- // forward message to resolver to update dependencies.
- resolver forward req
-
- case _ =>
- }
-
- def receive: Actor.Receive = {
- case tcpEvent: Tcp.Event => handleTcpEvents(tcpEvent)
- case trackerMsg: TrackerControlMessage => handleTrackerControlMessage(trackerMsg)
- case workerMsg: RabitWorkerRequest => handleRabitWorkerMessage(workerMsg)
-
- case akka.actor.ReceiveTimeout =>
- if (startedWorkers.size < numWorkers) {
- promisedShutdownWorkers.failure(
- new TimeoutException("Timed out waiting for workers to connect: " +
- s"${numWorkers - startedWorkers.size} of $numWorkers did not start/connect.")
- )
- context.stop(self)
- }
-
- context.setReceiveTimeout(Duration.Undefined)
- }
-}
-
-/**
- * Resolve the dependency between nodes as they connect to the tracker.
- * The dependency is enforced that a worker of rank K depends on its neighbors (from the treeMap
- * and ringMap) whose ranks are smaller than K. Since ranks are assigned in the order of
- * connections by workers, this dependency constraint assumes that a worker node connects first
- * is likely to finish setup first.
- */
-private[rabit] class WorkerDependencyResolver(handler: ActorRef) extends Actor with ActorLogging {
- import RabitWorkerHandler._
-
- context.watch(handler)
-
- case class Fulfillment(toConnectSet: Set[Int], promise: Promise[AwaitingConnections])
-
- // worker nodes that have connected, but have not send WorkerStarted message.
- private val dependencyMap = mutable.Map.empty[Int, Set[Int]]
- private val startedWorkers = mutable.Set.empty[Int]
- // worker nodes that have started, and await for connections.
- private val awaitConnWorkers = mutable.Map.empty[Int, ActorRef]
- private val pendingFulfillment = mutable.Map.empty[Int, Fulfillment]
-
- def awaitingWorkers(linkSet: Set[Int]): AwaitingConnections = {
- val connSet = awaitConnWorkers.toMap
- .filterKeys(k => linkSet.contains(k))
- AwaitingConnections(connSet, linkSet.size - connSet.size)
- }
-
- def receive: Actor.Receive = {
- // a copy of the AssignedRank message that is also sent to the worker
- case AssignedRank(rank, tree_neighbors, ring, parent) =>
- // the workers that the worker of given `rank` depends on:
- // worker of rank K only depends on workers with rank smaller than K.
- val dependentWorkers = (tree_neighbors.toSet ++ Set(ring._1, ring._2))
- .filter{ r => r != -1 && r < rank}
-
- log.debug(s"Rank $rank connected, dependencies: $dependentWorkers")
- dependencyMap.put(rank, dependentWorkers)
-
- case RequestAwaitConnWorkers(rank, toConnectSet) =>
- val promise = Promise[AwaitingConnections]()
-
- assert(dependencyMap.contains(rank))
-
- val updatedDependency = dependencyMap(rank) diff startedWorkers
- if (updatedDependency.isEmpty) {
- // all dependencies are satisfied
- log.debug(s"Rank $rank has all dependencies satisfied.")
- promise.success(awaitingWorkers(toConnectSet))
- } else {
- log.debug(s"Rank $rank's request for AwaitConnWorkers is pending fulfillment.")
- // promise is pending fulfillment due to unresolved dependency
- pendingFulfillment.put(rank, Fulfillment(toConnectSet, promise))
- }
-
- sender() ! promise.future
-
- case WorkerStarted(_, started, awaitingAcceptance) =>
- startedWorkers.add(started)
- if (awaitingAcceptance > 0) {
- awaitConnWorkers.put(started, sender())
- }
-
- // remove the started rank from all dependencies.
- dependencyMap.remove(started)
- dependencyMap.foreach { case (r, dset) =>
- val updatedDependency = dset diff startedWorkers
- // fulfill the future if all dependencies are met (started.)
- if (updatedDependency.isEmpty) {
- log.debug(s"Rank $r has all dependencies satisfied.")
- pendingFulfillment.remove(r).map{
- case Fulfillment(toConnectSet, promise) =>
- promise.success(awaitingWorkers(toConnectSet))
- }
- }
-
- dependencyMap.update(r, updatedDependency)
- }
-
- case DropFromWaitingList(rank) =>
- assert(awaitConnWorkers.remove(rank).isDefined)
-
- case Terminated(ref) =>
- if (ref.equals(handler)) {
- context.stop(self)
- }
- }
-}
-
-private[scala] object RabitTrackerHandler {
- // Messages sent by RabitTracker to this RabitTrackerHandler actor
- trait TrackerControlMessage
- case object RequestCompletionFuture extends TrackerControlMessage
- case object RequestBoundFuture extends TrackerControlMessage
- // Start the Rabit tracker at given socket address awaiting worker connections.
- // All workers must connect to the tracker before connectionTimeout, otherwise the tracker will
- // shut down due to timeout.
- case class StartTracker(addr: InetSocketAddress,
- maxPortTrials: Int,
- connectionTimeout: Duration) extends TrackerControlMessage
- // To interrupt the tracker handler due to uncaught exception thrown by the thread acting as
- // driver for the distributed training.
- case class InterruptTracker(e: Throwable) extends TrackerControlMessage
-
- def props(numWorkers: Int): Props =
- Props(new RabitTrackerHandler(numWorkers))
-}
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitWorkerHandler.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitWorkerHandler.scala
deleted file mode 100644
index 234c4d25a..000000000
--- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/handler/RabitWorkerHandler.scala
+++ /dev/null
@@ -1,467 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.rabit.handler
-
-import java.nio.{ByteBuffer, ByteOrder}
-
-import akka.io.Tcp
-import akka.actor._
-import akka.util.ByteString
-import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, RabitTrackerHelpers}
-
-import scala.concurrent.{Await, Future}
-import scala.concurrent.duration._
-import scala.util.Try
-
-/**
- * Actor to handle socket communication from worker node.
- * To handle fragmentation in received data, this class acts like a FSM
- * (finite-state machine) to keep track of the internal states.
- *
- * @param host IP address of the remote worker
- * @param worldSize number of total workers
- * @param tracker the RabitTrackerHandler actor reference
- */
-private[scala] class RabitWorkerHandler(host: String, worldSize: Int, tracker: ActorRef,
- connection: ActorRef)
- extends FSM[RabitWorkerHandler.State, RabitWorkerHandler.DataStruct]
- with ActorLogging with Stash {
-
- import RabitWorkerHandler._
- import RabitTrackerHelpers._
-
- private[this] var rank: Int = 0
- private[this] var port: Int = 0
-
- // indicate if the connection is transient (like "print" or "shutdown")
- private[this] var transient: Boolean = false
- private[this] var peerClosed: Boolean = false
-
- // number of workers pending acceptance of current worker
- private[this] var awaitingAcceptance: Int = 0
- private[this] var neighboringWorkers = Set.empty[Int]
-
- // TODO: use a single memory allocation to host all buffers,
- // including the transient ones for writing.
- private[this] val readBuffer = ByteBuffer.allocate(4096)
- .order(ByteOrder.nativeOrder())
- // in case the received message is longer than needed,
- // stash the spilled over part in this buffer, and send
- // to self when transition occurs.
- private[this] val spillOverBuffer = ByteBuffer.allocate(4096)
- .order(ByteOrder.nativeOrder())
- // when setup is complete, need to notify peer handlers
- // to reduce the awaiting-connection counter.
- private[this] var pendingAcknowledgement: Option[AcknowledgeAcceptance] = None
-
- private def resetBuffers(): Unit = {
- readBuffer.clear()
- if (spillOverBuffer.position() > 0) {
- spillOverBuffer.flip()
- self ! Tcp.Received(ByteString.fromByteBuffer(spillOverBuffer))
- spillOverBuffer.clear()
- }
- }
-
- private def stashSpillOver(buf: ByteBuffer): Unit = {
- if (buf.remaining() > 0) spillOverBuffer.put(buf)
- }
-
- def getNeighboringWorkers: Set[Int] = neighboringWorkers
-
- def decodeCommand(buffer: ByteBuffer): TrackerCommand = {
- val readBuffer = buffer.duplicate().order(ByteOrder.nativeOrder())
- readBuffer.flip()
-
- val rank = readBuffer.getInt()
- val worldSize = readBuffer.getInt()
- val jobId = readBuffer.getString
-
- val command = readBuffer.getString
- val trackerCommand = command match {
- case "start" => WorkerStart(rank, worldSize, jobId)
- case "shutdown" =>
- transient = true
- WorkerShutdown(rank, worldSize, jobId)
- case "recover" =>
- require(rank >= 0, "Invalid rank for recovering worker.")
- WorkerRecover(rank, worldSize, jobId)
- case "print" =>
- transient = true
- WorkerTrackerPrint(rank, worldSize, jobId, readBuffer.getString)
- }
-
- stashSpillOver(readBuffer)
- trackerCommand
- }
-
- startWith(AwaitingHandshake, DataStruct())
-
- when(AwaitingHandshake) {
- case Event(Tcp.Received(magic), _) =>
- assert(magic.length == 4)
- val purportedMagic = magic.asNativeOrderByteBuffer.getInt
- assert(purportedMagic == MAGIC_NUMBER, s"invalid magic number $purportedMagic from $host")
-
- // echo back the magic number
- connection ! Tcp.Write(magic)
- goto(AwaitingCommand) using StructTrackerCommand
- }
-
- when(AwaitingCommand) {
- case Event(Tcp.Received(bytes), validator) =>
- bytes.asByteBuffers.foreach { buf => readBuffer.put(buf) }
- if (validator.verify(readBuffer)) {
- Try(decodeCommand(readBuffer)) match {
- case scala.util.Success(decodedCommand) =>
- tracker ! decodedCommand
- case scala.util.Failure(th: java.nio.BufferUnderflowException) =>
- // BufferUnderflowException would occur if the message to print has not arrived yet.
- // Do nothing, wait for next Tcp.Received event
- case scala.util.Failure(th: Throwable) => throw th
- }
- }
-
- stay
- // when rank for a worker is assigned, send encoded rank information
- // back to worker over Tcp socket.
- case Event(aRank @ AssignedRank(assignedRank, neighbors, ring, parent), _) =>
- log.debug(s"Assigned rank [$assignedRank] for $host, T: $neighbors, R: $ring, P: $parent")
-
- rank = assignedRank
- // ranks from the ring
- val ringRanks = List(
- // ringPrev
- if (ring._1 != -1 && ring._1 != rank) ring._1 else -1,
- // ringNext
- if (ring._2 != -1 && ring._2 != rank) ring._2 else -1
- )
-
- // update the set of all linked workers to current worker.
- neighboringWorkers = neighbors.toSet ++ ringRanks.filterNot(_ == -1).toSet
-
- connection ! Tcp.Write(ByteString.fromByteBuffer(aRank.toByteBuffer(worldSize)))
- // to prevent reading before state transition
- connection ! Tcp.SuspendReading
- goto(BuildingLinkMap) using StructNodes
- }
-
- when(BuildingLinkMap) {
- case Event(Tcp.Received(bytes), validator) =>
- bytes.asByteBuffers.foreach { buf =>
- readBuffer.put(buf)
- }
-
- if (validator.verify(readBuffer)) {
- readBuffer.flip()
- // for a freshly started worker, numConnected should be 0.
- val numConnected = readBuffer.getInt()
- val toConnectSet = neighboringWorkers.diff(
- (0 until numConnected).map { index => readBuffer.getInt() }.toSet)
-
- // check which workers are currently awaiting connections
- tracker ! RequestAwaitConnWorkers(rank, toConnectSet)
- }
- stay
-
- // got a Future from the tracker (resolver) about workers that are
- // currently awaiting connections (particularly from this node.)
- case Event(future: Future[_], _) =>
- // blocks execution until all dependencies for current worker is resolved.
- Await.result(future, 1 minute).asInstanceOf[AwaitingConnections] match {
- // numNotReachable is the number of workers that currently
- // cannot be connected to (pending connection or setup). Instead, this worker will AWAIT
- // connections from those currently non-reachable nodes in the future.
- case AwaitingConnections(waitConnNodes, numNotReachable) =>
- log.debug(s"Rank $rank needs to connect to: $waitConnNodes, # bad: $numNotReachable")
- val buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder())
- buf.putInt(waitConnNodes.size).putInt(numNotReachable)
- buf.flip()
-
- // cache this message until the final state (SetupComplete)
- pendingAcknowledgement = Some(AcknowledgeAcceptance(
- waitConnNodes, numNotReachable))
-
- connection ! Tcp.Write(ByteString.fromByteBuffer(buf))
- if (waitConnNodes.isEmpty) {
- connection ! Tcp.SuspendReading
- goto(AwaitingErrorCount)
- }
- else {
- waitConnNodes.foreach { case (peerRank, peerRef) =>
- peerRef ! RequestWorkerHostPort
- }
-
- // a countdown for DivulgedHostPort messages.
- stay using DataStruct(Seq.empty[DataField], waitConnNodes.size - 1)
- }
- }
-
- case Event(DivulgedWorkerHostPort(peerRank, peerHost, peerPort), data) =>
- val hostBytes = peerHost.getBytes()
- val buffer = ByteBuffer.allocate(4 * 3 + hostBytes.length)
- .order(ByteOrder.nativeOrder())
- buffer.putInt(peerHost.length).put(hostBytes)
- .putInt(peerPort).putInt(peerRank)
-
- buffer.flip()
- connection ! Tcp.Write(ByteString.fromByteBuffer(buffer))
-
- if (data.counter == 0) {
- // to prevent reading before state transition
- connection ! Tcp.SuspendReading
- goto(AwaitingErrorCount)
- }
- else {
- stay using data.decrement()
- }
- }
-
- when(AwaitingErrorCount) {
- case Event(Tcp.Received(numErrors), _) =>
- val buf = numErrors.asNativeOrderByteBuffer
-
- buf.getInt match {
- case 0 =>
- stashSpillOver(buf)
- goto(AwaitingPortNumber)
- case _ =>
- stashSpillOver(buf)
- goto(BuildingLinkMap) using StructNodes
- }
- }
-
- when(AwaitingPortNumber) {
- case Event(Tcp.Received(assignedPort), _) =>
- assert(assignedPort.length == 4)
- port = assignedPort.asNativeOrderByteBuffer.getInt
- log.debug(s"Rank $rank listening @ $host:$port")
- // wait until the worker closes connection.
- if (peerClosed) goto(SetupComplete) else stay
-
- case Event(Tcp.PeerClosed, _) =>
- peerClosed = true
- if (port == 0) stay else goto(SetupComplete)
- }
-
- when(SetupComplete) {
- case Event(ReduceWaitCount(count: Int), _) =>
- awaitingAcceptance -= count
- // check peerClosed to avoid prematurely stopping this actor (which sends RST to worker)
- if (awaitingAcceptance == 0 && peerClosed) {
- tracker ! DropFromWaitingList(rank)
- // no longer needed.
- context.stop(self)
- }
- stay
-
- case Event(AcknowledgeAcceptance(peers, numBad), _) =>
- awaitingAcceptance = numBad
- tracker ! WorkerStarted(host, rank, awaitingAcceptance)
- peers.values.foreach { peer =>
- peer ! ReduceWaitCount(1)
- }
-
- if (awaitingAcceptance == 0 && peerClosed) self ! PoisonPill
-
- stay
-
- // can only divulge the complete host and port information
- // when this worker is declared fully connected (otherwise
- // port information is still missing.)
- case Event(RequestWorkerHostPort, _) =>
- sender() ! DivulgedWorkerHostPort(rank, host, port)
- stay
- }
-
- onTransition {
- // reset buffer when state transitions as data becomes stale
- case _ -> SetupComplete =>
- connection ! Tcp.ResumeReading
- resetBuffers()
- if (pendingAcknowledgement.isDefined) {
- self ! pendingAcknowledgement.get
- }
- case _ =>
- connection ! Tcp.ResumeReading
- resetBuffers()
- }
-
- // default message handler
- whenUnhandled {
- case Event(Tcp.PeerClosed, _) =>
- peerClosed = true
- if (transient) context.stop(self)
- stay
- }
-}
-
-private[scala] object RabitWorkerHandler {
- val MAGIC_NUMBER = 0xff99
-
- // Finite states of this actor, which acts like a FSM.
- // The following states are defined in order as the FSM progresses.
- sealed trait State
-
- // [1] Initial state, awaiting worker to send magic number per protocol.
- case object AwaitingHandshake extends State
- // [2] Awaiting worker to send command (start/print/recover/shutdown etc.)
- case object AwaitingCommand extends State
- // [3] Brokers connections between workers per ring/tree/parent link map.
- case object BuildingLinkMap extends State
- // [4] A transient state in which the worker reports the number of errors in establishing
- // connections to other peer workers. If no errors, transition to next state.
- case object AwaitingErrorCount extends State
- // [5] Awaiting the worker to report its port number for accepting connections from peer workers.
- // This port number information is later forwarded to linked workers.
- case object AwaitingPortNumber extends State
- // [6] Final state after completing the setup with the connecting worker. At this stage, the
- // worker will have closed the Tcp connection. The actor remains alive to handle messages from
- // peer actors representing workers with pending setups.
- case object SetupComplete extends State
-
- sealed trait DataField
- case object IntField extends DataField
- // an integer preceding the actual string
- case object StringField extends DataField
- case object IntSeqField extends DataField
-
- object DataStruct {
- def apply(): DataStruct = DataStruct(Seq.empty[DataField], 0)
- }
-
- // Internal data pertaining to individual state, used to verify the validity of packets sent by
- // workers.
- case class DataStruct(fields: Seq[DataField], counter: Int) {
- /**
- * Validate whether the provided buffer is complete (i.e., contains
- * all data fields specified for this DataStruct.)
- *
- * @param buf a byte buffer containing received data.
- */
- def verify(buf: ByteBuffer): Boolean = {
- if (fields.isEmpty) return true
-
- val dupBuf = buf.duplicate().order(ByteOrder.nativeOrder())
- dupBuf.flip()
-
- Try(fields.foldLeft(true) {
- case (complete, field) =>
- val remBytes = dupBuf.remaining()
- complete && (remBytes > 0) && (remBytes >= (field match {
- case IntField =>
- dupBuf.position(dupBuf.position() + 4)
- 4
- case StringField =>
- val strLen = dupBuf.getInt
- dupBuf.position(dupBuf.position() + strLen)
- 4 + strLen
- case IntSeqField =>
- val seqLen = dupBuf.getInt
- dupBuf.position(dupBuf.position() + seqLen * 4)
- 4 + seqLen * 4
- }))
- }).getOrElse(false)
- }
-
- def increment(): DataStruct = DataStruct(fields, counter + 1)
- def decrement(): DataStruct = DataStruct(fields, counter - 1)
- }
-
- val StructNodes = DataStruct(List(IntSeqField), 0)
- val StructTrackerCommand = DataStruct(List(
- IntField, IntField, StringField, StringField
- ), 0)
-
- // ---- Messages between RabitTrackerHandler and RabitTrackerConnectionHandler ----
-
- // RabitWorkerHandler --> RabitTrackerHandler
- sealed trait RabitWorkerRequest
- // RabitWorkerHandler <-- RabitTrackerHandler
- sealed trait RabitWorkerResponse
-
- // Representations of decoded worker commands.
- abstract class TrackerCommand(val command: String) extends RabitWorkerRequest {
- def rank: Int
- def worldSize: Int
- def jobId: String
-
- def encode: ByteString = {
- val buf = ByteBuffer.allocate(4 * 4 + jobId.length + command.length)
- .order(ByteOrder.nativeOrder())
-
- buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
- .putInt(command.length).put(command.getBytes()).flip()
-
- ByteString.fromByteBuffer(buf)
- }
- }
-
- case class WorkerStart(rank: Int, worldSize: Int, jobId: String)
- extends TrackerCommand("start")
- case class WorkerShutdown(rank: Int, worldSize: Int, jobId: String)
- extends TrackerCommand("shutdown")
- case class WorkerRecover(rank: Int, worldSize: Int, jobId: String)
- extends TrackerCommand("recover")
- case class WorkerTrackerPrint(rank: Int, worldSize: Int, jobId: String, msg: String)
- extends TrackerCommand("print") {
-
- override def encode: ByteString = {
- val buf = ByteBuffer.allocate(4 * 5 + jobId.length + command.length + msg.length)
- .order(ByteOrder.nativeOrder())
-
- buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
- .putInt(command.length).put(command.getBytes())
- .putInt(msg.length).put(msg.getBytes()).flip()
-
- ByteString.fromByteBuffer(buf)
- }
- }
-
- // Request to remove the worker of given rank from the list of workers awaiting peer connections.
- case class DropFromWaitingList(rank: Int) extends RabitWorkerRequest
- // Notify the tracker that the worker of given rank has finished setup and started.
- case class WorkerStarted(host: String, rank: Int, awaitingAcceptance: Int)
- extends RabitWorkerRequest
- // Request the set of workers to connect to, according to the LinkMap structure.
- case class RequestAwaitConnWorkers(rank: Int, toConnectSet: Set[Int])
- extends RabitWorkerRequest
-
- // Request, from the tracker, the set of nodes to connect.
- case class AwaitingConnections(workers: Map[Int, ActorRef], numBad: Int)
- extends RabitWorkerResponse
-
- // ---- Messages between ConnectionHandler actors ----
- sealed trait IntraWorkerMessage
-
- // Notify neighboring workers to decrease the counter of awaiting workers by `count`.
- case class ReduceWaitCount(count: Int) extends IntraWorkerMessage
- // Request host and port information from peer ConnectionHandler actors (acting on behave of
- // connecting workers.) This message will be brokered by RabitTrackerHandler.
- case object RequestWorkerHostPort extends IntraWorkerMessage
- // Response to the above request
- case class DivulgedWorkerHostPort(rank: Int, host: String, port: Int) extends IntraWorkerMessage
- // A reminder to send ReduceWaitCount messages once the actor is in state "SetupComplete".
- case class AcknowledgeAcceptance(peers: Map[Int, ActorRef], numBad: Int)
- extends IntraWorkerMessage
-
- // ---- End of message definitions ----
-
- def props(host: String, worldSize: Int, tracker: ActorRef, connection: ActorRef): Props = {
- Props(new RabitWorkerHandler(host, worldSize, tracker, connection))
- }
-}
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/LinkMap.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/LinkMap.scala
deleted file mode 100644
index edec4931b..000000000
--- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/LinkMap.scala
+++ /dev/null
@@ -1,136 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.rabit.util
-
-import java.nio.{ByteBuffer, ByteOrder}
-
-/**
- * The assigned rank to a connecting Rabit worker, along with the information of the ranks of
- * its linked peer workers, which are critical to perform Allreduce.
- * When RabitWorkerHandler delegates "start" or "recover" commands from the connecting worker
- * client, RabitTrackerHandler utilizes LinkMap to figure out linkage relationships, and respond
- * with this class as a message, which is later encoded to byte string, and sent over socket
- * connection to the worker client.
- *
- * @param rank assigned rank (ranked by worker connection order: first worker connecting to the
- * tracker is assigned rank 0, second with rank 1, etc.)
- * @param neighbors ranks of neighboring workers in a tree map.
- * @param ring ranks of neighboring workers in a ring map.
- * @param parent rank of the parent worker.
- */
-private[rabit] case class AssignedRank(rank: Int, neighbors: Seq[Int],
- ring: (Int, Int), parent: Int) {
- /**
- * Encode the AssignedRank message into byte sequence for socket communication with Rabit worker
- * client.
- * @param worldSize the number of total distributed workers. Must match `numWorkers` used in
- * LinkMap.
- * @return a ByteBuffer containing encoded data.
- */
- def toByteBuffer(worldSize: Int): ByteBuffer = {
- val buffer = ByteBuffer.allocate(4 * (neighbors.length + 6)).order(ByteOrder.nativeOrder())
- buffer.putInt(rank).putInt(parent).putInt(worldSize).putInt(neighbors.length)
- // neighbors in tree structure
- neighbors.foreach { n => buffer.putInt(n) }
- buffer.putInt(if (ring._1 != -1 && ring._1 != rank) ring._1 else -1)
- buffer.putInt(if (ring._2 != -1 && ring._2 != rank) ring._2 else -1)
-
- buffer.flip()
- buffer
- }
-}
-
-private[rabit] class LinkMap(numWorkers: Int) {
- private def getNeighbors(rank: Int): Seq[Int] = {
- val rank1 = rank + 1
- Vector(rank1 / 2 - 1, rank1 * 2 - 1, rank1 * 2).filter { r =>
- r >= 0 && r < numWorkers
- }
- }
-
- /**
- * Construct a ring structure that tends to share nodes with the tree.
- *
- * @param treeMap
- * @param parentMap
- * @param rank
- * @return Seq[Int] instance starting from rank.
- */
- private def constructShareRing(treeMap: Map[Int, Seq[Int]],
- parentMap: Map[Int, Int],
- rank: Int = 0): Seq[Int] = {
- treeMap(rank).toSet - parentMap(rank) match {
- case emptySet if emptySet.isEmpty =>
- List(rank)
- case connectionSet =>
- connectionSet.zipWithIndex.foldLeft(List(rank)) {
- case (ringSeq, (v, cnt)) =>
- val vConnSeq = constructShareRing(treeMap, parentMap, v)
- vConnSeq match {
- case vconn if vconn.size == cnt + 1 =>
- ringSeq ++ vconn.reverse
- case vconn =>
- ringSeq ++ vconn
- }
- }
- }
- }
- /**
- * Construct a ring connection used to recover local data.
- *
- * @param treeMap
- * @param parentMap
- */
- private def constructRingMap(treeMap: Map[Int, Seq[Int]], parentMap: Map[Int, Int]) = {
- assert(parentMap(0) == -1)
-
- val sharedRing = constructShareRing(treeMap, parentMap, 0).toVector
- assert(sharedRing.length == treeMap.size)
-
- (0 until numWorkers).map { r =>
- val rPrev = (r + numWorkers - 1) % numWorkers
- val rNext = (r + 1) % numWorkers
- sharedRing(r) -> (sharedRing(rPrev), sharedRing(rNext))
- }.toMap
- }
-
- private[this] val treeMap_ = (0 until numWorkers).map { r => r -> getNeighbors(r) }.toMap
- private[this] val parentMap_ = (0 until numWorkers).map{ r => r -> ((r + 1) / 2 - 1) }.toMap
- private[this] val ringMap_ = constructRingMap(treeMap_, parentMap_)
- val rMap_ = (0 until (numWorkers - 1)).foldLeft((Map(0 -> 0), 0)) {
- case ((rmap, k), i) =>
- val kNext = ringMap_(k)._2
- (rmap ++ Map(kNext -> (i + 1)), kNext)
- }._1
-
- val ringMap = ringMap_.map {
- case (k, (v0, v1)) => rMap_(k) -> (rMap_(v0), rMap_(v1))
- }
- val treeMap = treeMap_.map {
- case (k, vSeq) => rMap_(k) -> vSeq.map{ v => rMap_(v) }
- }
- val parentMap = parentMap_.map {
- case (k, v) if k == 0 =>
- rMap_(k) -> -1
- case (k, v) =>
- rMap_(k) -> rMap_(v)
- }
-
- def assignRank(rank: Int): AssignedRank = {
- AssignedRank(rank, treeMap(rank), ringMap(rank), parentMap(rank))
- }
-}
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/RabitTrackerHelpers.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/RabitTrackerHelpers.scala
deleted file mode 100644
index 3d7be618d..000000000
--- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/util/RabitTrackerHelpers.scala
+++ /dev/null
@@ -1,39 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.rabit.util
-
-import java.nio.{ByteOrder, ByteBuffer}
-import akka.util.ByteString
-
-private[rabit] object RabitTrackerHelpers {
- implicit class ByteStringHelplers(bs: ByteString) {
- // Java by default uses big endian. Enforce native endian so that
- // the byte order is consistent with the workers.
- def asNativeOrderByteBuffer: ByteBuffer = {
- bs.asByteBuffer.order(ByteOrder.nativeOrder())
- }
- }
-
- implicit class ByteBufferHelpers(buf: ByteBuffer) {
- def getString: String = {
- val len = buf.getInt()
- val stringBuffer = ByteBuffer.allocate(len).order(ByteOrder.nativeOrder())
- buf.get(stringBuffer.array(), 0, len)
- new String(stringBuffer.array(), "utf-8")
- }
- }
-}
diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java
index cce1254d0..20a243f5b 100644
--- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java
+++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java
@@ -30,8 +30,8 @@ import org.junit.Test;
* @author hzx
*/
public class BoosterImplTest {
- private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1";
- private String test_uri = "../../demo/data/agaricus.txt.test?indexing_mode=1";
+ private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1&format=libsvm";
+ private String test_uri = "../../demo/data/agaricus.txt.test?indexing_mode=1&format=libsvm";
public static class EvalError implements IEvaluation {
@Override
diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java
index cf174c6dd..d658c5529 100644
--- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java
+++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java
@@ -4,7 +4,7 @@
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
-
+
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
@@ -88,7 +88,7 @@ public class DMatrixTest {
public void testCreateFromFile() throws XGBoostError {
//create DMatrix from file
String filePath = writeResourceIntoTempFile("/agaricus.txt.test");
- DMatrix dmat = new DMatrix(filePath);
+ DMatrix dmat = new DMatrix(filePath + "?format=libsvm");
//get label
float[] labels = dmat.getLabel();
//check length
diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/DMatrixSuite.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/DMatrixSuite.scala
index 05200f49e..53325effa 100644
--- a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/DMatrixSuite.scala
+++ b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/DMatrixSuite.scala
@@ -20,12 +20,12 @@ import java.util.Arrays
import scala.util.Random
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.AnyFunSuite
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
-class DMatrixSuite extends FunSuite {
+class DMatrixSuite extends AnyFunSuite {
test("create DMatrix from File") {
- val dmat = new DMatrix("../../demo/data/agaricus.txt.test")
+ val dmat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
// get label
val labels: Array[Float] = dmat.getLabel
// check length
diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala
index 157971f82..2eda1fa2d 100644
--- a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala
+++ b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala
@@ -20,11 +20,11 @@ import java.io.{FileOutputStream, FileInputStream, File}
import junit.framework.TestCase
import org.apache.commons.logging.LogFactory
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.AnyFunSuite
import ml.dmlc.xgboost4j.java.XGBoostError
-class ScalaBoosterImplSuite extends FunSuite {
+class ScalaBoosterImplSuite extends AnyFunSuite {
private class EvalError extends EvalTrait {
@@ -95,8 +95,8 @@ class ScalaBoosterImplSuite extends FunSuite {
}
test("basic operation of booster") {
- val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
- val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val booster = trainBooster(trainMat, testMat)
val predicts = booster.predict(testMat, true)
@@ -106,8 +106,8 @@ class ScalaBoosterImplSuite extends FunSuite {
test("save/load model with path") {
- val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
- val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val eval = new EvalError
val booster = trainBooster(trainMat, testMat)
// save and load
@@ -123,8 +123,8 @@ class ScalaBoosterImplSuite extends FunSuite {
}
test("save/load model with stream") {
- val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
- val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val eval = new EvalError
val booster = trainBooster(trainMat, testMat)
// save and load
@@ -139,7 +139,7 @@ class ScalaBoosterImplSuite extends FunSuite {
}
test("cross validation") {
- val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val params = List("eta" -> "1.0", "max_depth" -> "3", "silent" -> "1", "nthread" -> "6",
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
val round = 2
@@ -148,8 +148,8 @@ class ScalaBoosterImplSuite extends FunSuite {
}
test("test with quantile histo depthwise") {
- val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
- val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val paramMap = List("max_depth" -> "3", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "eval_metric" -> "auc").toMap
@@ -158,8 +158,8 @@ class ScalaBoosterImplSuite extends FunSuite {
}
test("test with quantile histo lossguide") {
- val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
- val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val paramMap = List("max_depth" -> "3", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "lossguide", "max_leaves" -> "8", "eval_metric" -> "auc").toMap
@@ -168,8 +168,8 @@ class ScalaBoosterImplSuite extends FunSuite {
}
test("test with quantile histo lossguide with max bin") {
- val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
- val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val paramMap = List("max_depth" -> "3", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
@@ -179,8 +179,8 @@ class ScalaBoosterImplSuite extends FunSuite {
}
test("test with quantile histo depthwidth with max depth") {
- val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
- val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val paramMap = List("max_depth" -> "0", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "max_leaves" -> "8", "max_depth" -> "2",
@@ -190,8 +190,8 @@ class ScalaBoosterImplSuite extends FunSuite {
}
test("test with quantile histo depthwidth with max depth and max bin") {
- val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
- val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val paramMap = List("max_depth" -> "0", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
@@ -201,7 +201,7 @@ class ScalaBoosterImplSuite extends FunSuite {
}
test("test training from existing model in scala") {
- val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val paramMap = List("max_depth" -> "0", "silent" -> "0",
"objective" -> "binary:logistic", "tree_method" -> "hist",
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
@@ -213,8 +213,8 @@ class ScalaBoosterImplSuite extends FunSuite {
}
test("test getting number of features from a booster") {
- val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
- val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val booster = trainBooster(trainMat, testMat)
TestCase.assertEquals(booster.getNumFeature, 127)
diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTrackerConnectionHandlerTest.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTrackerConnectionHandlerTest.scala
deleted file mode 100644
index cd9016812..000000000
--- a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTrackerConnectionHandlerTest.scala
+++ /dev/null
@@ -1,255 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.rabit
-
-import java.nio.{ByteBuffer, ByteOrder}
-
-import akka.actor.{ActorRef, ActorSystem}
-import akka.io.Tcp
-import akka.testkit.{ImplicitSender, TestFSMRef, TestKit, TestProbe}
-import akka.util.ByteString
-import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler
-import ml.dmlc.xgboost4j.scala.rabit.handler.RabitWorkerHandler._
-import ml.dmlc.xgboost4j.scala.rabit.util.LinkMap
-import org.junit.runner.RunWith
-import org.scalatest.junit.JUnitRunner
-import org.scalatest.{FlatSpecLike, Matchers}
-
-import scala.concurrent.Promise
-
-object RabitTrackerConnectionHandlerTest {
- def intSeqToByteString(seq: Seq[Int]): ByteString = {
- val buf = ByteBuffer.allocate(seq.length * 4).order(ByteOrder.nativeOrder())
- seq.foreach { i => buf.putInt(i) }
- buf.flip()
- ByteString.fromByteBuffer(buf)
- }
-}
-
-@RunWith(classOf[JUnitRunner])
-class RabitTrackerConnectionHandlerTest
- extends TestKit(ActorSystem("RabitTrackerConnectionHandlerTest"))
- with FlatSpecLike with Matchers with ImplicitSender {
-
- import RabitTrackerConnectionHandlerTest._
-
- val magic = intSeqToByteString(List(0xff99))
-
- "RabitTrackerConnectionHandler" should "handle Rabit client 'start' command properly" in {
- val trackerProbe = TestProbe()
- val connProbe = TestProbe()
-
- val worldSize = 4
-
- val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize,
- trackerProbe.ref, connProbe.ref))
- fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
-
- // send mock magic number
- fsm ! Tcp.Received(magic)
- connProbe.expectMsg(Tcp.Write(magic))
-
- fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
- fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
- // ResumeReading should be seen once state transitions
- connProbe.expectMsg(Tcp.ResumeReading)
-
- // send mock tracker command in fragments: the handler should be able to handle it.
- val bufRank = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder())
- bufRank.putInt(0).putInt(worldSize).flip()
-
- val bufJobId = ByteBuffer.allocate(5).order(ByteOrder.nativeOrder())
- bufJobId.putInt(1).put(Array[Byte]('0')).flip()
-
- val bufCmd = ByteBuffer.allocate(9).order(ByteOrder.nativeOrder())
- bufCmd.putInt(5).put("start".getBytes()).flip()
-
- fsm ! Tcp.Received(ByteString.fromByteBuffer(bufRank))
- fsm ! Tcp.Received(ByteString.fromByteBuffer(bufJobId))
-
- // the state should not change for incomplete command data.
- fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
-
- // send the last fragment, and expect message at tracker actor.
- fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd))
- trackerProbe.expectMsg(WorkerStart(0, worldSize, "0"))
-
- val linkMap = new LinkMap(worldSize)
- val assignedRank = linkMap.assignRank(0)
- trackerProbe.reply(assignedRank)
-
- connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer(
- assignedRank.toByteBuffer(worldSize)
- )))
-
- // reading should be suspended upon transitioning to BuildingLinkMap
- connProbe.expectMsg(Tcp.SuspendReading)
- // state should transition with according state data changes.
- fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap
- fsm.stateData shouldEqual RabitWorkerHandler.StructNodes
- connProbe.expectMsg(Tcp.ResumeReading)
-
- // since the connection handler in test has rank 0, it will not have any nodes to connect to.
- fsm ! Tcp.Received(intSeqToByteString(List(0)))
- trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers))
-
- // return mock response to the connection handler
- val awaitConnPromise = Promise[AwaitingConnections]()
- awaitConnPromise.success(AwaitingConnections(Map.empty[Int, ActorRef],
- fsm.underlyingActor.getNeighboringWorkers.size
- ))
- fsm ! awaitConnPromise.future
- connProbe.expectMsg(Tcp.Write(
- intSeqToByteString(List(0, fsm.underlyingActor.getNeighboringWorkers.size))
- ))
- connProbe.expectMsg(Tcp.SuspendReading)
- fsm.stateName shouldEqual RabitWorkerHandler.AwaitingErrorCount
- connProbe.expectMsg(Tcp.ResumeReading)
-
- // send mock error count (0)
- fsm ! Tcp.Received(intSeqToByteString(List(0)))
-
- fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber
- connProbe.expectMsg(Tcp.ResumeReading)
-
- // simulate Tcp.PeerClosed event first, then Tcp.Received to test handling of async events.
- fsm ! Tcp.PeerClosed
- // state should not transition
- fsm.stateName shouldEqual RabitWorkerHandler.AwaitingPortNumber
- fsm ! Tcp.Received(intSeqToByteString(List(32768)))
-
- fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete
- connProbe.expectMsg(Tcp.ResumeReading)
-
- trackerProbe.expectMsg(RabitWorkerHandler.WorkerStarted("localhost", 0, 2))
-
- val handlerStopProbe = TestProbe()
- handlerStopProbe watch fsm
-
- // simulate connections from other workers by mocking ReduceWaitCount commands
- fsm ! RabitWorkerHandler.ReduceWaitCount(1)
- fsm.stateName shouldEqual RabitWorkerHandler.SetupComplete
- fsm ! RabitWorkerHandler.ReduceWaitCount(1)
- trackerProbe.expectMsg(RabitWorkerHandler.DropFromWaitingList(0))
- handlerStopProbe.expectTerminated(fsm)
-
- // all done.
- }
-
- it should "forward print command to tracker" in {
- val trackerProbe = TestProbe()
- val connProbe = TestProbe()
-
- val fsm = TestFSMRef(new RabitWorkerHandler("localhost", 4,
- trackerProbe.ref, connProbe.ref))
- fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
-
- fsm ! Tcp.Received(magic)
- connProbe.expectMsg(Tcp.Write(magic))
-
- fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
- fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
- // ResumeReading should be seen once state transitions
- connProbe.expectMsg(Tcp.ResumeReading)
-
- val printCmd = WorkerTrackerPrint(0, 4, "print", "hello world!")
- fsm ! Tcp.Received(printCmd.encode)
-
- trackerProbe.expectMsg(printCmd)
- }
-
- it should "handle fragmented print command without throwing exception" in {
- val trackerProbe = TestProbe()
- val connProbe = TestProbe()
-
- val fsm = TestFSMRef(new RabitWorkerHandler("localhost", 4,
- trackerProbe.ref, connProbe.ref))
- fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
-
- fsm ! Tcp.Received(magic)
- connProbe.expectMsg(Tcp.Write(magic))
-
- fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
- fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
- // ResumeReading should be seen once state transitions
- connProbe.expectMsg(Tcp.ResumeReading)
-
- val printCmd = WorkerTrackerPrint(0, 4, "0", "fragmented!")
- // 4 (rank: Int) + 4 (worldSize: Int) + (4+1) (jobId: String) + (4+5) (command: String) = 22
- val (partialMessage, remainder) = printCmd.encode.splitAt(22)
-
- // make sure that the partialMessage in itself is a valid command
- val partialMsgBuf = ByteBuffer.allocate(22).order(ByteOrder.nativeOrder())
- partialMsgBuf.put(partialMessage.asByteBuffer)
- RabitWorkerHandler.StructTrackerCommand.verify(partialMsgBuf) shouldBe true
-
- fsm ! Tcp.Received(partialMessage)
- fsm ! Tcp.Received(remainder)
-
- trackerProbe.expectMsg(printCmd)
- }
-
- it should "handle spill-over Tcp data correctly between state transition" in {
- val trackerProbe = TestProbe()
- val connProbe = TestProbe()
-
- val worldSize = 4
-
- val fsm = TestFSMRef(new RabitWorkerHandler("localhost", worldSize,
- trackerProbe.ref, connProbe.ref))
- fsm.stateName shouldEqual RabitWorkerHandler.AwaitingHandshake
-
- // send mock magic number
- fsm ! Tcp.Received(magic)
- connProbe.expectMsg(Tcp.Write(magic))
-
- fsm.stateName shouldEqual RabitWorkerHandler.AwaitingCommand
- fsm.stateData shouldEqual RabitWorkerHandler.StructTrackerCommand
- // ResumeReading should be seen once state transitions
- connProbe.expectMsg(Tcp.ResumeReading)
-
- // send mock tracker command in fragments: the handler should be able to handle it.
- val bufCmd = ByteBuffer.allocate(26).order(ByteOrder.nativeOrder())
- bufCmd.putInt(0).putInt(worldSize).putInt(1).put(Array[Byte]('0'))
- .putInt(5).put("start".getBytes())
- // spilled-over data
- .putInt(0).flip()
-
- // send data with 4 extra bytes corresponding to the next state.
- fsm ! Tcp.Received(ByteString.fromByteBuffer(bufCmd))
-
- trackerProbe.expectMsg(WorkerStart(0, worldSize, "0"))
-
- val linkMap = new LinkMap(worldSize)
- val assignedRank = linkMap.assignRank(0)
- trackerProbe.reply(assignedRank)
-
- connProbe.expectMsg(Tcp.Write(ByteString.fromByteBuffer(
- assignedRank.toByteBuffer(worldSize)
- )))
-
- // reading should be suspended upon transitioning to BuildingLinkMap
- connProbe.expectMsg(Tcp.SuspendReading)
- // state should transition with according state data changes.
- fsm.stateName shouldEqual RabitWorkerHandler.BuildingLinkMap
- fsm.stateData shouldEqual RabitWorkerHandler.StructNodes
- connProbe.expectMsg(Tcp.ResumeReading)
-
- // the handler should be able to handle spill-over data, and stash it until state transition.
- trackerProbe.expectMsg(RequestAwaitConnWorkers(0, fsm.underlyingActor.getNeighboringWorkers))
- }
-}
diff --git a/plugin/federated/README.md b/plugin/federated/README.md
index d83db6be1..631c44cee 100644
--- a/plugin/federated/README.md
+++ b/plugin/federated/README.md
@@ -19,7 +19,7 @@ cmake .. -GNinja \
-DUSE_NCCL=ON
ninja
cd ../python-package
-pip install -e . # or equivalently python setup.py develop
+pip install -e .
```
If CMake fails to locate gRPC, you may need to pass `-DCMAKE_PREFIX_PATH=` to CMake.
diff --git a/python-package/MANIFEST.in b/python-package/MANIFEST.in
deleted file mode 100644
index 23f2684c2..000000000
--- a/python-package/MANIFEST.in
+++ /dev/null
@@ -1,56 +0,0 @@
-include README.rst
-include xgboost/LICENSE
-include xgboost/VERSION
-include xgboost/CMakeLists.txt
-
-include xgboost/py.typed
-recursive-include xgboost *.py
-recursive-include xgboost/cmake *
-exclude xgboost/cmake/RPackageInstall.cmake.in
-exclude xgboost/cmake/RPackageInstallTargetSetup.cmake
-exclude xgboost/cmake/Sanitizer.cmake
-exclude xgboost/cmake/modules/FindASan.cmake
-exclude xgboost/cmake/modules/FindLSan.cmake
-exclude xgboost/cmake/modules/FindLibR.cmake
-exclude xgboost/cmake/modules/FindTSan.cmake
-exclude xgboost/cmake/modules/FindUBSan.cmake
-recursive-include xgboost/include *
-recursive-include xgboost/plugin *
-recursive-include xgboost/src *
-
-recursive-include xgboost/gputreeshap/GPUTreeShap *
-
-include xgboost/rabit/CMakeLists.txt
-recursive-include xgboost/rabit/include *
-recursive-include xgboost/rabit/src *
-prune xgboost/rabit/doc
-prune xgboost/rabit/guide
-
-include xgboost/dmlc-core/CMakeLists.txt
-
-recursive-include xgboost/dmlc-core/cmake *
-exclude xgboost/dmlc-core/cmake/gtest_cmake.in
-exclude xgboost/dmlc-core/cmake/lint.cmake
-exclude xgboost/dmlc-core/cmake/Sanitizer.cmake
-exclude xgboost/dmlc-core/cmake/Modules/FindASan.cmake
-exclude xgboost/dmlc-core/cmake/Modules/FindLSan.cmake
-exclude xgboost/dmlc-core/cmake/Modules/FindTSan.cmake
-exclude xgboost/dmlc-core/cmake/Modules/FindUBSan.cmake
-
-recursive-include xgboost/dmlc-core/include *
-recursive-include xgboost/dmlc-core/include *
-recursive-include xgboost/dmlc-core/make *
-recursive-include xgboost/dmlc-core/src *
-include xgboost/dmlc-core/tracker/dmlc-submit
-recursive-include xgboost/dmlc-core/tracker/dmlc_tracker *.py
-include xgboost/dmlc-core/tracker/yarn/build.bat
-include xgboost/dmlc-core/tracker/yarn/build.sh
-include xgboost/dmlc-core/tracker/yarn/pom.xml
-recursive-include xgboost/dmlc-core/tracker/yarn/src *
-include xgboost/dmlc-core/windows/dmlc.sln
-include xgboost/dmlc-core/windows/dmlc/dmlc.vcxproj
-
-prune xgboost/dmlc-core/doc
-prune xgboost/dmlc-core/scripts/
-
-global-exclude *.py[oc]
diff --git a/python-package/hatch_build.py b/python-package/hatch_build.py
new file mode 100644
index 000000000..696787fa2
--- /dev/null
+++ b/python-package/hatch_build.py
@@ -0,0 +1,22 @@
+"""
+Custom hook to customize the behavior of Hatchling.
+Here, we customize the tag of the generated wheels.
+"""
+import sysconfig
+from typing import Any, Dict
+
+from hatchling.builders.hooks.plugin.interface import BuildHookInterface
+
+
+def get_tag() -> str:
+ """Get appropriate wheel tag according to system"""
+ tag_platform = sysconfig.get_platform().replace("-", "_").replace(".", "_")
+ return f"py3-none-{tag_platform}"
+
+
+class CustomBuildHook(BuildHookInterface):
+ """A custom build hook"""
+
+ def initialize(self, version: str, build_data: Dict[str, Any]) -> None:
+ """This step ccurs immediately before each build."""
+ build_data["tag"] = get_tag()
diff --git a/python-package/packager/__init__.py b/python-package/packager/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/python-package/packager/build_config.py b/python-package/packager/build_config.py
new file mode 100644
index 000000000..290cf15db
--- /dev/null
+++ b/python-package/packager/build_config.py
@@ -0,0 +1,56 @@
+"""Build configuration"""
+import dataclasses
+from typing import Any, Dict, List, Optional
+
+
+@dataclasses.dataclass
+class BuildConfiguration: # pylint: disable=R0902
+ """Configurations use when building libxgboost"""
+
+ # Whether to hide C++ symbols in libxgboost.so
+ hide_cxx_symbols: bool = True
+ # Whether to enable OpenMP
+ use_openmp: bool = True
+ # Whether to enable CUDA
+ use_cuda: bool = False
+ # Whether to enable NCCL
+ use_nccl: bool = False
+ # Whether to enable HDFS
+ use_hdfs: bool = False
+ # Whether to enable Azure Storage
+ use_azure: bool = False
+ # Whether to enable AWS S3
+ use_s3: bool = False
+ # Whether to enable the dense parser plugin
+ plugin_dense_parser: bool = False
+ # Special option: See explanation below
+ use_system_libxgboost: bool = False
+
+ def _set_config_setting(
+ self, config_settings: Dict[str, Any], field_name: str
+ ) -> None:
+ if field_name in config_settings:
+ setattr(
+ self,
+ field_name,
+ (config_settings[field_name].lower() in ["true", "1", "on"]),
+ )
+ else:
+ raise ValueError(f"Field {field_name} is not a valid config_settings")
+
+ def update(self, config_settings: Optional[Dict[str, Any]]) -> None:
+ """Parse config_settings from Pip (or other PEP 517 frontend)"""
+ if config_settings is not None:
+ for field_name in [x.name for x in dataclasses.fields(self)]:
+ self._set_config_setting(config_settings, field_name)
+
+ def get_cmake_args(self) -> List[str]:
+ """Convert build configuration to CMake args"""
+ cmake_args = []
+ for field_name in [x.name for x in dataclasses.fields(self)]:
+ if field_name in ["use_system_libxgboost"]:
+ continue
+ cmake_option = field_name.upper()
+ cmake_value = "ON" if getattr(self, field_name) is True else "OFF"
+ cmake_args.append(f"-D{cmake_option}={cmake_value}")
+ return cmake_args
diff --git a/python-package/packager/nativelib.py b/python-package/packager/nativelib.py
new file mode 100644
index 000000000..f7f5b4e79
--- /dev/null
+++ b/python-package/packager/nativelib.py
@@ -0,0 +1,157 @@
+"""
+Functions for building libxgboost
+"""
+import logging
+import os
+import pathlib
+import shutil
+import subprocess
+import sys
+from platform import system
+from typing import Optional
+
+from .build_config import BuildConfiguration
+
+
+def _lib_name() -> str:
+ """Return platform dependent shared object name."""
+ if system() in ["Linux", "OS400"] or system().upper().endswith("BSD"):
+ name = "libxgboost.so"
+ elif system() == "Darwin":
+ name = "libxgboost.dylib"
+ elif system() == "Windows":
+ name = "xgboost.dll"
+ else:
+ raise NotImplementedError(f"System {system()} not supported")
+ return name
+
+
+def build_libxgboost(
+ cpp_src_dir: pathlib.Path,
+ build_dir: pathlib.Path,
+ build_config: BuildConfiguration,
+) -> pathlib.Path:
+ """Build libxgboost in a temporary directory and obtain the path to built libxgboost"""
+ logger = logging.getLogger("xgboost.packager.build_libxgboost")
+
+ if not cpp_src_dir.is_dir():
+ raise RuntimeError(f"Expected {cpp_src_dir} to be a directory")
+ logger.info(
+ "Building %s from the C++ source files in %s...", _lib_name(), str(cpp_src_dir)
+ )
+
+ def _build(*, generator: str) -> None:
+ cmake_cmd = [
+ "cmake",
+ str(cpp_src_dir),
+ generator,
+ "-DKEEP_BUILD_ARTIFACTS_IN_BINARY_DIR=ON",
+ ]
+ cmake_cmd.extend(build_config.get_cmake_args())
+
+ # Flag for cross-compiling for Apple Silicon
+ # We use environment variable because it's the only way to pass down custom flags
+ # through the cibuildwheel package, which calls `pip wheel` command.
+ if "CIBW_TARGET_OSX_ARM64" in os.environ:
+ cmake_cmd.append("-DCMAKE_OSX_ARCHITECTURES=arm64")
+
+ logger.info("CMake args: %s", str(cmake_cmd))
+ subprocess.check_call(cmake_cmd, cwd=build_dir)
+
+ if system() == "Windows":
+ subprocess.check_call(
+ ["cmake", "--build", ".", "--config", "Release"], cwd=build_dir
+ )
+ else:
+ nproc = os.cpu_count()
+ assert build_tool is not None
+ subprocess.check_call([build_tool, f"-j{nproc}"], cwd=build_dir)
+
+ if system() == "Windows":
+ supported_generators = (
+ "-GVisual Studio 17 2022",
+ "-GVisual Studio 16 2019",
+ "-GVisual Studio 15 2017",
+ "-GMinGW Makefiles",
+ )
+ for generator in supported_generators:
+ try:
+ _build(generator=generator)
+ logger.info(
+ "Successfully built %s using generator %s", _lib_name(), generator
+ )
+ break
+ except subprocess.CalledProcessError as e:
+ logger.info(
+ "Tried building with generator %s but failed with exception %s",
+ generator,
+ str(e),
+ )
+ # Empty build directory
+ shutil.rmtree(build_dir)
+ build_dir.mkdir()
+ else:
+ raise RuntimeError(
+ "None of the supported generators produced a successful build!"
+ f"Supported generators: {supported_generators}"
+ )
+ else:
+ build_tool = "ninja" if shutil.which("ninja") else "make"
+ generator = "-GNinja" if build_tool == "ninja" else "-GUnix Makefiles"
+ try:
+ _build(generator=generator)
+ except subprocess.CalledProcessError as e:
+ logger.info("Failed to build with OpenMP. Exception: %s", str(e))
+ build_config.use_openmp = False
+ _build(generator=generator)
+
+ return build_dir / "lib" / _lib_name()
+
+
+def locate_local_libxgboost(
+ toplevel_dir: pathlib.Path,
+ logger: logging.Logger,
+) -> Optional[pathlib.Path]:
+ """
+ Locate libxgboost from the local project directory's lib/ subdirectory.
+ """
+ libxgboost = toplevel_dir.parent / "lib" / _lib_name()
+ if libxgboost.exists():
+ logger.info("Found %s at %s", libxgboost.name, str(libxgboost.parent))
+ return libxgboost
+ return None
+
+
+def locate_or_build_libxgboost(
+ toplevel_dir: pathlib.Path,
+ build_dir: pathlib.Path,
+ build_config: BuildConfiguration,
+) -> pathlib.Path:
+ """Locate libxgboost; if not exist, build it"""
+ logger = logging.getLogger("xgboost.packager.locate_or_build_libxgboost")
+
+ libxgboost = locate_local_libxgboost(toplevel_dir, logger=logger)
+ if libxgboost is not None:
+ return libxgboost
+ if build_config.use_system_libxgboost:
+ # Find libxgboost from system prefix
+ sys_prefix = pathlib.Path(sys.prefix).absolute().resolve()
+ libxgboost = sys_prefix / "lib" / _lib_name()
+ if not libxgboost.exists():
+ raise RuntimeError(
+ f"use_system_libxgboost was specified but {_lib_name()} is "
+ f"not found in {libxgboost.parent}"
+ )
+
+ logger.info("Using system XGBoost: %s", str(libxgboost))
+ return libxgboost
+
+ if toplevel_dir.joinpath("cpp_src").exists():
+ # Source distribution; all C++ source files to be found in cpp_src/
+ cpp_src_dir = toplevel_dir.joinpath("cpp_src")
+ else:
+ # Probably running "pip install ." from python-package/
+ cpp_src_dir = toplevel_dir.parent
+ if not cpp_src_dir.joinpath("CMakeLists.txt").exists():
+ raise RuntimeError(f"Did not find CMakeLists.txt from {cpp_src_dir}")
+ return build_libxgboost(cpp_src_dir, build_dir=build_dir, build_config=build_config)
diff --git a/python-package/packager/pep517.py b/python-package/packager/pep517.py
new file mode 100644
index 000000000..56583e117
--- /dev/null
+++ b/python-package/packager/pep517.py
@@ -0,0 +1,157 @@
+"""
+Custom build backend for XGBoost Python package.
+Builds source distribution and binary wheels, following PEP 517 / PEP 660.
+Reuses components of Hatchling (https://github.com/pypa/hatch/tree/master/backend) for the sake
+of brevity.
+"""
+import dataclasses
+import logging
+import os
+import pathlib
+import tempfile
+from contextlib import contextmanager
+from typing import Any, Dict, Iterator, Optional, Union
+
+import hatchling.build
+
+from .build_config import BuildConfiguration
+from .nativelib import locate_local_libxgboost, locate_or_build_libxgboost
+from .sdist import copy_cpp_src_tree
+from .util import copy_with_logging, copytree_with_logging
+
+
+@contextmanager
+def cd(path: Union[str, pathlib.Path]) -> Iterator[str]: # pylint: disable=C0103
+ """
+ Temporarily change working directory.
+ TODO(hcho3): Remove this once we adopt Python 3.11, which implements contextlib.chdir.
+ """
+ path = str(path)
+ path = os.path.realpath(path)
+ cwd = os.getcwd()
+ os.chdir(path)
+ try:
+ yield path
+ finally:
+ os.chdir(cwd)
+
+
+TOPLEVEL_DIR = pathlib.Path(__file__).parent.parent.absolute().resolve()
+logging.basicConfig(level=logging.INFO)
+
+
+# Aliases
+get_requires_for_build_sdist = hatchling.build.get_requires_for_build_sdist
+get_requires_for_build_wheel = hatchling.build.get_requires_for_build_wheel
+get_requires_for_build_editable = hatchling.build.get_requires_for_build_editable
+
+
+def build_wheel(
+ wheel_directory: str,
+ config_settings: Optional[Dict[str, Any]] = None,
+ metadata_directory: Optional[str] = None,
+) -> str:
+ """Build a wheel"""
+ logger = logging.getLogger("xgboost.packager.build_wheel")
+
+ build_config = BuildConfiguration()
+ build_config.update(config_settings)
+ logger.info("Parsed build configuration: %s", dataclasses.asdict(build_config))
+
+ # Create tempdir with Python package + libxgboost
+ with tempfile.TemporaryDirectory() as td:
+ td_path = pathlib.Path(td)
+ build_dir = td_path / "libbuild"
+ build_dir.mkdir()
+
+ workspace = td_path / "whl_workspace"
+ workspace.mkdir()
+ logger.info("Copying project files to temporary directory %s", str(workspace))
+
+ copy_with_logging(TOPLEVEL_DIR / "pyproject.toml", workspace, logger=logger)
+ copy_with_logging(TOPLEVEL_DIR / "hatch_build.py", workspace, logger=logger)
+ copy_with_logging(TOPLEVEL_DIR / "README.rst", workspace, logger=logger)
+
+ pkg_path = workspace / "xgboost"
+ copytree_with_logging(TOPLEVEL_DIR / "xgboost", pkg_path, logger=logger)
+ lib_path = pkg_path / "lib"
+ lib_path.mkdir()
+ libxgboost = locate_or_build_libxgboost(
+ TOPLEVEL_DIR, build_dir=build_dir, build_config=build_config
+ )
+ copy_with_logging(libxgboost, lib_path, logger=logger)
+
+ with cd(workspace):
+ wheel_name = hatchling.build.build_wheel(
+ wheel_directory, config_settings, metadata_directory
+ )
+ return wheel_name
+
+
+def build_sdist(
+ sdist_directory: str,
+ config_settings: Optional[Dict[str, Any]] = None,
+) -> str:
+ """Build a source distribution"""
+ logger = logging.getLogger("xgboost.packager.build_sdist")
+
+ if config_settings:
+ raise NotImplementedError(
+ "XGBoost's custom build backend doesn't support config_settings option "
+ f"when building sdist. {config_settings=}"
+ )
+
+ cpp_src_dir = TOPLEVEL_DIR.parent
+ if not cpp_src_dir.joinpath("CMakeLists.txt").exists():
+ raise RuntimeError(f"Did not find CMakeLists.txt from {cpp_src_dir}")
+
+ # Create tempdir with Python package + C++ sources
+ with tempfile.TemporaryDirectory() as td:
+ td_path = pathlib.Path(td)
+
+ workspace = td_path / "sdist_workspace"
+ workspace.mkdir()
+ logger.info("Copying project files to temporary directory %s", str(workspace))
+
+ copy_with_logging(TOPLEVEL_DIR / "pyproject.toml", workspace, logger=logger)
+ copy_with_logging(TOPLEVEL_DIR / "hatch_build.py", workspace, logger=logger)
+ copy_with_logging(TOPLEVEL_DIR / "README.rst", workspace, logger=logger)
+
+ copytree_with_logging(
+ TOPLEVEL_DIR / "xgboost", workspace / "xgboost", logger=logger
+ )
+ copytree_with_logging(
+ TOPLEVEL_DIR / "packager", workspace / "packager", logger=logger
+ )
+
+ temp_cpp_src_dir = workspace / "cpp_src"
+ copy_cpp_src_tree(cpp_src_dir, target_dir=temp_cpp_src_dir, logger=logger)
+
+ with cd(workspace):
+ sdist_name = hatchling.build.build_sdist(sdist_directory, config_settings)
+ return sdist_name
+
+
+def build_editable(
+ wheel_directory: str,
+ config_settings: Optional[Dict[str, Any]] = None,
+ metadata_directory: Optional[str] = None,
+) -> str:
+ """Build an editable installation. We mostly delegate to Hatchling."""
+ logger = logging.getLogger("xgboost.packager.build_editable")
+
+ if config_settings:
+ raise NotImplementedError(
+ "XGBoost's custom build backend doesn't support config_settings option "
+ f"when building editable installation. {config_settings=}"
+ )
+
+ if locate_local_libxgboost(TOPLEVEL_DIR, logger=logger) is None:
+ raise RuntimeError(
+ "To use the editable installation, first build libxgboost with CMake. "
+ "See https://xgboost.readthedocs.io/en/latest/build.html for detailed instructions."
+ )
+
+ return hatchling.build.build_editable(
+ wheel_directory, config_settings, metadata_directory
+ )
diff --git a/python-package/packager/sdist.py b/python-package/packager/sdist.py
new file mode 100644
index 000000000..af9fbca0d
--- /dev/null
+++ b/python-package/packager/sdist.py
@@ -0,0 +1,27 @@
+"""
+Functions for building sdist
+"""
+import logging
+import pathlib
+
+from .util import copy_with_logging, copytree_with_logging
+
+
+def copy_cpp_src_tree(
+ cpp_src_dir: pathlib.Path, target_dir: pathlib.Path, logger: logging.Logger
+) -> None:
+ """Copy C++ source tree into build directory"""
+
+ for subdir in [
+ "src",
+ "include",
+ "dmlc-core",
+ "gputreeshap",
+ "rabit",
+ "cmake",
+ "plugin",
+ ]:
+ copytree_with_logging(cpp_src_dir / subdir, target_dir / subdir, logger=logger)
+
+ for filename in ["CMakeLists.txt", "LICENSE"]:
+ copy_with_logging(cpp_src_dir.joinpath(filename), target_dir, logger=logger)
diff --git a/python-package/packager/util.py b/python-package/packager/util.py
new file mode 100644
index 000000000..0fff062d7
--- /dev/null
+++ b/python-package/packager/util.py
@@ -0,0 +1,25 @@
+"""
+Utility functions for implementing PEP 517 backend
+"""
+import logging
+import pathlib
+import shutil
+
+
+def copytree_with_logging(
+ src: pathlib.Path, dest: pathlib.Path, logger: logging.Logger
+) -> None:
+ """Call shutil.copytree() with logging"""
+ logger.info("Copying %s -> %s", str(src), str(dest))
+ shutil.copytree(src, dest)
+
+
+def copy_with_logging(
+ src: pathlib.Path, dest: pathlib.Path, logger: logging.Logger
+) -> None:
+ """Call shutil.copy() with logging"""
+ if dest.is_dir():
+ logger.info("Copying %s -> %s", str(src), str(dest / src.name))
+ else:
+ logger.info("Copying %s -> %s", str(src), str(dest))
+ shutil.copy(src, dest)
diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml
new file mode 100644
index 000000000..8f120df5d
--- /dev/null
+++ b/python-package/pyproject.toml
@@ -0,0 +1,42 @@
+[build-system]
+requires = [
+ "hatchling>=1.12.1"
+]
+backend-path = ["."]
+build-backend = "packager.pep517"
+
+[project]
+name = "xgboost"
+version = "2.0.0-dev"
+authors = [
+ {name = "Hyunsu Cho", email = "chohyu01@cs.washington.edu"},
+ {name = "Jiaming Yuan", email = "jm.yuan@outlook.com"}
+]
+description = "XGBoost Python Package"
+readme = {file = "README.rst", content-type = "text/x-rst"}
+requires-python = ">=3.8"
+license = {text = "Apache-2.0"}
+classifiers = [
+ "License :: OSI Approved :: Apache Software License",
+ "Development Status :: 5 - Production/Stable",
+ "Operating System :: OS Independent",
+ "Programming Language :: Python",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10"
+]
+dependencies = [
+ "numpy",
+ "scipy"
+]
+
+[project.optional-dependencies]
+pandas = ["pandas"]
+scikit-learn = ["scikit-learn"]
+dask = ["dask", "pandas", "distributed"]
+datatable = ["datatable"]
+plotting = ["graphviz", "matplotlib"]
+pyspark = ["pyspark", "scikit-learn", "cloudpickle"]
+
+[tool.hatch.build.targets.wheel.hooks.custom]
diff --git a/python-package/xgboost/config.py b/python-package/xgboost/config.py
index c08a13150..1691d473f 100644
--- a/python-package/xgboost/config.py
+++ b/python-package/xgboost/config.py
@@ -16,7 +16,7 @@ def config_doc(
extra_note: Optional[str] = None,
parameters: Optional[str] = None,
returns: Optional[str] = None,
- see_also: Optional[str] = None
+ see_also: Optional[str] = None,
) -> Callable[[_F], _F]:
"""Decorator to format docstring for config functions.
diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py
index 88bd1c819..35c5c009f 100644
--- a/python-package/xgboost/dask.py
+++ b/python-package/xgboost/dask.py
@@ -73,6 +73,7 @@ from .core import (
_deprecate_positional_args,
_expect,
)
+from .data import _is_cudf_ser, _is_cupy_array
from .sklearn import (
XGBClassifier,
XGBClassifierBase,
@@ -1894,10 +1895,15 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierMixIn, XGBClassifierBa
)
# pylint: disable=attribute-defined-outside-init
- if isinstance(y, (da.Array)):
+ if isinstance(y, da.Array):
self.classes_ = await self.client.compute(da.unique(y))
else:
self.classes_ = await self.client.compute(y.drop_duplicates())
+ if _is_cudf_ser(self.classes_):
+ self.classes_ = self.classes_.to_cupy()
+ if _is_cupy_array(self.classes_):
+ self.classes_ = self.classes_.get()
+ self.classes_ = numpy.array(self.classes_)
self.n_classes_ = len(self.classes_)
if self.n_classes_ > 2:
diff --git a/python-package/xgboost/plotting.py b/python-package/xgboost/plotting.py
index 71058e8c9..d9eb14d0f 100644
--- a/python-package/xgboost/plotting.py
+++ b/python-package/xgboost/plotting.py
@@ -30,7 +30,7 @@ def plot_importance(
grid: bool = True,
show_values: bool = True,
values_format: str = "{v}",
- **kwargs: Any
+ **kwargs: Any,
) -> Axes:
"""Plot importance based on fitted trees.
@@ -155,7 +155,7 @@ def to_graphviz(
no_color: Optional[str] = None,
condition_node_params: Optional[dict] = None,
leaf_node_params: Optional[dict] = None,
- **kwargs: Any
+ **kwargs: Any,
) -> GraphvizSource:
"""Convert specified tree to graphviz instance. IPython can automatically plot
the returned graphviz instance. Otherwise, you should call .render() method
@@ -250,7 +250,7 @@ def plot_tree(
num_trees: int = 0,
rankdir: Optional[str] = None,
ax: Optional[Axes] = None,
- **kwargs: Any
+ **kwargs: Any,
) -> Axes:
"""Plot specified tree.
diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py
index f2c5e1197..8f84459d7 100644
--- a/python-package/xgboost/spark/data.py
+++ b/python-package/xgboost/spark/data.py
@@ -219,7 +219,9 @@ def create_dmatrix_from_partitions( # pylint: disable=too-many-arguments
array: Optional[np.ndarray] = part[feature_cols]
elif part[name].shape[0] > 0:
array = part[name]
- array = stack_series(array)
+ if name == alias.data:
+ # For the array/vector typed case.
+ array = stack_series(array)
else:
array = None
diff --git a/python-package/xgboost/spark/params.py b/python-package/xgboost/spark/params.py
index 78a35eee0..7c3231431 100644
--- a/python-package/xgboost/spark/params.py
+++ b/python-package/xgboost/spark/params.py
@@ -1,4 +1,6 @@
"""Xgboost pyspark integration submodule for params."""
+from typing import Dict
+
# pylint: disable=too-few-public-methods
from pyspark.ml.param import TypeConverters
from pyspark.ml.param.shared import Param, Params
@@ -11,7 +13,7 @@ class HasArbitraryParamsDict(Params):
input.
"""
- arbitrary_params_dict: Param[dict] = Param(
+ arbitrary_params_dict: "Param[Dict]" = Param(
Params._dummy(),
"arbitrary_params_dict",
"arbitrary_params_dict This parameter holds all of the additional parameters which are "
diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py
index 20a4c681e..7bf3cf45b 100644
--- a/python-package/xgboost/testing/__init__.py
+++ b/python-package/xgboost/testing/__init__.py
@@ -317,13 +317,15 @@ class TestDataset:
enable_categorical=True,
)
- def get_device_dmat(self) -> xgb.QuantileDMatrix:
+ def get_device_dmat(self, max_bin: Optional[int]) -> xgb.QuantileDMatrix:
import cupy as cp
w = None if self.w is None else cp.array(self.w)
X = cp.array(self.X, dtype=np.float32)
y = cp.array(self.y, dtype=np.float32)
- return xgb.QuantileDMatrix(X, y, weight=w, base_margin=self.margin)
+ return xgb.QuantileDMatrix(
+ X, y, weight=w, base_margin=self.margin, max_bin=max_bin
+ )
def get_external_dmat(self) -> xgb.DMatrix:
n_samples = self.X.shape[0]
@@ -431,8 +433,11 @@ def make_ltr(
"""Make a dataset for testing LTR."""
rng = np.random.default_rng(1994)
X = rng.normal(0, 1.0, size=n_samples * n_features).reshape(n_samples, n_features)
- y = rng.integers(0, max_rel, size=n_samples)
- qid = rng.integers(0, n_query_groups, size=n_samples)
+ y = np.sum(X, axis=1)
+ y -= y.min()
+ y = np.round(y / y.max() * max_rel).astype(np.int32)
+
+ qid = rng.integers(0, n_query_groups, size=n_samples, dtype=np.int32)
w = rng.normal(0, 1.0, size=n_query_groups)
w -= np.min(w)
w /= np.max(w)
@@ -879,5 +884,12 @@ def data_dir(path: str) -> str:
return os.path.join(demo_dir(path), "data")
+def load_agaricus(path: str) -> Tuple[xgb.DMatrix, xgb.DMatrix]:
+ dpath = data_dir(path)
+ dtrain = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.train?format=libsvm"))
+ dtest = xgb.DMatrix(os.path.join(dpath, "agaricus.txt.test?format=libsvm"))
+ return dtrain, dtest
+
+
def project_root(path: str) -> str:
return normpath(os.path.join(demo_dir(path), os.path.pardir))
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 74a0107e1..a09a5499c 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -3,30 +3,50 @@
*/
#include "xgboost/c_api.h"
-#include
+#include // for copy
+#include // for strtoimax
+#include // for nan
+#include // for strcmp
+#include // for operator<<, basic_ostream, ios, stringstream
+#include // for less
+#include // for numeric_limits
+#include