merge 23Mar01

This commit is contained in:
amdsc21 2023-05-02 00:05:58 +02:00
commit 5446c501af
258 changed files with 7471 additions and 5379 deletions

View File

@ -40,7 +40,7 @@ jobs:
key: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }}
restore-keys: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }}
- name: Test XGBoost4J
- name: Test XGBoost4J (Core)
run: |
cd jvm-packages
mvn test -B -pl :xgboost4j_2.12
@ -67,7 +67,7 @@ jobs:
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_IAM_S3_UPLOADER }}
- name: Test XGBoost4J-Spark
- name: Test XGBoost4J (Core, Spark, Examples)
run: |
rm -rfv build/
cd jvm-packages

View File

@ -65,7 +65,7 @@ jobs:
run: |
cd python-package
python --version
python setup.py sdist
python -m build --sdist
pip install -v ./dist/xgboost-*.tar.gz
cd ..
python -c 'import xgboost'
@ -92,6 +92,9 @@ jobs:
auto-update-conda: true
python-version: ${{ matrix.python-version }}
activate-environment: test
- name: Install build
run: |
conda install -c conda-forge python-build
- name: Display Conda env
run: |
conda info
@ -100,7 +103,7 @@ jobs:
run: |
cd python-package
python --version
python setup.py sdist
python -m build --sdist
pip install -v ./dist/xgboost-*.tar.gz
cd ..
python -c 'import xgboost'
@ -147,7 +150,7 @@ jobs:
run: |
cd python-package
python --version
python setup.py install
pip install -v .
- name: Test Python package
run: |
@ -194,7 +197,7 @@ jobs:
run: |
cd python-package
python --version
python setup.py bdist_wheel --universal
pip wheel -v . --wheel-dir dist/
pip install ./dist/*.whl
- name: Test Python package
@ -238,7 +241,7 @@ jobs:
run: |
cd python-package
python --version
python setup.py install
pip install -v .
- name: Test Python package
run: |

View File

@ -54,7 +54,7 @@ jobs:
matrix:
config:
- {os: windows-latest, r: 'release', compiler: 'mingw', build: 'autotools'}
- {os: windows-latest, r: 'release', compiler: 'msvc', build: 'cmake'}
- {os: windows-latest, r: '4.2.0', compiler: 'msvc', build: 'cmake'}
env:
R_REMOTES_NO_ERRORS_FROM_WARNINGS: true
RSPM: ${{ matrix.config.rspm }}

View File

@ -47,6 +47,7 @@ option(USE_NVTX "Build with cuda profiling annotations. Developers only." OFF)
set(NVTX_HEADER_DIR "" CACHE PATH "Path to the stand-alone nvtx header")
option(RABIT_MOCK "Build rabit with mock" OFF)
option(HIDE_CXX_SYMBOLS "Build shared library and hide all C++ symbols" OFF)
option(KEEP_BUILD_ARTIFACTS_IN_BINARY_DIR "Output build artifacts in CMake binary dir" OFF)
## CUDA
option(USE_CUDA "Build with GPU acceleration" OFF)
option(USE_NCCL "Build with NCCL to enable distributed GPU support." OFF)
@ -312,8 +313,13 @@ if (JVM_BINDINGS)
xgboost_target_defs(xgboost4j)
endif (JVM_BINDINGS)
set_output_directory(runxgboost ${xgboost_SOURCE_DIR})
set_output_directory(xgboost ${xgboost_SOURCE_DIR}/lib)
if (KEEP_BUILD_ARTIFACTS_IN_BINARY_DIR)
set_output_directory(runxgboost ${xgboost_BINARY_DIR})
set_output_directory(xgboost ${xgboost_BINARY_DIR}/lib)
else ()
set_output_directory(runxgboost ${xgboost_SOURCE_DIR})
set_output_directory(xgboost ${xgboost_SOURCE_DIR}/lib)
endif ()
# Ensure these two targets do not build simultaneously, as they produce outputs with conflicting names
add_dependencies(xgboost runxgboost)

View File

@ -32,7 +32,7 @@ OBJECTS= \
$(PKGROOT)/src/objective/objective.o \
$(PKGROOT)/src/objective/regression_obj.o \
$(PKGROOT)/src/objective/multiclass_obj.o \
$(PKGROOT)/src/objective/rank_obj.o \
$(PKGROOT)/src/objective/lambdarank_obj.o \
$(PKGROOT)/src/objective/hinge.o \
$(PKGROOT)/src/objective/aft_obj.o \
$(PKGROOT)/src/objective/adaptive.o \

View File

@ -32,7 +32,7 @@ OBJECTS= \
$(PKGROOT)/src/objective/objective.o \
$(PKGROOT)/src/objective/regression_obj.o \
$(PKGROOT)/src/objective/multiclass_obj.o \
$(PKGROOT)/src/objective/rank_obj.o \
$(PKGROOT)/src/objective/lambdarank_obj.o \
$(PKGROOT)/src/objective/hinge.o \
$(PKGROOT)/src/objective/aft_obj.o \
$(PKGROOT)/src/objective/adaptive.o \

View File

@ -72,7 +72,7 @@ test_that("xgb.DMatrix: saving, loading", {
tmp <- c("0 1:1 2:1", "1 3:1", "0 1:1")
tmp_file <- tempfile(fileext = ".libsvm")
writeLines(tmp, tmp_file)
dtest4 <- xgb.DMatrix(tmp_file, silent = TRUE)
dtest4 <- xgb.DMatrix(paste(tmp_file, "?format=libsvm", sep = ""), silent = TRUE)
expect_equal(dim(dtest4), c(3, 4))
expect_equal(getinfo(dtest4, 'label'), c(0, 1, 0))

View File

@ -20,10 +20,10 @@ num_round = 2
# 0 means do not save any model except the final round model
save_period = 2
# The path of training data
data = "agaricus.txt.train"
data = "agaricus.txt.train?format=libsvm"
# The path of validation data, used to monitor training process, here [test] sets name of the validation set
eval[test] = "agaricus.txt.test"
eval[test] = "agaricus.txt.test?format=libsvm"
# evaluate on training data as well each round
eval_train = 1
# The path of test data
test:data = "agaricus.txt.test"
test:data = "agaricus.txt.test?format=libsvm"

View File

@ -21,8 +21,8 @@ num_round = 2
# 0 means do not save any model except the final round model
save_period = 0
# The path of training data
data = "machine.txt.train"
data = "machine.txt.train?format=libsvm"
# The path of validation data, used to monitor training process, here [test] sets name of the validation set
eval[test] = "machine.txt.test"
eval[test] = "machine.txt.test?format=libsvm"
# The path of test data
test:data = "machine.txt.test"
test:data = "machine.txt.test?format=libsvm"

View File

@ -42,8 +42,8 @@ int main() {
// load the data
DMatrixHandle dtrain, dtest;
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.train", silent, &dtrain));
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.test", silent, &dtest));
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.train?format=libsvm", silent, &dtrain));
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.test?format=libsvm", silent, &dtest));
// create the booster
BoosterHandle booster;

View File

@ -7,15 +7,19 @@ import os
import xgboost as xgb
CURRENT_DIR = os.path.dirname(__file__)
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test'))
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
dtrain = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
)
dtest = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
)
watchlist = [(dtest, "eval"), (dtrain, "train")]
###
# advanced: start from a initial base prediction
#
print('start running example to start from a initial prediction')
print("start running example to start from a initial prediction")
# specify parameters via map, definition are same as c++ version
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
# train xgboost for 1 round
bst = xgb.train(param, dtrain, 1, watchlist)
# Note: we need the margin value instead of transformed prediction in
@ -27,5 +31,5 @@ ptest = bst.predict(dtest, output_margin=True)
dtrain.set_base_margin(ptrain)
dtest.set_base_margin(ptest)
print('this is result of running from initial prediction')
print("this is result of running from initial prediction")
bst = xgb.train(param, dtrain, 1, watchlist)

View File

@ -10,27 +10,45 @@ import xgboost as xgb
# load data in do training
CURRENT_DIR = os.path.dirname(__file__)
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
param = {'max_depth':2, 'eta':1, 'objective':'binary:logistic'}
dtrain = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
)
param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
num_round = 2
print('running cross validation')
print("running cross validation")
# do cross validation, this will print result out as
# [iteration] metric_name:mean_value+std_value
# std_value is standard deviation of the metric
xgb.cv(param, dtrain, num_round, nfold=5,
metrics={'error'}, seed=0,
callbacks=[xgb.callback.EvaluationMonitor(show_stdv=True)])
xgb.cv(
param,
dtrain,
num_round,
nfold=5,
metrics={"error"},
seed=0,
callbacks=[xgb.callback.EvaluationMonitor(show_stdv=True)],
)
print('running cross validation, disable standard deviation display')
print("running cross validation, disable standard deviation display")
# do cross validation, this will print result out as
# [iteration] metric_name:mean_value
res = xgb.cv(param, dtrain, num_boost_round=10, nfold=5,
metrics={'error'}, seed=0,
callbacks=[xgb.callback.EvaluationMonitor(show_stdv=False),
xgb.callback.EarlyStopping(3)])
res = xgb.cv(
param,
dtrain,
num_boost_round=10,
nfold=5,
metrics={"error"},
seed=0,
callbacks=[
xgb.callback.EvaluationMonitor(show_stdv=False),
xgb.callback.EarlyStopping(3),
],
)
print(res)
print('running cross validation, with preprocessing function')
print("running cross validation, with preprocessing function")
# define the preprocessing function
# used to return the preprocessed training, test data, and parameter
# we can use this to do weight rescale, etc.
@ -38,32 +56,36 @@ print('running cross validation, with preprocessing function')
def fpreproc(dtrain, dtest, param):
label = dtrain.get_label()
ratio = float(np.sum(label == 0)) / np.sum(label == 1)
param['scale_pos_weight'] = ratio
param["scale_pos_weight"] = ratio
return (dtrain, dtest, param)
# do cross validation, for each fold
# the dtrain, dtest, param will be passed into fpreproc
# then the return value of fpreproc will be used to generate
# results of that fold
xgb.cv(param, dtrain, num_round, nfold=5,
metrics={'auc'}, seed=0, fpreproc=fpreproc)
xgb.cv(param, dtrain, num_round, nfold=5, metrics={"auc"}, seed=0, fpreproc=fpreproc)
###
# you can also do cross validation with customized loss function
# See custom_objective.py
##
print('running cross validation, with customized loss function')
print("running cross validation, with customized loss function")
def logregobj(preds, dtrain):
labels = dtrain.get_label()
preds = 1.0 / (1.0 + np.exp(-preds))
grad = preds - labels
hess = preds * (1.0 - preds)
return grad, hess
def evalerror(preds, dtrain):
labels = dtrain.get_label()
return 'error', float(sum(labels != (preds > 0.0))) / len(labels)
return "error", float(sum(labels != (preds > 0.0))) / len(labels)
param = {'max_depth':2, 'eta':1}
param = {"max_depth": 2, "eta": 1}
# train with customized objective
xgb.cv(param, dtrain, num_round, nfold=5, seed=0,
obj=logregobj, feval=evalerror)
xgb.cv(param, dtrain, num_round, nfold=5, seed=0, obj=logregobj, feval=evalerror)

View File

@ -7,28 +7,37 @@ import os
import xgboost as xgb
CURRENT_DIR = os.path.dirname(__file__)
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test'))
dtrain = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
)
dtest = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
)
param = [('max_depth', 2), ('objective', 'binary:logistic'), ('eval_metric', 'logloss'), ('eval_metric', 'error')]
param = [
("max_depth", 2),
("objective", "binary:logistic"),
("eval_metric", "logloss"),
("eval_metric", "error"),
]
num_round = 2
watchlist = [(dtest,'eval'), (dtrain,'train')]
watchlist = [(dtest, "eval"), (dtrain, "train")]
evals_result = {}
bst = xgb.train(param, dtrain, num_round, watchlist, evals_result=evals_result)
print('Access logloss metric directly from evals_result:')
print(evals_result['eval']['logloss'])
print("Access logloss metric directly from evals_result:")
print(evals_result["eval"]["logloss"])
print('')
print('Access metrics through a loop:')
print("")
print("Access metrics through a loop:")
for e_name, e_mtrs in evals_result.items():
print('- {}'.format(e_name))
print("- {}".format(e_name))
for e_mtr_name, e_mtr_vals in e_mtrs.items():
print(' - {}'.format(e_mtr_name))
print(' - {}'.format(e_mtr_vals))
print(" - {}".format(e_mtr_name))
print(" - {}".format(e_mtr_vals))
print('')
print('Access complete dictionary:')
print("")
print("Access complete dictionary:")
print(evals_result)

View File

@ -11,14 +11,22 @@ import xgboost as xgb
# basically, we are using linear model, instead of tree for our boosters
##
CURRENT_DIR = os.path.dirname(__file__)
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test'))
dtrain = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
)
dtest = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
)
# change booster to gblinear, so that we are fitting a linear model
# alpha is the L1 regularizer
# lambda is the L2 regularizer
# you can also set lambda_bias which is L2 regularizer on the bias term
param = {'objective':'binary:logistic', 'booster':'gblinear',
'alpha': 0.0001, 'lambda': 1}
param = {
"objective": "binary:logistic",
"booster": "gblinear",
"alpha": 0.0001,
"lambda": 1,
}
# normally, you do not need to set eta (step_size)
# XGBoost uses a parallel coordinate descent algorithm (shotgun),
@ -29,9 +37,15 @@ param = {'objective':'binary:logistic', 'booster':'gblinear',
##
# the rest of settings are the same
##
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
watchlist = [(dtest, "eval"), (dtrain, "train")]
num_round = 4
bst = xgb.train(param, dtrain, num_round, watchlist)
preds = bst.predict(dtest)
labels = dtest.get_label()
print('error=%f' % (sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds))))
print(
"error=%f"
% (
sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i])
/ float(len(preds))
)
)

View File

@ -16,8 +16,8 @@ test = os.path.join(CURRENT_DIR, "../data/agaricus.txt.test")
def native_interface():
# load data in do training
dtrain = xgb.DMatrix(train)
dtest = xgb.DMatrix(test)
dtrain = xgb.DMatrix(train + "?format=libsvm")
dtest = xgb.DMatrix(test + "?format=libsvm")
param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
watchlist = [(dtest, "eval"), (dtrain, "train")]
num_round = 3

View File

@ -8,14 +8,18 @@ import xgboost as xgb
# load data in do training
CURRENT_DIR = os.path.dirname(__file__)
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test'))
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
dtrain = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
)
dtest = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
)
param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
watchlist = [(dtest, "eval"), (dtrain, "train")]
num_round = 3
bst = xgb.train(param, dtrain, num_round, watchlist)
print('start testing predict the leaf indices')
print("start testing predict the leaf indices")
# predict using first 2 tree
leafindex = bst.predict(
dtest, iteration_range=(0, 2), pred_leaf=True, strict_shape=True

View File

@ -3,61 +3,12 @@
This directory contains a demo of Federated Learning using
[NVFlare](https://nvidia.github.io/NVFlare/).
## Training with CPU only
## Horizontal Federated XGBoost
To run the demo, first build XGBoost with the federated learning plugin enabled (see the
[README](../../plugin/federated/README.md)).
For horizontal federated learning using XGBoost (data is split row-wise), check out the `horizontal` directory
(see the [README](horizontal/README.md)).
Install NVFlare (note that currently NVFlare only supports Python 3.8):
```shell
pip install nvflare
```
## Vertical Federated XGBoost
Prepare the data:
```shell
./prepare_data.sh
```
Start the NVFlare federated server:
```shell
/tmp/nvflare/poc/server/startup/start.sh
```
In another terminal, start the first worker:
```shell
/tmp/nvflare/poc/site-1/startup/start.sh
```
And the second worker:
```shell
/tmp/nvflare/poc/site-2/startup/start.sh
```
Then start the admin CLI:
```shell
/tmp/nvflare/poc/admin/startup/fl_admin.sh
```
In the admin CLI, run the following command:
```shell
submit_job hello-xgboost
```
Once the training finishes, the model file should be written into
`/tmp/nvlfare/poc/site-1/run_1/test.model.json` and `/tmp/nvflare/poc/site-2/run_1/test.model.json`
respectively.
Finally, shutdown everything from the admin CLI, using `admin` as password:
```shell
shutdown client
shutdown server
```
## Training with GPUs
To demo with Federated Learning using GPUs, make sure your machine has at least 2 GPUs.
Build XGBoost with the federated learning plugin enabled along with CUDA, but with NCCL
turned off (see the [README](../../plugin/federated/README.md)).
Modify `config/config_fed_client.json` and set `use_gpus` to `true`, then repeat the steps
above.
For vertical federated learning using XGBoost (data is split column-wise), check out the `vertical` directory
(see the [README](vertical/README.md)).

View File

@ -1,23 +0,0 @@
{
"format_version": 2,
"executors": [
{
"tasks": [
"train"
],
"executor": {
"path": "trainer.XGBoostTrainer",
"args": {
"server_address": "localhost:9091",
"world_size": 2,
"server_cert_path": "server-cert.pem",
"client_key_path": "client-key.pem",
"client_cert_path": "client-cert.pem",
"use_gpus": "false"
}
}
}
],
"task_result_filters": [],
"task_data_filters": []
}

View File

@ -1,22 +0,0 @@
{
"format_version": 2,
"server": {
"heart_beat_timeout": 600
},
"task_data_filters": [],
"task_result_filters": [],
"workflows": [
{
"id": "server_workflow",
"path": "controller.XGBoostController",
"args": {
"port": 9091,
"world_size": 2,
"server_key_path": "server-key.pem",
"server_cert_path": "server-cert.pem",
"client_cert_path": "client-cert.pem"
}
}
],
"components": []
}

View File

@ -0,0 +1,63 @@
# Experimental Support of Horizontal Federated XGBoost using NVFlare
This directory contains a demo of Horizontal Federated Learning using
[NVFlare](https://nvidia.github.io/NVFlare/).
## Training with CPU only
To run the demo, first build XGBoost with the federated learning plugin enabled (see the
[README](../../plugin/federated/README.md)).
Install NVFlare (note that currently NVFlare only supports Python 3.8):
```shell
pip install nvflare
```
Prepare the data:
```shell
./prepare_data.sh
```
Start the NVFlare federated server:
```shell
/tmp/nvflare/poc/server/startup/start.sh
```
In another terminal, start the first worker:
```shell
/tmp/nvflare/poc/site-1/startup/start.sh
```
And the second worker:
```shell
/tmp/nvflare/poc/site-2/startup/start.sh
```
Then start the admin CLI:
```shell
/tmp/nvflare/poc/admin/startup/fl_admin.sh
```
In the admin CLI, run the following command:
```shell
submit_job horizontal-xgboost
```
Once the training finishes, the model file should be written into
`/tmp/nvlfare/poc/site-1/run_1/test.model.json` and `/tmp/nvflare/poc/site-2/run_1/test.model.json`
respectively.
Finally, shutdown everything from the admin CLI, using `admin` as password:
```shell
shutdown client
shutdown server
```
## Training with GPUs
To demo with Federated Learning using GPUs, make sure your machine has at least 2 GPUs.
Build XGBoost with the federated learning plugin enabled along with CUDA, but with NCCL
turned off (see the [README](../../plugin/federated/README.md)).
Modify `config/config_fed_client.json` and set `use_gpus` to `true`, then repeat the steps
above.

View File

@ -15,8 +15,8 @@ split -n l/${world_size} --numeric-suffixes=1 -a 1 ../data/agaricus.txt.train ag
split -n l/${world_size} --numeric-suffixes=1 -a 1 ../data/agaricus.txt.test agaricus.txt.test-site-
nvflare poc -n 2 --prepare
mkdir -p /tmp/nvflare/poc/admin/transfer/hello-xgboost
cp -fr config custom /tmp/nvflare/poc/admin/transfer/hello-xgboost
mkdir -p /tmp/nvflare/poc/admin/transfer/horizontal-xgboost
cp -fr config custom /tmp/nvflare/poc/admin/transfer/horizontal-xgboost
cp server-*.pem client-cert.pem /tmp/nvflare/poc/server/
for id in $(eval echo "{1..$world_size}"); do
cp server-cert.pem client-*.pem /tmp/nvflare/poc/site-"$id"/

View File

@ -0,0 +1,59 @@
# Experimental Support of Vertical Federated XGBoost using NVFlare
This directory contains a demo of Vertical Federated Learning using
[NVFlare](https://nvidia.github.io/NVFlare/).
## Training with CPU only
To run the demo, first build XGBoost with the federated learning plugin enabled (see the
[README](../../plugin/federated/README.md)).
Install NVFlare (note that currently NVFlare only supports Python 3.8):
```shell
pip install nvflare
```
Prepare the data (note that this step will download the HIGGS dataset, which is 2.6GB compressed, and 7.5GB
uncompressed, so make sure you have enough disk space and are on a fast internet connection):
```shell
./prepare_data.sh
```
Start the NVFlare federated server:
```shell
/tmp/nvflare/poc/server/startup/start.sh
```
In another terminal, start the first worker:
```shell
/tmp/nvflare/poc/site-1/startup/start.sh
```
And the second worker:
```shell
/tmp/nvflare/poc/site-2/startup/start.sh
```
Then start the admin CLI:
```shell
/tmp/nvflare/poc/admin/startup/fl_admin.sh
```
In the admin CLI, run the following command:
```shell
submit_job vertical-xgboost
```
Once the training finishes, the model file should be written into
`/tmp/nvlfare/poc/site-1/run_1/test.model.json` and `/tmp/nvflare/poc/site-2/run_1/test.model.json`
respectively.
Finally, shutdown everything from the admin CLI, using `admin` as password:
```shell
shutdown client
shutdown server
```
## Training with GPUs
Currently GPUs are not yet supported by vertical federated XGBoost.

View File

@ -0,0 +1,68 @@
"""
Example of training controller with NVFlare
===========================================
"""
import multiprocessing
from nvflare.apis.client import Client
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller, Task
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from trainer import SupportedTasks
import xgboost.federated
class XGBoostController(Controller):
def __init__(self, port: int, world_size: int, server_key_path: str,
server_cert_path: str, client_cert_path: str):
"""Controller for federated XGBoost.
Args:
port: the port for the gRPC server to listen on.
world_size: the number of sites.
server_key_path: the path to the server key file.
server_cert_path: the path to the server certificate file.
client_cert_path: the path to the client certificate file.
"""
super().__init__()
self._port = port
self._world_size = world_size
self._server_key_path = server_key_path
self._server_cert_path = server_cert_path
self._client_cert_path = client_cert_path
self._server = None
def start_controller(self, fl_ctx: FLContext):
self._server = multiprocessing.Process(
target=xgboost.federated.run_federated_server,
args=(self._port, self._world_size, self._server_key_path,
self._server_cert_path, self._client_cert_path))
self._server.start()
def stop_controller(self, fl_ctx: FLContext):
if self._server:
self._server.terminate()
def process_result_of_unknown_task(self, client: Client, task_name: str,
client_task_id: str, result: Shareable,
fl_ctx: FLContext):
self.log_warning(fl_ctx, f"Unknown task: {task_name} from client {client.name}.")
def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
self.log_info(fl_ctx, "XGBoost training control flow started.")
if abort_signal.triggered:
return
task = Task(name=SupportedTasks.TRAIN, data=Shareable())
self.broadcast_and_wait(
task=task,
min_responses=self._world_size,
fl_ctx=fl_ctx,
wait_time_after_min_received=1,
abort_signal=abort_signal,
)
if abort_signal.triggered:
return
self.log_info(fl_ctx, "XGBoost training control flow finished.")

View File

@ -0,0 +1,97 @@
import os
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import FLContextKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
import xgboost as xgb
from xgboost import callback
class SupportedTasks(object):
TRAIN = "train"
class XGBoostTrainer(Executor):
def __init__(self, server_address: str, world_size: int, server_cert_path: str,
client_key_path: str, client_cert_path: str):
"""Trainer for federated XGBoost.
Args:
server_address: address for the gRPC server to connect to.
world_size: the number of sites.
server_cert_path: the path to the server certificate file.
client_key_path: the path to the client key file.
client_cert_path: the path to the client certificate file.
"""
super().__init__()
self._server_address = server_address
self._world_size = world_size
self._server_cert_path = server_cert_path
self._client_key_path = client_key_path
self._client_cert_path = client_cert_path
def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext,
abort_signal: Signal) -> Shareable:
self.log_info(fl_ctx, f"Executing {task_name}")
try:
if task_name == SupportedTasks.TRAIN:
self._do_training(fl_ctx)
return make_reply(ReturnCode.OK)
else:
self.log_error(fl_ctx, f"{task_name} is not a supported task.")
return make_reply(ReturnCode.TASK_UNKNOWN)
except BaseException as e:
self.log_exception(fl_ctx,
f"Task {task_name} failed. Exception: {e.__str__()}")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
def _do_training(self, fl_ctx: FLContext):
client_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME)
rank = int(client_name.split('-')[1]) - 1
communicator_env = {
'xgboost_communicator': 'federated',
'federated_server_address': self._server_address,
'federated_world_size': self._world_size,
'federated_rank': rank,
'federated_server_cert': self._server_cert_path,
'federated_client_key': self._client_key_path,
'federated_client_cert': self._client_cert_path
}
with xgb.collective.CommunicatorContext(**communicator_env):
# Load file, file will not be sharded in federated mode.
if rank == 0:
label = '&label_column=0'
else:
label = ''
dtrain = xgb.DMatrix(f'higgs.train.csv?format=csv{label}', data_split_mode=1)
dtest = xgb.DMatrix(f'higgs.test.csv?format=csv{label}', data_split_mode=1)
# specify parameters via map
param = {
'validate_parameters': True,
'eta': 0.1,
'gamma': 1.0,
'max_depth': 8,
'min_child_weight': 100,
'tree_method': 'approx',
'grow_policy': 'depthwise',
'objective': 'binary:logistic',
'eval_metric': 'auc',
}
# specify validations set to watch performance
watchlist = [(dtest, "eval"), (dtrain, "train")]
# number of boosting rounds
num_round = 10
bst = xgb.train(param, dtrain, num_round, evals=watchlist, early_stopping_rounds=2)
# Save the model.
workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT)
run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN)
run_dir = workspace.get_run_dir(run_number)
bst.save_model(os.path.join(run_dir, "higgs.model.federated.vertical.json"))
xgb.collective.communicator_print("Finished training\n")

View File

@ -0,0 +1,65 @@
#!/bin/bash
set -e
rm -fr ./*.pem /tmp/nvflare/poc
world_size=2
# Generate server and client certificates.
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out server-cert.pem -subj "/C=US/CN=localhost"
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost"
# Download HIGGS dataset.
if [ -f "HIGGS.csv" ]; then
echo "HIGGS.csv exists, skipping download."
else
echo "Downloading HIGGS dataset."
wget https://archive.ics.uci.edu/ml/machine-learning-databases/00280/HIGGS.csv.gz
gunzip HIGGS.csv.gz
fi
# Split into train/test.
if [[ -f higgs.train.csv && -f higgs.test.csv ]]; then
echo "higgs.train.csv and higgs.test.csv exist, skipping split."
else
echo "Splitting HIGGS dataset into train/test."
head -n 10450000 HIGGS.csv > higgs.train.csv
tail -n 550000 HIGGS.csv > higgs.test.csv
fi
# Split train and test files by column to simulate a federated environment.
site_files=(higgs.{train,test}.csv-site-*)
if [ ${#site_files[@]} -eq $((world_size*2)) ]; then
echo "Site files exist, skipping split."
else
echo "Splitting train/test into site files."
total_cols=28 # plus label
cols=$((total_cols/world_size))
echo "Columns per site: $cols"
for (( site=1; site<=world_size; site++ )); do
if (( site == 1 )); then
start=$((cols*(site-1)+1))
else
start=$((cols*(site-1)+2))
fi
if (( site == world_size )); then
end=$((total_cols+1))
else
end=$((cols*site+1))
fi
echo "Site $site, columns $start-$end"
cut -d, -f${start}-${end} higgs.train.csv > higgs.train.csv-site-"${site}"
cut -d, -f${start}-${end} higgs.test.csv > higgs.test.csv-site-"${site}"
done
fi
nvflare poc -n 2 --prepare
mkdir -p /tmp/nvflare/poc/admin/transfer/vertical-xgboost
cp -fr config custom /tmp/nvflare/poc/admin/transfer/vertical-xgboost
cp server-*.pem client-cert.pem /tmp/nvflare/poc/server/
for (( site=1; site<=world_size; site++ )); do
cp server-cert.pem client-*.pem /tmp/nvflare/poc/site-"${site}"/
ln -s "${PWD}"/higgs.train.csv-site-"${site}" /tmp/nvflare/poc/site-"${site}"/higgs.train.csv
ln -s "${PWD}"/higgs.test.csv-site-"${site}" /tmp/nvflare/poc/site-"${site}"/higgs.test.csv
done

View File

@ -105,7 +105,7 @@ def make_pysrc_wheel(release: str, outdir: str) -> None:
os.mkdir(dist)
with DirectoryExcursion(os.path.join(ROOT, "python-package")):
subprocess.check_call(["python", "setup.py", "sdist"])
subprocess.check_call(["python", "-m", "build", "--sdist"])
src = os.path.join(DIST, f"xgboost-{release}.tar.gz")
subprocess.check_call(["twine", "check", src])
shutil.move(src, os.path.join(dist, f"xgboost-{release}.tar.gz"))

File diff suppressed because it is too large Load Diff

View File

@ -12,6 +12,7 @@ systems. If the instructions do not work for you, please feel free to ask quest
Consider installing XGBoost from a pre-built binary, to avoid the trouble of building XGBoost from the source. Checkout :doc:`Installation Guide </install>`.
.. contents:: Contents
:local:
.. _get_source:
@ -152,11 +153,11 @@ On Windows, run CMake as follows:
mkdir build
cd build
cmake .. -G"Visual Studio 14 2015 Win64" -DUSE_CUDA=ON
cmake .. -G"Visual Studio 17 2022" -A x64 -DUSE_CUDA=ON
(Change the ``-G`` option appropriately if you have a different version of Visual Studio installed.)
The above cmake configuration run will create an ``xgboost.sln`` solution file in the build directory. Build this solution in release mode as a x64 build, either from Visual studio or from command line:
The above cmake configuration run will create an ``xgboost.sln`` solution file in the build directory. Build this solution in Release mode, either from Visual studio or from command line:
.. code-block:: bash
@ -176,111 +177,104 @@ Building Python Package with Default Toolchains
===============================================
There are several ways to build and install the package from source:
1. Use Python setuptools directly
1. Build C++ core with CMake first
The XGBoost Python package supports most of the setuptools commands, here is a list of tested commands:
You can first build C++ library using CMake as described in :ref:`build_shared_lib`.
After compilation, a shared library will appear in ``lib/`` directory.
On Linux distributions, the shared library is ``lib/libxgboost.so``.
The install script ``pip install .`` will reuse the shared library instead of compiling
it from scratch, making it quite fast to run.
.. code-block:: console
$ cd python-package/
$ pip install . # Will re-use lib/libxgboost.so
2. Install the Python package directly
You can navigate to ``python-package/`` directory and install the Python package directly
by running
.. code-block:: console
$ cd python-package/
$ pip install -v .
which will compile XGBoost's native (C++) code using default CMake flags.
To enable additional compilation options, pass corresponding ``--config-settings``:
.. code-block:: console
$ pip install -v . --config-settings use_cuda=True --config-settings use_nccl=True
Use Pip 22.1 or later to use ``--config-settings`` option.
Here are the available options for ``--config-settings``:
.. literalinclude:: ../python-package/packager/build_config.py
:language: python
:start-at: @dataclasses.dataclass
:end-before: def _set_config_setting(
``use_system_libxgboost`` is a special option. See Item 4 below for
detailed description.
.. note:: Verbose flag recommended
As ``pip install .`` will build C++ code, it will take a while to complete.
To ensure that the build is progressing successfully, we suggest that
you add the verbose flag (``-v``) when invoking ``pip install``.
3. Editable installation
To further enable rapid development and iteration, we provide an **editable installation**.
In an editable installation, the installed package is simply a symbolic link to your
working copy of the XGBoost source code. So every changes you make to your source
directory will be immediately visible to the Python interpreter. Here is how to
install XGBoost as editable installation:
.. code-block:: bash
python setup.py install # Install the XGBoost to your current Python environment.
python setup.py build # Build the Python package.
python setup.py build_ext # Build only the C++ core.
python setup.py sdist # Create a source distribution
python setup.py bdist # Create a binary distribution
python setup.py bdist_wheel # Create a binary distribution with wheel format
Running ``python setup.py install`` will compile XGBoost using default CMake flags. For
passing additional compilation options, append the flags to the command. For example,
to enable CUDA acceleration and NCCL (distributed GPU) support:
.. code-block:: bash
python setup.py install --use-cuda --use-nccl
Please refer to ``setup.py`` for a complete list of available options. Some other
options used for development are only available for using CMake directly. See next
section on how to use CMake with setuptools manually.
You can install the created distribution packages using pip. For example, after running
``sdist`` setuptools command, a tar ball similar to ``xgboost-1.0.0.tar.gz`` will be
created under the ``dist`` directory. Then you can install it by invoking the following
command under ``dist`` directory:
.. code-block:: bash
# under python-package directory
cd dist
pip install ./xgboost-1.0.0.tar.gz
For details about these commands, please refer to the official document of `setuptools
<https://setuptools.readthedocs.io/en/latest/>`_, or just Google "how to install Python
package from source". XGBoost Python package follows the general convention.
Setuptools is usually available with your Python distribution, if not you can install it
via system command. For example on Debian or Ubuntu:
.. code-block:: bash
sudo apt-get install python-setuptools
For cleaning up the directory after running above commands, ``python setup.py clean`` is
not sufficient. After copying out the build result, simply running ``git clean -xdf``
under ``python-package`` is an efficient way to remove generated cache files. If you
find weird behaviors in Python build or running linter, it might be caused by those
cached files.
For using develop command (editable installation), see next section.
.. code-block::
python setup.py develop # Create a editable installation.
pip install -e . # Same as above, but carried out by pip.
2. Build C++ core with CMake first
This is mostly for C++ developers who don't want to go through the hooks in Python
setuptools. You can build C++ library directly using CMake as described in above
sections. After compilation, a shared object (or called dynamic linked library, jargon
depending on your platform) will appear in XGBoost's source tree under ``lib/``
directory. On Linux distributions it's ``lib/libxgboost.so``. From there all Python
setuptools commands will reuse that shared object instead of compiling it again. This
is especially convenient if you are using the editable installation, where the installed
package is simply a link to the source tree. We can perform rapid testing during
development. Here is a simple bash script does that:
.. code-block:: bash
# Under xgboost source tree.
# Under xgboost source directory
mkdir build
cd build
cmake ..
make -j$(nproc)
# Build shared library libxgboost.so
cmake .. -GNinja
ninja
# Install as editable installation
cd ../python-package
pip install -e . # or equivalently python setup.py develop
pip install -e .
3. Use ``libxgboost.so`` on system path.
4. Use ``libxgboost.so`` on system path.
This is for distributing xgboost in a language independent manner, where
``libxgboost.so`` is separately packaged with Python package. Assuming `libxgboost.so`
is already presented in system library path, which can be queried via:
This option is useful for package managers that wish to separately package
``libxgboost.so`` and the XGBoost Python package. For example, Conda
publishes ``libxgboost`` (for the shared library) and ``py-xgboost``
(for the Python package).
To use this option, first make sure that ``libxgboost.so`` exists in the system library path:
.. code-block:: python
import sys
import os
os.path.join(sys.prefix, 'lib')
import pathlib
libpath = pathlib.Path(sys.prefix).joinpath("lib", "libxgboost.so")
assert libpath.exists()
Then one only needs to provide an user option when installing Python package to reuse the
shared object in system path:
Then pass ``use_system_libxgboost=True`` option to ``pip install``:
.. code-block:: bash
cd xgboost/python-package
python setup.py install --use-system-libxgboost
cd python-package
pip install . --config-settings use_system_libxgboost=True
.. note::
See :doc:`contrib/python_packaging` for instructions on packaging
and distributing XGBoost as Python distributions.
.. _python_mingw:
Building Python Package for Windows with MinGW-w64 (Advanced)
@ -297,7 +291,7 @@ So you may want to build XGBoost with GCC own your own risk. This presents some
2. ``-O3`` is OK.
3. ``-mtune=native`` is also OK.
4. Don't use ``-march=native`` gcc flag. Using it causes the Python interpreter to crash if the DLL was actually used.
5. You may need to provide the lib with the runtime libs. If ``mingw32/bin`` is not in ``PATH``, build a wheel (``python setup.py bdist_wheel``), open it with an archiver and put the needed dlls to the directory where ``xgboost.dll`` is situated. Then you can install the wheel with ``pip``.
5. You may need to provide the lib with the runtime libs. If ``mingw32/bin`` is not in ``PATH``, build a wheel (``pip wheel``), open it with an archiver and put the needed dlls to the directory where ``xgboost.dll`` is situated. Then you can install the wheel with ``pip``.
******************************
Building R Package From Source

View File

@ -35,8 +35,9 @@ calls ``cibuildwheel`` to build the wheel. The ``cibuildwheel`` is a library tha
suitable Python environment for each OS and processor target. Since we don't have Apple Silion
machine in GitHub Actions, cross-compilation is needed; ``cibuildwheel`` takes care of the complex
task of cross-compiling a Python wheel. (Note that ``cibuildwheel`` will call
``setup.py bdist_wheel``. Since XGBoost has a native library component, ``setup.py`` contains
a glue code to call CMake and a C++ compiler to build the native library on the fly.)
``pip wheel``. Since XGBoost has a native library component, we created a customized build
backend that hooks into ``pip``. The customized backend contains the glue code to compile the native
library on the fly.)
*********************************************************
Reproduce CI testing environments using Docker containers

View File

@ -23,6 +23,7 @@ Here are guidelines for contributing to various aspect of the XGBoost project:
Community Guideline <community>
donate
coding_guide
python_packaging
unit_tests
Docs and Examples <docs>
git_guide

View File

@ -0,0 +1,83 @@
###########################################
Notes on packaging XGBoost's Python package
###########################################
.. contents:: Contents
:local:
.. _packaging_python_xgboost:
***************************************************
How to build binary wheels and source distributions
***************************************************
Wheels and source distributions (sdist for short) are the two main
mechanisms for packaging and distributing Python packages.
* A **source distribution** (sdist) is a tarball (``.tar.gz`` extension) that
contains the source code.
* A **wheel** is a ZIP-compressed archive (with ``.whl`` extension)
representing a *built* distribution. Unlike an sdist, a wheel can contain
compiled components. The compiled components are compiled prior to distribution,
making it more convenient for end-users to install a wheel. Wheels containing
compiled components are referred to as **binary wheels**.
See `Python Packaging User Guide <https://packaging.python.org/en/latest/>`_
to learn more about how Python packages in general are packaged and
distributed.
For the remainder of this document, we will focus on packaging and
distributing XGBoost.
Building sdists
===============
In the case of XGBoost, an sdist contains both the Python code as well as
the C++ code, so that the core part of XGBoost can be compiled into the
shared libary ``libxgboost.so`` [#shared_lib_name]_.
You can obtain an sdist as follows:
.. code-block:: console
$ python -m build --sdist .
(You'll need to install the ``build`` package first:
``pip install build`` or ``conda install python-build``.)
Running ``pip install`` with an sdist will launch CMake and a C++ compiler
to compile the bundled C++ code into ``libxgboost.so``:
.. code-block:: console
$ pip install -v xgboost-2.0.0.tar.gz # Add -v to show build progress
Building binary wheels
======================
You can also build a wheel as follows:
.. code-block:: console
$ pip wheel --no-deps -v .
Notably, the resulting wheel contains a copy of the shared library
``libxgboost.so`` [#shared_lib_name]_. The wheel is a **binary wheel**,
since it contains a compiled binary.
Running ``pip install`` with the binary wheel will extract the content of
the wheel into the current Python environment. Since the wheel already
contains a pre-built copy of ``libxgboost.so``, it does not have to be
built at the time of install. So ``pip install`` with the binary wheel
completes quickly:
.. code-block:: console
$ pip install xgboost-2.0.0-py3-none-linux_x86_64.whl # Completes quickly
.. rubric:: Footnotes
.. [#shared_lib_name] The name of the shared library file will differ
depending on the operating system in use. See :ref:`build_shared_lib`.

View File

@ -16,15 +16,28 @@ Stable Release
Python
------
Pre-built binary are uploaded to PyPI (Python Package Index) for each release. Supported platforms are Linux (x86_64, aarch64), Windows (x86_64) and MacOS (x86_64, Apple Silicon).
Pre-built binary wheels are uploaded to PyPI (Python Package Index) for each release. Supported platforms are Linux (x86_64, aarch64), Windows (x86_64) and MacOS (x86_64, Apple Silicon).
.. code-block:: bash
# Pip 21.3+ is required
pip install xgboost
You might need to run the command with ``--user`` flag or use ``virtualenv`` if you run
into permission errors. Python pre-built binary capability for each platform:
into permission errors.
.. note:: Windows users need to install Visual C++ Redistributable
XGBoost requires DLLs from `Visual C++ Redistributable
<https://www.microsoft.com/en-us/download/details.aspx?id=48145>`_
in order to function, so make sure to install it. Exception: If
you have Visual Studio installed, you already have access to
necessary libraries and thus don't need to install Visual C++
Redistributable.
Capabilities of binary wheels for each platform:
.. |tick| unicode:: U+2714
.. |cross| unicode:: U+2718

View File

@ -41,3 +41,7 @@ Contents
XGBoost4J Scala API <scaladocs/xgboost4j/index>
XGBoost4J-Spark Scala API <scaladocs/xgboost4j-spark/index>
XGBoost4J-Flink Scala API <scaladocs/xgboost4j-flink/index>
.. note::
Please note that the flink interface is still under construction.

View File

@ -219,6 +219,16 @@
"num_pairsample": { "type": "string" },
"fix_list_weight": { "type": "string" }
}
},
"lambdarank_param": {
"type": "object",
"properties": {
"lambdarank_num_pair_per_sample": { "type": "string" },
"lambdarank_pair_method": { "type": "string" },
"lambdarank_unbiased": {"type": "string" },
"lambdarank_bias_norm": {"type": "string" },
"ndcg_exp_gain": {"type": "string"}
}
}
},
"type": "object",
@ -477,22 +487,22 @@
"type": "object",
"properties": {
"name": { "const": "rank:pairwise" },
"lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
"lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"}
},
"required": [
"name",
"lambda_rank_param"
"lambdarank_param"
]
},
{
"type": "object",
"properties": {
"name": { "const": "rank:ndcg" },
"lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
"lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"}
},
"required": [
"name",
"lambda_rank_param"
"lambdarank_param"
]
},
{

View File

@ -233,7 +233,7 @@ Parameters for Tree Booster
.. note:: This parameter is working-in-progress.
- The strategy used for training multi-target models, including multi-target regression
and multi-class classification. See :doc:`/tutorials/multioutput` for more information.
and multi-class classification. See :doc:`/tutorials/multioutput` for more information.
- ``one_output_per_tree``: One model for each target.
- ``multi_output_tree``: Use multi-target trees.
@ -380,9 +380,9 @@ Specify the learning task and the corresponding learning objective. The objectiv
See :doc:`/tutorials/aft_survival_analysis` for details.
- ``multi:softmax``: set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes)
- ``multi:softprob``: same as softmax, but output a vector of ``ndata * nclass``, which can be further reshaped to ``ndata * nclass`` matrix. The result contains predicted probability of each data point belonging to each class.
- ``rank:pairwise``: Use LambdaMART to perform pairwise ranking where the pairwise loss is minimized
- ``rank:ndcg``: Use LambdaMART to perform list-wise ranking where `Normalized Discounted Cumulative Gain (NDCG) <http://en.wikipedia.org/wiki/NDCG>`_ is maximized
- ``rank:map``: Use LambdaMART to perform list-wise ranking where `Mean Average Precision (MAP) <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_ is maximized
- ``rank:ndcg``: Use LambdaMART to perform pair-wise ranking where `Normalized Discounted Cumulative Gain (NDCG) <http://en.wikipedia.org/wiki/NDCG>`_ is maximized. This objective supports position debiasing for click data.
- ``rank:map``: Use LambdaMART to perform pair-wise ranking where `Mean Average Precision (MAP) <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_ is maximized
- ``rank:pairwise``: Use LambdaRank to perform pair-wise ranking using the `ranknet` objective.
- ``reg:gamma``: gamma regression with log-link. Output is a mean of gamma distribution. It might be useful, e.g., for modeling insurance claims severity, or for any outcome that might be `gamma-distributed <https://en.wikipedia.org/wiki/Gamma_distribution#Occurrence_and_applications>`_.
- ``reg:tweedie``: Tweedie regression with log-link. It might be useful, e.g., for modeling total loss in insurance, or for any outcome that might be `Tweedie-distributed <https://en.wikipedia.org/wiki/Tweedie_distribution#Occurrence_and_applications>`_.
@ -395,8 +395,9 @@ Specify the learning task and the corresponding learning objective. The objectiv
* ``eval_metric`` [default according to objective]
- Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, mean average precision for ranking)
- User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous one
- Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, `mean average precision` for ``rank:map``, etc.)
- User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous ones
- The choices are listed below:
- ``rmse``: `root mean square error <http://en.wikipedia.org/wiki/Root_mean_square_error>`_
@ -480,6 +481,36 @@ Parameter for using AFT Survival Loss (``survival:aft``) and Negative Log Likeli
* ``aft_loss_distribution``: Probability Density Function, ``normal``, ``logistic``, or ``extreme``.
.. _ltr-param:
Parameters for learning to rank (``rank:ndcg``, ``rank:map``, ``rank:pairwise``)
================================================================================
These are parameters specific to learning to rank task. See :doc:`Learning to Rank </tutorials/learning_to_rank>` for an in-depth explanation.
* ``lambdarank_pair_method`` [default = ``mean``]
How to construct pairs for pair-wise learning.
- ``mean``: Sample ``lambdarank_num_pair_per_sample`` pairs for each document in the query list.
- ``topk``: Focus on top-``lambdarank_num_pair_per_sample`` documents. Construct :math:`|query|` pairs for each document at the top-``lambdarank_num_pair_per_sample`` ranked by the model.
* ``lambdarank_num_pair_per_sample`` [range = :math:`[1, \infty]`]
It specifies the number of pairs sampled for each document when pair method is ``mean``, or the truncation level for queries when the pair method is ``topk``. For example, to train with ``ndcg@6``, set ``lambdarank_num_pair_per_sample`` to :math:`6` and ``lambdarank_pair_method`` to ``topk``.
* ``lambdarank_unbiased`` [default = ``false``]
Specify whether do we need to debias input click data.
* ``lambdarank_bias_norm`` [default = 2.0]
:math:`L_p` normalization for position debiasing, default is :math:`L_2`. Only relevant when ``lambdarank_unbiased`` is set to true.
* ``ndcg_exp_gain`` [default = ``true``]
Whether we should use exponential gain function for ``NDCG``. There are two forms of gain function for ``NDCG``, one is using relevance value directly while the other is using :math:`2^{rel} - 1` to emphasize on retrieving relevant documents. When ``ndcg_exp_gain`` is true (the default), relevance degree cannot be greater than 31.
***********************
Command Line Parameters
***********************

View File

@ -23,7 +23,7 @@ Requirements
Dask can be installed using either pip or conda (see the dask `installation
documentation <https://docs.dask.org/en/latest/install.html>`_ for more information). For
accelerating XGBoost with GPUs, `dask-cuda <https://github.com/rapidsai/dask-cuda>`_ is
accelerating XGBoost with GPUs, `dask-cuda <https://github.com/rapidsai/dask-cuda>`__ is
recommended for creating GPU clusters.

View File

@ -77,7 +77,7 @@ The external memory version takes in the following `URI <https://en.wikipedia.or
.. code-block:: none
filename#cacheprefix
filename?format=libsvm#cacheprefix
The ``filename`` is the normal path to LIBSVM format file you want to load in, and
``cacheprefix`` is a path to a cache file that XGBoost will use for caching preprocessed
@ -97,13 +97,13 @@ you have a dataset stored in a file similar to ``agaricus.txt.train`` with LIBSV
.. code-block:: python
dtrain = DMatrix('../data/agaricus.txt.train#dtrain.cache')
dtrain = DMatrix('../data/agaricus.txt.train?format=libsvm#dtrain.cache')
XGBoost will first load ``agaricus.txt.train`` in, preprocess it, then write to a new file named
``dtrain.cache`` as an on disk cache for storing preprocessed data in an internal binary format. For
more notes about text input formats, see :doc:`/tutorials/input_format`.
For CLI version, simply add the cache suffix, e.g. ``"../data/agaricus.txt.train#dtrain.cache"``.
For CLI version, simply add the cache suffix, e.g. ``"../data/agaricus.txt.train?format=libsvm#dtrain.cache"``.
**********************************

View File

@ -2,10 +2,15 @@
Text Input Format of DMatrix
############################
.. _basic_input_format:
Here we will briefly describe the text input formats for XGBoost. However, for users with access to a supported language environment like Python or R, it's recommended to use data parsers from that ecosystem instead. For instance, :py:func:`sklearn.datasets.load_svmlight_file`.
******************
Basic Input Format
******************
XGBoost currently supports two text formats for ingesting data: LIBSVM and CSV. The rest of this document will describe the LIBSVM format. (See `this Wikipedia article <https://en.wikipedia.org/wiki/Comma-separated_values>`_ for a description of the CSV format.). Please be careful that, XGBoost does **not** understand file extensions, nor try to guess the file format, as there is no universal agreement upon file extension of LIBSVM or CSV. Instead it employs `URI <https://en.wikipedia.org/wiki/Uniform_Resource_Identifier>`_ format for specifying the precise input file type. For example if you provide a `csv` file ``./data.train.csv`` as input, XGBoost will blindly use the default LIBSVM parser to digest it and generate a parser error. Instead, users need to provide an URI in the form of ``train.csv?format=csv``. For external memory input, the URI should of a form similar to ``train.csv?format=csv#dtrain.cache``. See :ref:`python_data_interface` and :doc:`/tutorials/external_memory` also.
XGBoost currently supports two text formats for ingesting data: LIBSVM and CSV. The rest of this document will describe the LIBSVM format. (See `this Wikipedia article <https://en.wikipedia.org/wiki/Comma-separated_values>`_ for a description of the CSV format.). Please be careful that, XGBoost does **not** understand file extensions, nor try to guess the file format, as there is no universal agreement upon file extension of LIBSVM or CSV. Instead it employs `URI <https://en.wikipedia.org/wiki/Uniform_Resource_Identifier>`_ format for specifying the precise input file type. For example if you provide a `csv` file ``./data.train.csv`` as input, XGBoost will blindly use the default LIBSVM parser to digest it and generate a parser error. Instead, users need to provide an URI in the form of ``train.csv?format=csv`` or ``train.csv?format=libsvm``. For external memory input, the URI should of a form similar to ``train.csv?format=csv#dtrain.cache``. See :ref:`python_data_interface` and :doc:`/tutorials/external_memory` also.
For training or predicting, XGBoost takes an instance file with the format as below:

View File

@ -108,8 +108,8 @@ virtualenv and pip:
python -m venv xgboost_env
source xgboost_env/bin/activate
pip install pyarrow pandas venv-pack xgboost
# https://rapids.ai/pip.html#install
pip install cudf-cu11 --extra-index-url=https://pypi.ngc.nvidia.com
# https://docs.rapids.ai/install#pip-install
pip install cudf-cu11 --extra-index-url=https://pypi.nvidia.com
venv-pack -o xgboost_env.tar.gz
With Conda:
@ -241,7 +241,7 @@ additional spark configurations and dependencies:
--master spark://<master-ip>:7077 \
--conf spark.executor.resource.gpu.amount=1 \
--conf spark.task.resource.gpu.amount=1 \
--packages com.nvidia:rapids-4-spark_2.12:22.08.0 \
--packages com.nvidia:rapids-4-spark_2.12:23.04.0 \
--conf spark.plugins=com.nvidia.spark.SQLPlugin \
--conf spark.sql.execution.arrow.maxRecordsPerBatch=1000000 \
--archives xgboost_env.tar.gz#environment \

View File

@ -38,7 +38,7 @@ typedef uint64_t bst_ulong; // NOLINT(*)
*/
/**
* @defgroup Library
* @defgroup Library Library
*
* These functions are used to obtain general information about XGBoost including version,
* build info and current global configuration.
@ -112,7 +112,7 @@ XGB_DLL int XGBGetGlobalConfig(char const **out_config);
/**@}*/
/**
* @defgroup DMatrix
* @defgroup DMatrix DMatrix
*
* @brief DMatrix is the baisc data storage for XGBoost used by all XGBoost algorithms
* including both training, prediction and explanation. There are a few variants of
@ -138,7 +138,11 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle
/*!
* \brief load a data matrix
* \param config JSON encoded parameters for DMatrix construction. Accepted fields are:
* - uri: The URI of the input file.
* - uri: The URI of the input file. The URI parameter `format` is required when loading text data.
* \verbatim embed:rst:leading-asterisk
* See :doc:`/tutorials/input_format` for more info.
* \endverbatim
* - silent (optional): Whether to print message during loading. Default to true.
* - data_split_mode (optional): Whether to split by row or column. In distributed mode, the
* file is split accordingly; otherwise this is only an indicator on how the file was split
@ -200,7 +204,7 @@ XGB_DLL int XGDMatrixCreateFromDense(char const *data, char const *config, DMatr
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixCreateFromCSC(char const *indptr, char const *indices, char const *data,
bst_ulong nrow, char const *c_json_config, DMatrixHandle *out);
bst_ulong nrow, char const *config, DMatrixHandle *out);
/*!
* \brief create a matrix content from CSC format
@ -281,7 +285,7 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data, char const *
DMatrixHandle *out);
/**
* @defgroup Streaming
* @defgroup Streaming Streaming
* @ingroup DMatrix
*
* @brief Quantile DMatrix and external memory DMatrix can be created from batches of
@ -431,7 +435,7 @@ XGB_EXTERN_C typedef void DataIterResetCallback(DataIterHandle handle); // NOLIN
* - Step 0: Define a data iterator with 2 methods `reset`, and `next`.
* - Step 1: Create a DMatrix proxy by \ref XGProxyDMatrixCreate and hold the handle.
* - Step 2: Pass the iterator handle, proxy handle and 2 methods into
* `XGDMatrixCreateFromCallback`, along with other parameters encoded as a JSON object.
* \ref XGDMatrixCreateFromCallback, along with other parameters encoded as a JSON object.
* - Step 3: Call appropriate data setters in `next` functions.
*
* \param iter A handle to external data iterator.
@ -830,7 +834,7 @@ XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config
/** @} */ // End of DMatrix
/**
* @defgroup Booster
* @defgroup Booster Booster
*
* @brief The `Booster` class is the gradient-boosted model for XGBoost.
* @{
@ -953,7 +957,7 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle, int iter, DMatrixHandle d
*/
/**
* @defgroup Prediction
* @defgroup Prediction Prediction
* @ingroup Booster
*
* @brief These functions are used for running prediction and explanation algorithms.
@ -1155,7 +1159,7 @@ XGB_DLL int XGBoosterPredictFromCudaColumnar(BoosterHandle handle, char const *v
/**
* @defgroup Serialization
* @defgroup Serialization Serialization
* @ingroup Booster
*
* @brief There are multiple ways to serialize a Booster object depending on the use case.
@ -1490,7 +1494,7 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config,
/**@}*/ // End of Booster
/**
* @defgroup Collective
* @defgroup Collective Collective
*
* @brief Experimental support for exposing internal communicator in XGBoost.
*

View File

@ -50,7 +50,19 @@ struct Context : public XGBoostParameter<Context> {
bool IsCPU() const { return gpu_id == kCpuId; }
bool IsCUDA() const { return !IsCPU(); }
CUDAContext const* CUDACtx() const;
// Make a CUDA context based on the current context.
Context MakeCUDA(std::int32_t device = 0) const {
Context ctx = *this;
ctx.gpu_id = device;
return ctx;
}
Context MakeCPU() const {
Context ctx = *this;
ctx.gpu_id = kCpuId;
return ctx;
}
// declare parameters
DMLC_DECLARE_PARAMETER(Context) {

View File

@ -1,5 +1,5 @@
/*!
* Copyright (c) 2015-2022 by XGBoost Contributors
/**
* Copyright 2015-2023 by XGBoost Contributors
* \file data.h
* \brief The input data structure of xgboost.
* \author Tianqi Chen
@ -196,6 +196,14 @@ class MetaInfo {
*/
bool IsVerticalFederated() const;
/*!
* \brief A convenient method to check if the MetaInfo should contain labels.
*
* Normally we assume labels are available everywhere. The only exception is in vertical federated
* learning where labels are only available on worker 0.
*/
bool ShouldHaveLabels() const;
private:
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
@ -230,44 +238,72 @@ struct Entry {
}
};
/*!
* \brief Parameters for constructing batches.
/**
* \brief Parameters for constructing histogram index batches.
*/
struct BatchParam {
/*! \brief The GPU device to use. */
int gpu_id {-1};
/*! \brief Maximum number of bins per feature for histograms. */
/**
* \brief Maximum number of bins per feature for histograms.
*/
bst_bin_t max_bin{0};
/*! \brief Hessian, used for sketching with future approx implementation. */
/**
* \brief Hessian, used for sketching with future approx implementation.
*/
common::Span<float> hess;
/*! \brief Whether should DMatrix regenerate the batch. Only used for GHistIndex. */
bool regen {false};
/*! \brief Parameter used to generate column matrix for hist. */
/**
* \brief Whether should we force DMatrix to regenerate the batch. Only used for
* GHistIndex.
*/
bool regen{false};
/**
* \brief Forbid regenerating the gradient index. Used for internal validation.
*/
bool forbid_regen{false};
/**
* \brief Parameter used to generate column matrix for hist.
*/
double sparse_thresh{std::numeric_limits<double>::quiet_NaN()};
/**
* \brief Exact or others that don't need histogram.
*/
BatchParam() = default;
// GPU Hist
BatchParam(int32_t device, bst_bin_t max_bin)
: gpu_id{device}, max_bin{max_bin} {}
// Hist
/**
* \brief Used by the hist tree method.
*/
BatchParam(bst_bin_t max_bin, double sparse_thresh)
: max_bin{max_bin}, sparse_thresh{sparse_thresh} {}
// Approx
/**
* \brief Get batch with sketch weighted by hessian. The batch will be regenerated if
* the span is changed, so caller should keep the span for each iteration.
* \brief Used by the approx tree method.
*
* Get batch with sketch weighted by hessian. The batch will be regenerated if the
* span is changed, so caller should keep the span for each iteration.
*/
BatchParam(bst_bin_t max_bin, common::Span<float> hessian, bool regenerate)
: max_bin{max_bin}, hess{hessian}, regen{regenerate} {}
bool operator!=(BatchParam const& other) const {
if (hess.empty() && other.hess.empty()) {
return gpu_id != other.gpu_id || max_bin != other.max_bin;
}
return gpu_id != other.gpu_id || max_bin != other.max_bin || hess.data() != other.hess.data();
bool ParamNotEqual(BatchParam const& other) const {
// Check non-floating parameters.
bool cond = max_bin != other.max_bin;
// Check sparse thresh.
bool l_nan = std::isnan(sparse_thresh);
bool r_nan = std::isnan(other.sparse_thresh);
bool st_chg = (l_nan != r_nan) || (!l_nan && !r_nan && (sparse_thresh != other.sparse_thresh));
cond |= st_chg;
return cond;
}
bool operator==(BatchParam const& other) const {
return !(*this != other);
bool Initialized() const { return max_bin != 0; }
/**
* \brief Make a copy of self for DMatrix to describe how its existing index was generated.
*/
BatchParam MakeCache() const {
auto p = *this;
// These parameters have nothing to do with how the gradient index was generated in the
// first place.
p.regen = false;
p.forbid_regen = false;
return p;
}
};
@ -427,7 +463,7 @@ class EllpackPage {
* This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix
* in CSR format.
*/
explicit EllpackPage(DMatrix* dmat, const BatchParam& param);
explicit EllpackPage(Context const* ctx, DMatrix* dmat, const BatchParam& param);
/*! \brief Destructor. */
~EllpackPage();
@ -543,7 +579,9 @@ class DMatrix {
template <typename T>
BatchSet<T> GetBatches();
template <typename T>
BatchSet<T> GetBatches(const BatchParam& param);
BatchSet<T> GetBatches(Context const* ctx);
template <typename T>
BatchSet<T> GetBatches(Context const* ctx, const BatchParam& param);
template <typename T>
bool PageExists() const;
@ -558,21 +596,17 @@ class DMatrix {
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
}
/*!
/**
* \brief Load DMatrix from URI.
*
* \param uri The URI of input.
* \param silent Whether print information during loading.
* \param data_split_mode In distributed mode, split the input according this mode; otherwise,
* it's just an indicator on how the input was split beforehand.
* \param file_format The format type of the file, used for dmlc::Parser::Create.
* By default "auto" will be able to load in both local binary file.
* \param page_size Page size for external memory.
* \return The created DMatrix.
*/
static DMatrix* Load(const std::string& uri,
bool silent = true,
DataSplitMode data_split_mode = DataSplitMode::kRow,
const std::string& file_format = "auto");
static DMatrix* Load(const std::string& uri, bool silent = true,
DataSplitMode data_split_mode = DataSplitMode::kRow);
/**
* \brief Creates a new DMatrix from an external data adapter.
@ -654,18 +688,19 @@ class DMatrix {
protected:
virtual BatchSet<SparsePage> GetRowBatches() = 0;
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) = 0;
virtual BatchSet<ExtSparsePage> GetExtBatches(BatchParam const& param) = 0;
virtual BatchSet<CSCPage> GetColumnBatches(Context const* ctx) = 0;
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches(Context const* ctx) = 0;
virtual BatchSet<EllpackPage> GetEllpackBatches(Context const* ctx, BatchParam const& param) = 0;
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(Context const* ctx,
BatchParam const& param) = 0;
virtual BatchSet<ExtSparsePage> GetExtBatches(Context const* ctx, BatchParam const& param) = 0;
virtual bool EllpackExists() const = 0;
virtual bool GHistIndexExists() const = 0;
virtual bool SparsePageExists() const = 0;
};
template<>
template <>
inline BatchSet<SparsePage> DMatrix::GetBatches() {
return GetRowBatches();
}
@ -680,34 +715,39 @@ inline bool DMatrix::PageExists<GHistIndexMatrix>() const {
return this->GHistIndexExists();
}
template<>
template <>
inline bool DMatrix::PageExists<SparsePage>() const {
return this->SparsePageExists();
}
template<>
inline BatchSet<CSCPage> DMatrix::GetBatches() {
return GetColumnBatches();
}
template<>
inline BatchSet<SortedCSCPage> DMatrix::GetBatches() {
return GetSortedColumnBatches();
}
template<>
inline BatchSet<EllpackPage> DMatrix::GetBatches(const BatchParam& param) {
return GetEllpackBatches(param);
template <>
inline BatchSet<SparsePage> DMatrix::GetBatches(Context const*) {
return GetRowBatches();
}
template <>
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(const BatchParam& param) {
return GetGradientIndex(param);
inline BatchSet<CSCPage> DMatrix::GetBatches(Context const* ctx) {
return GetColumnBatches(ctx);
}
template <>
inline BatchSet<ExtSparsePage> DMatrix::GetBatches() {
return GetExtBatches(BatchParam{});
inline BatchSet<SortedCSCPage> DMatrix::GetBatches(Context const* ctx) {
return GetSortedColumnBatches(ctx);
}
template <>
inline BatchSet<EllpackPage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
return GetEllpackBatches(ctx, param);
}
template <>
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
return GetGradientIndex(ctx, param);
}
template <>
inline BatchSet<ExtSparsePage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
return GetExtBatches(ctx, param);
}
} // namespace xgboost

View File

@ -567,7 +567,7 @@ class RegTree : public Model {
* \brief drop the trace after fill, must be called after fill.
* \param inst The sparse instance to drop.
*/
void Drop(const SparsePage::Inst& inst);
void Drop();
/*!
* \brief returns the size of the feature vector
* \return the size of the feature vector
@ -807,13 +807,10 @@ inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
has_missing_ = data_.size() != feature_count;
}
inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) {
for (auto const& entry : inst) {
if (entry.index >= data_.size()) {
continue;
}
data_[entry.index].flag = -1;
}
inline void RegTree::FVec::Drop() {
Entry e{};
e.flag = -1;
std::fill_n(data_.data(), data_.size(), e);
has_missing_ = true;
}

View File

@ -33,16 +33,16 @@
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<flink.version>1.8.3</flink.version>
<spark.version>3.1.1</spark.version>
<scala.version>2.12.8</scala.version>
<flink.version>1.17.0</flink.version>
<spark.version>3.4.0</spark.version>
<scala.version>2.12.17</scala.version>
<scala.binary.version>2.12</scala.binary.version>
<hadoop.version>3.3.5</hadoop.version>
<maven.wagon.http.retryHandler.count>5</maven.wagon.http.retryHandler.count>
<log.capi.invocation>OFF</log.capi.invocation>
<use.cuda>OFF</use.cuda>
<cudf.version>22.12.0</cudf.version>
<spark.rapids.version>22.12.0</spark.rapids.version>
<cudf.version>23.04.0</cudf.version>
<spark.rapids.version>23.04.0</spark.rapids.version>
<cudf.classifier>cuda11</cudf.classifier>
</properties>
<repositories>
@ -374,7 +374,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-checkstyle-plugin</artifactId>
<version>3.2.1</version>
<version>3.2.2</version>
<configuration>
<configLocation>checkstyle.xml</configLocation>
<failOnViolation>true</failOnViolation>
@ -450,7 +450,7 @@
<plugins>
<plugin>
<artifactId>maven-project-info-reports-plugin</artifactId>
<version>3.4.2</version>
<version>3.4.3</version>
</plugin>
<plugin>
<groupId>net.alchim31.maven</groupId>
@ -469,7 +469,7 @@
<dependency>
<groupId>com.esotericsoftware</groupId>
<artifactId>kryo</artifactId>
<version>5.4.0</version>
<version>5.5.0</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
@ -477,11 +477,6 @@
<version>${scala.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-reflect</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
@ -495,13 +490,13 @@
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
<version>3.0.8</version>
<version>3.2.15</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalactic</groupId>
<artifactId>scalactic_${scala.binary.version}</artifactId>
<version>3.0.8</version>
<version>3.2.15</version>
<scope>test</scope>
</dependency>
</dependencies>

View File

@ -26,7 +26,7 @@
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-spark_${scala.binary.version}</artifactId>
<version>2.0.0-SNAPSHOT</version>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
@ -37,12 +37,7 @@
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.12.0</version>
<version>${project.version}</version>
</dependency>
</dependencies>
</project>

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2021 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -62,8 +62,8 @@ public class BasicWalkThrough {
public static void main(String[] args) throws IOException, XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);
@ -112,7 +112,8 @@ public class BasicWalkThrough {
System.out.println("start build dmatrix from csr sparse data ...");
//build dmatrix from CSR Sparse Matrix
DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
DataLoader.CSRSparseData spData =
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm");
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
DMatrix.SparseType.CSR, 127);

View File

@ -32,8 +32,8 @@ public class BoostFromPrediction {
System.out.println("start running example to start from a initial prediction");
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();

View File

@ -30,7 +30,7 @@ import ml.dmlc.xgboost4j.java.XGBoostError;
public class CrossValidation {
public static void main(String[] args) throws IOException, XGBoostError {
//load train mat
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
//set params
HashMap<String, Object> params = new HashMap<String, Object>();

View File

@ -139,9 +139,9 @@ public class CustomObjective {
public static void main(String[] args) throws XGBoostError {
//load train mat (svmlight format)
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
//load valid mat (svmlight format)
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 1.0);

View File

@ -29,9 +29,9 @@ import ml.dmlc.xgboost4j.java.example.util.DataLoader;
public class EarlyStopping {
public static void main(String[] args) throws IOException, XGBoostError {
DataLoader.CSRSparseData trainCSR =
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm");
DataLoader.CSRSparseData testCSR =
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.test");
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.test?format=libsvm");
Map<String, Object> paramMap = new HashMap<String, Object>() {
{

View File

@ -32,8 +32,8 @@ public class ExternalMemory {
//this is the only difference, add a # followed by a cache prefix name
//several cache file with the prefix will be generated
//currently only support convert from libsvm file
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm#dtrain.cache");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm#dtest.cache");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();

View File

@ -32,8 +32,8 @@ import ml.dmlc.xgboost4j.java.example.util.CustomEval;
public class GeneralizedLinearModel {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
//specify parameters
//change booster to gblinear, so that we are fitting a linear model

View File

@ -31,8 +31,8 @@ import ml.dmlc.xgboost4j.java.example.util.CustomEval;
public class PredictFirstNtree {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();

View File

@ -31,8 +31,8 @@ import ml.dmlc.xgboost4j.java.XGBoostError;
public class PredictLeafIndices {
public static void main(String[] args) throws XGBoostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
//specify parameters
HashMap<String, Object> params = new HashMap<String, Object>();

View File

@ -0,0 +1,107 @@
/*
Copyright (c) 2014-2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java.example.flink;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.tuple.Tuple13;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import ml.dmlc.xgboost4j.java.flink.XGBoost;
import ml.dmlc.xgboost4j.java.flink.XGBoostModel;
public class DistTrainWithFlinkExample {
static Tuple2<XGBoostModel, DataSet<Float[]>> runPrediction(
ExecutionEnvironment env,
java.nio.file.Path trainPath,
int percentage) throws Exception {
// reading data
final DataSet<Tuple2<Long, Tuple2<Vector, Double>>> data =
DataSetUtils.zipWithIndex(parseCsv(env, trainPath));
final long size = data.count();
final long trainCount = Math.round(size * 0.01 * percentage);
final DataSet<Tuple2<Vector, Double>> trainData =
data
.filter(item -> item.f0 < trainCount)
.map(t -> t.f1)
.returns(TypeInformation.of(new TypeHint<Tuple2<Vector, Double>>(){}));
final DataSet<Vector> testData =
data
.filter(tuple -> tuple.f0 >= trainCount)
.map(t -> t.f1.f0)
.returns(TypeInformation.of(new TypeHint<Vector>(){}));
// define parameters
HashMap<String, Object> paramMap = new HashMap<String, Object>(3);
paramMap.put("eta", 0.1);
paramMap.put("max_depth", 2);
paramMap.put("objective", "binary:logistic");
// number of iterations
final int round = 2;
// train the model
XGBoostModel model = XGBoost.train(trainData, paramMap, round);
DataSet<Float[]> predTest = model.predict(testData);
return new Tuple2<XGBoostModel, DataSet<Float[]>>(model, predTest);
}
private static MapOperator<Tuple13<Double, String, Double, Double, Double, Integer, Integer,
Integer, Integer, Integer, Integer, Integer, Integer>,
Tuple2<Vector, Double>> parseCsv(ExecutionEnvironment env, Path trainPath) {
return env.readCsvFile(trainPath.toString())
.ignoreFirstLine()
.types(Double.class, String.class, Double.class, Double.class, Double.class,
Integer.class, Integer.class, Integer.class, Integer.class, Integer.class,
Integer.class, Integer.class, Integer.class)
.map(DistTrainWithFlinkExample::mapFunction);
}
private static Tuple2<Vector, Double> mapFunction(Tuple13<Double, String, Double, Double, Double,
Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer> tuple) {
final DenseVector dense = Vectors.dense(tuple.f2, tuple.f3, tuple.f4, tuple.f5, tuple.f6,
tuple.f7, tuple.f8, tuple.f9, tuple.f10, tuple.f11, tuple.f12);
if (tuple.f1.contains("inf")) {
return new Tuple2<Vector, Double>(dense, 1.0);
} else {
return new Tuple2<Vector, Double>(dense, 0.0);
}
}
public static void main(String[] args) throws Exception {
final java.nio.file.Path parentPath = java.nio.file.Paths.get(Arrays.stream(args)
.findFirst().orElse("."));
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
Tuple2<XGBoostModel, DataSet<Float[]>> tuple2 = runPrediction(
env, parentPath.resolve("veterans_lung_cancer.csv"), 70
);
List<Float[]> list = tuple2.f1.collect();
System.out.println(list.size());
}
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -36,8 +36,8 @@ object BasicWalkThrough {
}
def main(args: Array[String]): Unit = {
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train")
val testMax = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMax = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0
@ -76,7 +76,7 @@ object BasicWalkThrough {
// build dmatrix from CSR Sparse Matrix
println("start build dmatrix from csr sparse data ...")
val spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train")
val spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm")
val trainMax2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
JDMatrix.SparseType.CSR)
trainMax2.setLabel(spData.labels)

View File

@ -24,8 +24,8 @@ object BoostFromPrediction {
def main(args: Array[String]): Unit = {
println("start running example to start from a initial prediction")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0

View File

@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
object CrossValidation {
def main(args: Array[String]): Unit = {
val trainMat: DMatrix = new DMatrix("../../demo/data/agaricus.txt.train")
val trainMat: DMatrix = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
// set params
val params = new mutable.HashMap[String, Any]

View File

@ -138,8 +138,8 @@ object CustomObjective {
}
def main(args: Array[String]): Unit = {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0
params += "max_depth" -> 2

View File

@ -25,8 +25,8 @@ object ExternalMemory {
// this is the only difference, add a # followed by a cache prefix name
// several cache file with the prefix will be generated
// currently only support convert from libsvm file
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm#dtrain.cache")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm#dtest.cache")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0

View File

@ -27,8 +27,8 @@ import ml.dmlc.xgboost4j.scala.example.util.CustomEval
*/
object GeneralizedLinearModel {
def main(args: Array[String]): Unit = {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
// specify parameters
// change booster to gblinear, so that we are fitting a linear model

View File

@ -23,8 +23,8 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
object PredictFirstNTree {
def main(args: Array[String]): Unit = {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0

View File

@ -25,8 +25,8 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
object PredictLeafIndices {
def main(args: Array[String]): Unit = {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
val params = new mutable.HashMap[String, Any]()
params += "eta" -> 1.0

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014 - 2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -15,27 +15,84 @@
*/
package ml.dmlc.xgboost4j.scala.example.flink
import ml.dmlc.xgboost4j.scala.flink.XGBoost
import org.apache.flink.api.scala.{ExecutionEnvironment, _}
import org.apache.flink.ml.MLUtils
import java.lang.{Double => JDouble, Long => JLong}
import java.nio.file.{Path, Paths}
import org.apache.flink.api.java.tuple.{Tuple13, Tuple2}
import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
import org.apache.flink.ml.linalg.{Vector, Vectors}
import ml.dmlc.xgboost4j.java.flink.{XGBoost, XGBoostModel}
import org.apache.flink.api.common.typeinfo.{TypeHint, TypeInformation}
import org.apache.flink.api.java.utils.DataSetUtils
object DistTrainWithFlink {
def main(args: Array[String]) {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
// read trainining data
val trainData =
MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
val testData = MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.test")
// define parameters
val paramMap = List(
"eta" -> 0.1,
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
import scala.jdk.CollectionConverters._
private val rowTypeHint = TypeInformation.of(new TypeHint[Tuple2[Vector, JDouble]]{})
private val testDataTypeHint = TypeInformation.of(classOf[Vector])
private[flink] def parseCsv(trainPath: Path)(implicit env: ExecutionEnvironment):
DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = {
DataSetUtils.zipWithIndex(
env
.readCsvFile(trainPath.toString)
.ignoreFirstLine
.types(
classOf[Double], classOf[String], classOf[Double], classOf[Double], classOf[Double],
classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer],
classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer]
)
.map((row: Tuple13[Double, String, Double, Double, Double,
Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer]) => {
val dense = Vectors.dense(row.f2, row.f3, row.f4,
row.f5.toDouble, row.f6.toDouble, row.f7.toDouble, row.f8.toDouble,
row.f9.toDouble, row.f10.toDouble, row.f11.toDouble, row.f12.toDouble)
val label = if (row.f1.contains("inf")) {
JDouble.valueOf(1.0)
} else {
JDouble.valueOf(0.0)
}
new Tuple2[Vector, JDouble](dense, label)
})
.returns(rowTypeHint)
)
}
private[flink] def runPrediction(trainPath: Path, percentage: Int)
(implicit env: ExecutionEnvironment):
(XGBoostModel, DataSet[Array[Float]]) = {
// read training data
val data: DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = parseCsv(trainPath)
val trainSize = Math.round(0.01 * percentage * data.count())
val trainData: DataSet[Tuple2[Vector, JDouble]] =
data.filter(d => d.f0 < trainSize).map(_.f1).returns(rowTypeHint)
val testData: DataSet[Vector] =
data
.filter(d => d.f0 >= trainSize)
.map(_.f1.f0)
.returns(testDataTypeHint)
val paramMap = mapAsJavaMap(Map(
("eta", "0.1".asInstanceOf[AnyRef]),
("max_depth", "2"),
("objective", "binary:logistic"),
("verbosity", "1")
))
// number of iterations
val round = 2
// train the model
val model = XGBoost.train(trainData, paramMap, round)
val predTest = model.predict(testData.map{x => x.vector})
model.saveModelAsHadoopFile("file:///path/to/xgboost.model")
val result = model.predict(testData).map(prediction => prediction.map(Float.unbox))
(model, result)
}
def main(args: Array[String]): Unit = {
implicit val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val parentPath = Paths.get(args.headOption.getOrElse("."))
val (_, predTest) = runPrediction(parentPath.resolve("veterans_lung_cancer.csv"), 70)
val list = predTest.collect().asScala
println(list.length)
}
}

View File

@ -0,0 +1,36 @@
/*
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java.example.flink
import org.apache.flink.api.java.ExecutionEnvironment
import org.scalatest.Inspectors._
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers._
import java.nio.file.Paths
class DistTrainWithFlinkExampleTest extends AnyFunSuite {
private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
private val data = parentPath.resolve("veterans_lung_cancer.csv")
test("Smoke test for scala flink example") {
val env = ExecutionEnvironment.createLocalEnvironment(1)
val tuple2 = DistTrainWithFlinkExample.runPrediction(env, data, 70)
val results = tuple2.f1.collect()
results should have size 41
forEvery(results)(item => item should have size 1)
}
}

View File

@ -0,0 +1,37 @@
/*
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.example.flink
import org.apache.flink.api.java.ExecutionEnvironment
import org.scalatest.Inspectors._
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers._
import java.nio.file.Paths
import scala.jdk.CollectionConverters._
class DistTrainWithFlinkSuite extends AnyFunSuite {
private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
private val data = parentPath.resolve("veterans_lung_cancer.csv")
test("Smoke test for scala flink example") {
implicit val env: ExecutionEnvironment = ExecutionEnvironment.createLocalEnvironment(1)
val (_, result) = DistTrainWithFlink.runPrediction(data, 70)
val results = result.collect().asScala
results should have size 41
forEvery(results)(item => item should have size 1)
}
}

View File

@ -8,8 +8,11 @@
<artifactId>xgboost-jvm_2.12</artifactId>
<version>2.0.0-SNAPSHOT</version>
</parent>
<artifactId>xgboost4j-flink_2.12</artifactId>
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
<version>2.0.0-SNAPSHOT</version>
<properties>
<flink-ml.version>2.2.0</flink-ml.version>
</properties>
<build>
<plugins>
<plugin>
@ -26,32 +29,22 @@
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.12.0</version>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-scala_${scala.binary.version}</artifactId>
<artifactId>flink-clients</artifactId>
<version>${flink.version}</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-clients_${scala.binary.version}</artifactId>
<version>${flink.version}</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-ml_${scala.binary.version}</artifactId>
<version>${flink.version}</version>
<artifactId>flink-ml-servable-core</artifactId>
<version>${flink-ml.version}</version>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-common</artifactId>
<version>3.3.5</version>
<version>${hadoop.version}</version>
</dependency>
</dependencies>

View File

@ -0,0 +1,187 @@
/*
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java.flink;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.util.Collector;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.Communicator;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.RabitTracker;
import ml.dmlc.xgboost4j.java.XGBoostError;
public class XGBoost {
private static final Logger logger = LoggerFactory.getLogger(XGBoost.class);
private static class MapFunction
extends RichMapPartitionFunction<Tuple2<Vector, Double>, XGBoostModel> {
private final Map<String, Object> params;
private final int round;
private final Map<String, String> workerEnvs;
public MapFunction(Map<String, Object> params, int round, Map<String, String> workerEnvs) {
this.params = params;
this.round = round;
this.workerEnvs = workerEnvs;
}
public void mapPartition(java.lang.Iterable<Tuple2<Vector, Double>> it,
Collector<XGBoostModel> collector) throws XGBoostError {
workerEnvs.put(
"DMLC_TASK_ID",
String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask())
);
if (logger.isInfoEnabled()) {
logger.info("start with env: {}", workerEnvs.entrySet().stream()
.map(e -> String.format("\"%s\": \"%s\"", e.getKey(), e.getValue()))
.collect(Collectors.joining(", "))
);
}
final Iterator<LabeledPoint> dataIter =
StreamSupport
.stream(it.spliterator(), false)
.map(VectorToPointMapper.INSTANCE)
.iterator();
if (dataIter.hasNext()) {
final DMatrix trainMat = new DMatrix(dataIter, null);
int numEarlyStoppingRounds =
Optional.ofNullable(params.get("numEarlyStoppingRounds"))
.map(x -> Integer.parseInt(x.toString()))
.orElse(0);
final Booster booster = trainBooster(trainMat, numEarlyStoppingRounds);
collector.collect(new XGBoostModel(booster));
} else {
logger.warn("Nothing to train with.");
}
}
private Booster trainBooster(DMatrix trainMat,
int numEarlyStoppingRounds) throws XGBoostError {
Booster booster;
final Map<String, DMatrix> watches =
new HashMap<String, DMatrix>() {{ put("train", trainMat); }};
try {
Communicator.init(workerEnvs);
booster = ml.dmlc.xgboost4j.java.XGBoost
.train(
trainMat,
params,
round,
watches,
null,
null,
null,
numEarlyStoppingRounds);
} catch (XGBoostError xgbException) {
final String identifier = String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask());
logger.warn(
String.format("XGBooster worker %s has failed due to", identifier),
xgbException
);
throw xgbException;
} finally {
Communicator.shutdown();
}
return booster;
}
private static class VectorToPointMapper
implements Function<Tuple2<Vector, Double>, LabeledPoint> {
public static VectorToPointMapper INSTANCE = new VectorToPointMapper();
@Override
public LabeledPoint apply(Tuple2<Vector, Double> tuple) {
final SparseVector vector = tuple.f0.toSparse();
final double[] values = vector.values;
final int size = values.length;
final float[] array = new float[size];
for (int i = 0; i < size; i++) {
array[i] = (float) values[i];
}
return new LabeledPoint(
tuple.f1.floatValue(),
vector.size(),
vector.indices,
array);
}
}
}
/**
* Load XGBoost model from path, using Hadoop Filesystem API.
*
* @param modelPath The path that is accessible by hadoop filesystem API.
* @return The loaded model
*/
public static XGBoostModel loadModelFromHadoopFile(final String modelPath) throws Exception {
final FileSystem fileSystem = FileSystem.get(new Configuration());
final Path f = new Path(modelPath);
try (FSDataInputStream opened = fileSystem.open(f)) {
return new XGBoostModel(ml.dmlc.xgboost4j.java.XGBoost.loadModel(opened));
}
}
/**
* Train a xgboost model with link.
*
* @param dtrain The training data.
* @param params XGBoost parameters.
* @param numBoostRound Number of rounds to train.
*/
public static XGBoostModel train(DataSet<Tuple2<Vector, Double>> dtrain,
Map<String, Object> params,
int numBoostRound) throws Exception {
final RabitTracker tracker =
new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
if (tracker.start(0L)) {
return dtrain
.mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerEnvs()))
.reduce((x, y) -> x)
.collect()
.get(0);
} else {
throw new Error("Tracker cannot be started");
}
}
}

View File

@ -0,0 +1,136 @@
/*
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java.flink;
import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.stream.StreamSupport;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.util.Collector;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;
public class XGBoostModel implements Serializable {
private static final org.slf4j.Logger logger =
org.slf4j.LoggerFactory.getLogger(XGBoostModel.class);
private final Booster booster;
private final PredictorFunction predictorFunction;
public XGBoostModel(Booster booster) {
this.booster = booster;
this.predictorFunction = new PredictorFunction(booster);
}
/**
* Save the model as a Hadoop filesystem file.
*
* @param modelPath The model path as in Hadoop path.
*/
public void saveModelAsHadoopFile(String modelPath) throws IOException, XGBoostError {
booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)));
}
public byte[] toByteArray(String format) throws XGBoostError {
return booster.toByteArray(format);
}
/**
* Save the model as a Hadoop filesystem file.
*
* @param modelPath The model path as in Hadoop path.
* @param format The model format (ubj, json, deprecated)
* @throws XGBoostError internal error
* @throws IOException save error
*/
public void saveModelAsHadoopFile(String modelPath, String format)
throws IOException, XGBoostError {
booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)), format);
}
/**
* predict with the given DMatrix
*
* @param testSet the local test set represented as DMatrix
* @return prediction result
*/
public float[][] predict(DMatrix testSet) throws XGBoostError {
return booster.predict(testSet, true, 0);
}
/**
* Predict given vector dataset.
*
* @param data The dataset to be predicted.
* @return The prediction result.
*/
public DataSet<Float[]> predict(DataSet<Vector> data) {
return data.mapPartition(predictorFunction);
}
private static class PredictorFunction implements MapPartitionFunction<Vector, Float[]> {
private final Booster booster;
public PredictorFunction(Booster booster) {
this.booster = booster;
}
@Override
public void mapPartition(Iterable<Vector> it, Collector<Float[]> out) throws Exception {
final Iterator<LabeledPoint> dataIter =
StreamSupport.stream(it.spliterator(), false)
.map(Vector::toSparse)
.map(PredictorFunction::fromVector)
.iterator();
if (dataIter.hasNext()) {
final DMatrix data = new DMatrix(dataIter, null);
float[][] predictions = booster.predict(data, true, 2);
Arrays.stream(predictions).map(ArrayUtils::toObject).forEach(out::collect);
} else {
logger.debug("Empty partition");
}
}
private static LabeledPoint fromVector(SparseVector vector) {
final int[] index = vector.indices;
final double[] value = vector.values;
int size = value.length;
final float[] values = new float[size];
for (int i = 0; i < size; i++) {
values[i] = (float) value[i];
}
return new LabeledPoint(0.0f, vector.size(), index, values);
}
}
}

View File

@ -1,99 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.flink
import scala.collection.JavaConverters.asScalaIteratorConverter
import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala}
import org.apache.commons.logging.LogFactory
import org.apache.flink.api.common.functions.RichMapPartitionFunction
import org.apache.flink.api.scala.{DataSet, _}
import org.apache.flink.ml.common.LabeledVector
import org.apache.flink.util.Collector
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
object XGBoost {
/**
* Helper map function to start the job.
*
* @param workerEnvs
*/
private class MapFunction(paramMap: Map[String, Any],
round: Int,
workerEnvs: java.util.Map[String, String])
extends RichMapPartitionFunction[LabeledVector, XGBoostModel] {
val logger = LogFactory.getLog(this.getClass)
def mapPartition(it: java.lang.Iterable[LabeledVector],
collector: Collector[XGBoostModel]): Unit = {
workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
logger.info("start with env" + workerEnvs.toString)
Communicator.init(workerEnvs)
val mapper = (x: LabeledVector) => {
val (index, value) = x.vector.toSeq.unzip
LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray)
}
val dataIter = for (x <- it.iterator().asScala) yield mapper(x)
val trainMat = new DMatrix(dataIter, null)
val watches = List("train" -> trainMat).toMap
val round = 2
val numEarlyStoppingRounds = paramMap.get("numEarlyStoppingRounds")
.map(_.toString.toInt).getOrElse(0)
val booster = XGBoostScala.train(trainMat, paramMap, round, watches,
earlyStoppingRound = numEarlyStoppingRounds)
Communicator.shutdown()
collector.collect(new XGBoostModel(booster))
}
}
val logger = LogFactory.getLog(this.getClass)
/**
* Load XGBoost model from path, using Hadoop Filesystem API.
*
* @param modelPath The path that is accessible by hadoop filesystem API.
* @return The loaded model
*/
def loadModelFromHadoopFile(modelPath: String) : XGBoostModel = {
new XGBoostModel(
XGBoostScala.loadModel(FileSystem.get(new Configuration).open(new Path(modelPath))))
}
/**
* Train a xgboost model with link.
*
* @param dtrain The training data.
* @param params The parameters to XGBoost.
* @param round Number of rounds to train.
*/
def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int):
XGBoostModel = {
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
if (tracker.start(0L)) {
dtrain
.mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))
.reduce((x, y) => x).collect().head
} else {
throw new Error("Tracker cannot be started")
null
}
}
}

View File

@ -1,67 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.flink
import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import org.apache.flink.api.scala.{DataSet, _}
import org.apache.flink.ml.math.Vector
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
class XGBoostModel (booster: Booster) extends Serializable {
/**
* Save the model as a Hadoop filesystem file.
*
* @param modelPath The model path as in Hadoop path.
*/
def saveModelAsHadoopFile(modelPath: String): Unit = {
booster.saveModel(FileSystem
.get(new Configuration)
.create(new Path(modelPath)))
}
/**
* predict with the given DMatrix
* @param testSet the local test set represented as DMatrix
* @return prediction result
*/
def predict(testSet: DMatrix): Array[Array[Float]] = {
booster.predict(testSet, true, 0)
}
/**
* Predict given vector dataset.
*
* @param data The dataset to be predicted.
* @return The prediction result.
*/
def predict(data: DataSet[Vector]) : DataSet[Array[Float]] = {
val predictMap: Iterator[Vector] => Traversable[Array[Float]] =
(it: Iterator[Vector]) => {
val mapper = (x: Vector) => {
val (index, value) = x.toSeq.unzip
LabeledPoint(0.0f, x.size, index.toArray, value.map(_.toFloat).toArray)
}
val dataIter = for (x <- it) yield mapper(x)
val dmat = new DMatrix(dataIter, null)
this.booster.predict(dmat)
}
data.mapPartition(predictMap)
}
}

View File

@ -38,22 +38,10 @@
<version>4.13.2</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-actor_${scala.binary.version}</artifactId>
<version>2.6.20</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-testkit_${scala.binary.version}</artifactId>
<version>2.6.20</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
<version>3.0.5</version>
<version>3.2.15</version>
<scope>provided</scope>
</dependency>
<dependency>

View File

@ -19,10 +19,10 @@ package ml.dmlc.xgboost4j.scala
import scala.collection.mutable.ArrayBuffer
import ai.rapids.cudf.Table
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
class QuantileDMatrixSuite extends FunSuite {
class QuantileDMatrixSuite extends AnyFunSuite {
test("QuantileDMatrix test") {

View File

@ -44,13 +44,6 @@
<version>${spark.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>ai.rapids</groupId>
<artifactId>cudf</artifactId>
<version>${cudf.version}</version>
<classifier>${cudf.classifier}</classifier>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.nvidia</groupId>
<artifactId>rapids-4-spark_${scala.binary.version}</artifactId>

View File

@ -20,14 +20,15 @@ import java.nio.file.{Files, Path}
import java.sql.{Date, Timestamp}
import java.util.{Locale, TimeZone}
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.{GpuTestUtils, SparkConf}
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.{Row, SparkSession}
trait GpuTestSuite extends FunSuite with TmpFolderSuite {
trait GpuTestSuite extends AnyFunSuite with TmpFolderSuite {
import SparkSessionHolder.withSparkSession
protected def getResourcePath(resource: String): String = {
@ -200,7 +201,7 @@ trait GpuTestSuite extends FunSuite with TmpFolderSuite {
}
trait TmpFolderSuite extends BeforeAndAfterAll { self: FunSuite =>
trait TmpFolderSuite extends BeforeAndAfterAll { self: AnyFunSuite =>
protected var tempDir: Path = _
override def beforeAll(): Unit = {

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2021-2022 by Contributors
Copyright (c) 2021-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -22,7 +22,6 @@ import java.util.ServiceLoader
import scala.collection.JavaConverters._
import scala.collection.{AbstractIterator, Iterator, mutable}
import ml.dmlc.xgboost4j.java.Communicator
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
@ -35,7 +34,6 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.logging.LogFactory
import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
@ -263,12 +261,6 @@ object PreXGBoost extends PreXGBoostProvider {
private var batchCnt = 0
private val batchIterImpl = rowIterator.grouped(inferBatchSize).flatMap { batchRow =>
if (batchCnt == 0) {
val rabitEnv = Array(
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
Communicator.init(rabitEnv.asJava)
}
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
@ -295,13 +287,8 @@ object PreXGBoost extends PreXGBoostProvider {
override def hasNext: Boolean = batchIterImpl.hasNext
override def next(): Row = {
val ret = batchIterImpl.next()
if (!batchIterImpl.hasNext) {
Communicator.shutdown()
}
ret
}
override def next(): Row = batchIterImpl.next()
}
}

View File

@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -23,7 +23,6 @@ import scala.util.Random
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
@ -44,21 +43,16 @@ import org.apache.spark.sql.SparkSession
* Use a finite, non-zero timeout value to prevent tracker from
* hanging indefinitely (in milliseconds)
* (supported by "scala" implementation only.)
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
* in Scala without Python components, and with full support of timeouts.
* The Scala implementation is currently experimental, use at your own risk.
*
* @param hostIp The Rabit Tracker host IP address which is only used for python implementation.
* This is only needed if the host IP cannot be automatically guessed.
* @param pythonExec The python executed path for Rabit Tracker,
* which is only used for python implementation.
*/
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String,
case class TrackerConf(workerConnectionTimeout: Long,
hostIp: String = "", pythonExec: String = "")
object TrackerConf {
def apply(): TrackerConf = TrackerConf(0L, "python")
def apply(): TrackerConf = TrackerConf(0L)
}
private[scala] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int,
@ -349,11 +343,9 @@ object XGBoost extends Serializable {
/** visiable for testing */
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
val tracker: IRabitTracker = trackerConf.trackerImpl match {
case "scala" => new RabitTracker(nWorkers)
case "python" => new PyRabitTracker(nWorkers, trackerConf.hostIp, trackerConf.pythonExec)
case _ => new PyRabitTracker(nWorkers)
}
val tracker: IRabitTracker = new PyRabitTracker(
nWorkers, trackerConf.hostIp, trackerConf.pythonExec
)
tracker
}

View File

@ -22,11 +22,10 @@ import scala.util.Random
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
import ml.dmlc.xgboost4j.scala.DMatrix
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
class CommunicatorRobustnessSuite extends FunSuite with PerTest {
class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
private def getXGBoostExecutionParams(paramMap: Map[String, Any]): XGBoostExecutionParams = {
val classifier = new XGBoostClassifier(paramMap)
@ -40,7 +39,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
val paramMap = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "python", hostIp))
"tracker_conf" -> TrackerConf(0L, hostIp))
val xgbExecParams = getXGBoostExecutionParams(paramMap)
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
tracker match {
@ -53,7 +52,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
val paramMap1 = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "python", "", pythonExec))
"tracker_conf" -> TrackerConf(0L, "", pythonExec))
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
tracker1 match {
@ -66,7 +65,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
val paramMap2 = Map(
"num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(0L, "python", hostIp, pythonExec))
"tracker_conf" -> TrackerConf(0L, hostIp, pythonExec))
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
tracker2 match {
@ -78,58 +77,6 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
}
}
test("training with Scala-implemented Rabit tracker") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
val model = new XGBoostClassifier(paramMap).fit(training)
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
}
test("test Communicator allreduce to validate Scala-implemented Rabit tracker") {
val vectorLength = 100
val rdd = sc.parallelize(
(1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache()
val tracker = new ScalaRabitTracker(numWorkers)
tracker.start(0)
val trackerEnvs = tracker.getWorkerEnvs
val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]()
val rawData = rdd.mapPartitions { iter =>
Iterator(iter.toArray)
}.collect()
val maxVec = (0 until vectorLength).toArray.map { j =>
(0 until numWorkers).toArray.map { i => rawData(i)(j) }.max
}
val allReduceResults = rdd.mapPartitions { iter =>
Communicator.init(trackerEnvs)
val arr = iter.toArray
val results = Communicator.allReduce(arr, Communicator.OpType.MAX)
Communicator.shutdown()
Iterator(results)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
allReduceResults.foreachPartition(() => _)
val byPartitionResults = allReduceResults.collect()
assert(byPartitionResults(0).length == vectorLength)
collectedAllReduceResults.put(byPartitionResults(0))
}
}
sparkThread.start()
assert(tracker.waitFor(0L) == 0)
sparkThread.join()
assert(collectedAllReduceResults.poll().sameElements(maxVec))
}
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
/*
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
@ -193,68 +140,6 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
assert(tracker.waitFor(0) != 0)
}
test("test Scala RabitTracker's exception handling: it should not hang forever.") {
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new ScalaRabitTracker(numWorkers)
tracker.start(0)
val trackerEnvs = tracker.getWorkerEnvs
val workerCount: Int = numWorkers
val dummyTasks = rdd.mapPartitions { iter =>
Communicator.init(trackerEnvs)
val index = iter.next()
Thread.sleep(100 + index * 10)
if (index == workerCount) {
// kill the worker by throwing an exception
throw new RuntimeException("Worker exception.")
}
Communicator.shutdown()
Iterator(index)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
// forces a Spark job.
dummyTasks.foreachPartition(() => _)
}
}
sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
}
test("test Scala RabitTracker's workerConnectionTimeout") {
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new ScalaRabitTracker(numWorkers)
tracker.start(500)
val trackerEnvs = tracker.getWorkerEnvs
val dummyTasks = rdd.mapPartitions { iter =>
val index = iter.next()
// simulate that the first worker cannot connect to tracker due to network issues.
if (index != 1) {
Communicator.init(trackerEnvs)
Thread.sleep(1000)
Communicator.shutdown()
}
Iterator(index)
}.cache()
val sparkThread = new Thread() {
override def run(): Unit = {
// forces a Spark job.
dummyTasks.foreachPartition(() => _)
}
}
sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
// should fail due to connection timeout
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
}
test("should allow the dataframe containing communicator calls to be partially evaluated for" +
" multiple times (ISSUE-4406)") {
val paramMap = Map(

View File

@ -17,13 +17,13 @@
package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.ml.linalg.Vectors
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
import org.apache.spark.sql.functions._
class DeterministicPartitioningSuite extends FunSuite with TmpFolderPerSuite with PerTest {
class DeterministicPartitioningSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
test("perform deterministic partitioning when checkpointInternal and" +
" checkpointPath is set (Classifier)") {

View File

@ -19,10 +19,10 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, ExternalCheckpointManager, XGBoost => SXGBoost}
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
import org.apache.hadoop.fs.{FileSystem, Path}
class ExternalCheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
private def produceParamMap(checkpointPath: String, checkpointInterval: Int):
Map[String, Any] = {

View File

@ -18,12 +18,12 @@ package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.Partitioner
import org.apache.spark.ml.feature.VectorAssembler
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.sql.functions._
import scala.util.Random
class FeatureSizeValidatingSuite extends FunSuite with PerTest {
class FeatureSizeValidatingSuite extends AnyFunSuite with PerTest {
test("transform throwing exception if feature size of dataset is greater than model's") {
val modelPath = getClass.getResource("/model/0.82/model").getPath

View File

@ -19,12 +19,12 @@ package ml.dmlc.xgboost4j.scala.spark
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.DataFrame
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
import scala.util.Random
import org.apache.spark.SparkException
class MissingValueHandlingSuite extends FunSuite with PerTest {
class MissingValueHandlingSuite extends AnyFunSuite with PerTest {
test("dense vectors containing missing value") {
def buildDenseDataFrame(): DataFrame = {
val numRows = 100

View File

@ -16,12 +16,13 @@
package ml.dmlc.xgboost4j.scala.spark
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.SparkException
import org.apache.spark.ml.param.ParamMap
class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {
class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
test("XGBoost and Spark parameters synchronize correctly") {
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",

View File

@ -22,13 +22,14 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.scalatest.{BeforeAndAfterEach, FunSuite}
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
import scala.math.min
import scala.util.Random
import org.apache.commons.io.IOUtils
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
trait PerTest extends BeforeAndAfterEach { self: AnyFunSuite =>
protected val numWorkers: Int = min(Runtime.getRuntime.availableProcessors(), 4)

View File

@ -25,9 +25,9 @@ import scala.util.Random
import org.apache.spark.ml.feature._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.functions._
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest {
class PersistenceSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
test("test persistence of XGBoostClassifier and XGBoostClassificationModel") {
val eval = new EvalError()

View File

@ -19,9 +19,10 @@ package ml.dmlc.xgboost4j.scala.spark
import java.nio.file.{Files, Path}
import org.apache.spark.network.util.JavaUtils
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite
trait TmpFolderPerSuite extends BeforeAndAfterAll { self: FunSuite =>
trait TmpFolderPerSuite extends BeforeAndAfterAll { self: AnyFunSuite =>
protected var tempDir: Path = _
override def beforeAll(): Unit = {

View File

@ -22,13 +22,13 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg._
import org.apache.spark.sql._
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
import org.apache.commons.io.IOUtils
import org.apache.spark.Partitioner
import org.apache.spark.ml.feature.VectorAssembler
class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuite {
class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
protected val treeMethod: String = "auto"

View File

@ -21,11 +21,11 @@ import ml.dmlc.xgboost4j.scala.Booster
import scala.collection.JavaConverters._
import org.apache.spark.sql._
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.SparkException
class XGBoostCommunicatorRegressionSuite extends FunSuite with PerTest {
class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
val predictionErrorMin = 0.00001f
val maxFailure = 2;

View File

@ -19,9 +19,9 @@ package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import org.apache.spark.sql._
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
class XGBoostConfigureSuite extends FunSuite with PerTest {
class XGBoostConfigureSuite extends AnyFunSuite with PerTest {
override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")

View File

@ -22,12 +22,12 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import ml.dmlc.xgboost4j.scala.DMatrix
import org.apache.spark.{SparkException, TaskContext}
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.functions.lit
class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
class XGBoostGeneralSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
test("distributed training with the specified worker number") {
val trainingRDD = sc.parallelize(Classification.train)

View File

@ -23,11 +23,11 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Row}
import org.scalatest.FunSuite
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.ml.feature.VectorAssembler
class XGBoostRegressorSuite extends FunSuite with PerTest with TmpFolderPerSuite {
class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
protected val treeMethod: String = "auto"
test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {

View File

@ -69,7 +69,7 @@ pom_template = """
<dependency>
<groupId>org.scalactic</groupId>
<artifactId>scalactic_${{scala.binary.version}}</artifactId>
<version>3.0.8</version>
<version>3.2.15</version>
<scope>test</scope>
</dependency>
<dependency>

View File

@ -31,22 +31,10 @@
<version>4.13.2</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-actor_${scala.binary.version}</artifactId>
<version>2.6.20</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-testkit_${scala.binary.version}</artifactId>
<version>2.6.20</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
<version>3.0.5</version>
<version>3.2.15</version>
<scope>provided</scope>
</dependency>
</dependencies>

View File

@ -1,195 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit
import java.net.{InetAddress, InetSocketAddress}
import akka.actor.ActorSystem
import akka.pattern.ask
import ml.dmlc.xgboost4j.java.{IRabitTracker, TrackerProperties}
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitTrackerHandler
import scala.concurrent.duration._
import scala.concurrent.{Await, Future}
import scala.util.{Failure, Success, Try}
/**
* Scala implementation of the Rabit tracker interface without Python dependency.
* The Scala Rabit tracker fully implements the timeout logic, effectively preventing the tracker
* (and thus any distributed tasks) to hang indefinitely due to network issues or worker node
* failures.
*
* Note that this implementation is currently experimental, and should be used at your own risk.
*
* Example usage:
* {{{
* import scala.concurrent.duration._
*
* val tracker = new RabitTracker(32)
* // allow up to 10 minutes for all workers to connect to the tracker.
* tracker.start(10 minutes)
*
* /* ...
* launching workers in parallel
* ...
* */
*
* // wait for worker execution up to 6 hours.
* // providing a finite timeout prevents a long-running task from hanging forever in
* // catastrophic events, like the loss of an executor during model training.
* tracker.waitFor(6 hours)
* }}}
*
* @param numWorkers Number of distributed workers from which the tracker expects connections.
* @param port The minimum port number that the tracker binds to.
* If port is omitted, or given as None, a random ephemeral port is chosen at runtime.
* @param maxPortTrials The maximum number of trials of socket binding, by sequentially
* increasing the port number.
*/
private[scala] class RabitTracker(numWorkers: Int, port: Option[Int] = None,
maxPortTrials: Int = 1000)
extends IRabitTracker {
import scala.collection.JavaConverters._
require(numWorkers >=1, "numWorkers must be greater than or equal to one (1).")
val system = ActorSystem.create("RabitTracker")
val handler = system.actorOf(RabitTrackerHandler.props(numWorkers), "Handler")
implicit val askTimeout: akka.util.Timeout = akka.util.Timeout(30 seconds)
private[this] val tcpBindingTimeout: Duration = 1 minute
var workerEnvs: Map[String, String] = Map.empty
override def uncaughtException(t: Thread, e: Throwable): Unit = {
handler ? RabitTrackerHandler.InterruptTracker(e)
}
/**
* Start the Rabit tracker.
*
* @param timeout The timeout for awaiting connections from worker nodes.
* Note that when used in Spark applications, because all Spark transformations are
* lazily executed, the I/O time for loading RDDs/DataFrames from external sources
* (local dist, HDFS, S3 etc.) must be taken into account for the timeout value.
* If the timeout value is too small, the Rabit tracker will likely timeout before workers
* establishing connections to the tracker, due to the overhead of loading data.
* Using a finite timeout is encouraged, as it prevents the tracker (thus the Spark driver
* running it) from hanging indefinitely due to worker connection issues (e.g. firewall.)
* @return Boolean flag indicating if the Rabit tracker starts successfully.
*/
private def start(timeout: Duration): Boolean = {
val hostAddress = Option(TrackerProperties.getInstance().getHostIp)
.map(InetAddress.getByName).getOrElse(InetAddress.getLocalHost)
handler ? RabitTrackerHandler.StartTracker(
new InetSocketAddress(hostAddress, port.getOrElse(0)), maxPortTrials, timeout)
// block by waiting for the actor to bind to a port
Try(Await.result(handler ? RabitTrackerHandler.RequestBoundFuture, askTimeout.duration)
.asInstanceOf[Future[Map[String, String]]]) match {
case Success(futurePortBound) =>
// The success of the Future is contingent on binding to an InetSocketAddress.
val isBound = Try(Await.ready(futurePortBound, tcpBindingTimeout)).isSuccess
if (isBound) {
workerEnvs = Await.result(futurePortBound, 0 nano)
}
isBound
case Failure(ex: Throwable) =>
false
}
}
/**
* Start the Rabit tracker.
*
* @param connectionTimeoutMillis Timeout, in milliseconds, for the tracker to wait for worker
* connections. If a non-positive value is provided, the tracker
* waits for incoming worker connections indefinitely.
* @return Boolean flag indicating if the Rabit tracker starts successfully.
*/
def start(connectionTimeoutMillis: Long): Boolean = {
if (connectionTimeoutMillis <= 0) {
start(Duration.Inf)
} else {
start(Duration.fromNanos(connectionTimeoutMillis * 1e6))
}
}
def stop(): Unit = {
system.terminate()
}
/**
* Get a Map of necessary environment variables to initiate Rabit workers.
*
* @return HashMap containing tracker information.
*/
def getWorkerEnvs: java.util.Map[String, String] = {
new java.util.HashMap((workerEnvs ++ Map(
"DMLC_NUM_WORKER" -> numWorkers.toString,
"DMLC_NUM_SERVER" -> "0"
)).asJava)
}
/**
* Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
* This method blocks until timeout or task completion.
*
* @param atMost the maximum execution time for the workers. By default,
* the tracker waits for the workers indefinitely.
* @return 0 if the tasks complete successfully, and non-zero otherwise.
*/
private def waitFor(atMost: Duration): Int = {
// request the completion Future from the tracker actor
Try(Await.result(handler ? RabitTrackerHandler.RequestCompletionFuture, askTimeout.duration)
.asInstanceOf[Future[Int]]) match {
case Success(futureCompleted) =>
// wait for all workers to complete synchronously.
val statusCode = Try(Await.result(futureCompleted, atMost)) match {
case Success(n) if n == numWorkers =>
IRabitTracker.TrackerStatus.SUCCESS.getStatusCode
case Success(n) if n < numWorkers =>
IRabitTracker.TrackerStatus.TIMEOUT.getStatusCode
case Failure(e) =>
IRabitTracker.TrackerStatus.FAILURE.getStatusCode
}
system.terminate()
statusCode
case Failure(ex: Throwable) =>
system.terminate()
IRabitTracker.TrackerStatus.FAILURE.getStatusCode
}
}
/**
* Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
* This method blocks until timeout or task completion.
*
* @param atMostMillis Number of milliseconds for the tracker to wait for workers. If a
* non-positive number is given, the tracker waits indefinitely.
* @return 0 if the tasks complete successfully, and non-zero otherwise
*/
def waitFor(atMostMillis: Long): Int = {
if (atMostMillis <= 0) {
waitFor(Duration.Inf)
} else {
waitFor(Duration.fromNanos(atMostMillis * 1e6))
}
}
}

View File

@ -1,361 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit.handler
import java.net.InetSocketAddress
import java.util.UUID
import scala.concurrent.duration._
import scala.collection.mutable
import scala.concurrent.{Promise, TimeoutException}
import akka.io.{IO, Tcp}
import akka.actor._
import ml.dmlc.xgboost4j.java.XGBoostError
import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, LinkMap}
import scala.util.{Failure, Random, Success, Try}
/** The Akka actor for handling and coordinating Rabit worker connections.
* This is the main actor for handling socket connections, interacting with the synchronous
* tracker interface, and resolving tree/ring/parent dependencies between workers.
*
* @param numWorkers Number of workers to track.
*/
private[scala] class RabitTrackerHandler(numWorkers: Int)
extends Actor with ActorLogging {
import context.system
import RabitWorkerHandler._
import RabitTrackerHandler._
private[this] val promisedWorkerEnvs = Promise[Map[String, String]]()
private[this] val promisedShutdownWorkers = Promise[Int]()
private[this] val tcpManager = IO(Tcp)
// resolves worker connection dependency.
val resolver = context.actorOf(Props(classOf[WorkerDependencyResolver], self), "Resolver")
// workers that have sent "shutdown" signal
private[this] val shutdownWorkers = mutable.Set.empty[Int]
private[this] val jobToRankMap = mutable.HashMap.empty[String, Int]
private[this] val actorRefToHost = mutable.HashMap.empty[ActorRef, String]
private[this] val ranksToAssign = mutable.ListBuffer(0 until numWorkers: _*)
private[this] var maxPortTrials = 0
private[this] var workerConnectionTimeout: Duration = Duration.Inf
private[this] var portTrials = 0
private[this] val startedWorkers = mutable.Set.empty[Int]
val linkMap = new LinkMap(numWorkers)
def decideRank(rank: Int, jobId: String = "NULL"): Option[Int] = {
rank match {
case r if r >= 0 => Some(r)
case _ =>
jobId match {
case "NULL" => None
case jid => jobToRankMap.get(jid)
}
}
}
/**
* Handler for all Akka Tcp connection/binding events. Read/write over the socket is handled
* by the RabitWorkerHandler.
*
* @param event Generic Tcp.Event
*/
private def handleTcpEvents(event: Tcp.Event): Unit = event match {
case Tcp.Bound(local) =>
// expect all workers to connect within timeout
log.info(s"Tracker listening @ ${local.getAddress.getHostAddress}:${local.getPort}")
log.info(s"Worker connection timeout is $workerConnectionTimeout.")
context.setReceiveTimeout(workerConnectionTimeout)
promisedWorkerEnvs.success(Map(
"DMLC_TRACKER_URI" -> local.getAddress.getHostAddress,
"DMLC_TRACKER_PORT" -> local.getPort.toString,
// not required because the world size will be communicated to the
// worker node after the rank is assigned.
"rabit_world_size" -> numWorkers.toString
))
case Tcp.CommandFailed(cmd: Tcp.Bind) =>
if (portTrials < maxPortTrials) {
portTrials += 1
tcpManager ! Tcp.Bind(self,
new InetSocketAddress(cmd.localAddress.getAddress, cmd.localAddress.getPort + 1),
backlog = 256)
}
case Tcp.Connected(remote, local) =>
log.debug(s"Incoming connection from worker @ ${remote.getAddress.getHostAddress}")
// revoke timeout if all workers have connected.
val workerHandler = context.actorOf(RabitWorkerHandler.props(
remote.getAddress.getHostAddress, numWorkers, self, sender()
), s"ConnectionHandler-${UUID.randomUUID().toString}")
val connection = sender()
connection ! Tcp.Register(workerHandler, keepOpenOnPeerClosed = true)
actorRefToHost.put(workerHandler, remote.getAddress.getHostName)
}
/**
* Handles external tracker control messages sent by RabitTracker (usually in ask patterns)
* to interact with the tracker interface.
*
* @param trackerMsg control messages sent by RabitTracker class.
*/
private def handleTrackerControlMessage(trackerMsg: TrackerControlMessage): Unit =
trackerMsg match {
case msg: StartTracker =>
maxPortTrials = msg.maxPortTrials
workerConnectionTimeout = msg.connectionTimeout
// if the port number is missing, try binding to a random ephemeral port.
if (msg.addr.getPort == 0) {
tcpManager ! Tcp.Bind(self,
new InetSocketAddress(msg.addr.getAddress, new Random().nextInt(61000 - 32768) + 32768),
backlog = 256)
} else {
tcpManager ! Tcp.Bind(self, msg.addr, backlog = 256)
}
sender() ! true
case RequestBoundFuture =>
sender() ! promisedWorkerEnvs.future
case RequestCompletionFuture =>
sender() ! promisedShutdownWorkers.future
case InterruptTracker(e) =>
log.error(e, "Uncaught exception thrown by worker.")
// make sure that waitFor() does not hang indefinitely.
promisedShutdownWorkers.failure(e)
context.stop(self)
}
/**
* Handles messages sent by child actors representing connecting Rabit workers, by brokering
* messages to the dependency resolver, and processing worker commands.
*
* @param workerMsg Message sent by RabitWorkerHandler actors.
*/
private def handleRabitWorkerMessage(workerMsg: RabitWorkerRequest): Unit = workerMsg match {
case req @ RequestAwaitConnWorkers(_, _) =>
// since the requester may request to connect to other workers
// that have not fully set up, delegate this request to the
// dependency resolver which handles the dependencies properly.
resolver forward req
// ---- Rabit worker commands: start/recover/shutdown/print ----
case WorkerTrackerPrint(_, _, _, msg) =>
log.info(msg.trim)
case WorkerShutdown(rank, _, _) =>
assert(rank >= 0, "Invalid rank.")
assert(!shutdownWorkers.contains(rank))
shutdownWorkers.add(rank)
log.info(s"Received shutdown signal from $rank")
if (shutdownWorkers.size == numWorkers) {
promisedShutdownWorkers.success(shutdownWorkers.size)
}
case WorkerRecover(prevRank, worldSize, jobId) =>
assert(prevRank >= 0)
sender() ! linkMap.assignRank(prevRank)
case WorkerStart(rank, worldSize, jobId) =>
assert(worldSize == numWorkers || worldSize == -1,
s"Purported worldSize ($worldSize) does not match worker count ($numWorkers)."
)
Try(decideRank(rank, jobId).getOrElse(ranksToAssign.remove(0))) match {
case Success(wkRank) =>
if (jobId != "NULL") {
jobToRankMap.put(jobId, wkRank)
}
val assignedRank = linkMap.assignRank(wkRank)
sender() ! assignedRank
resolver ! assignedRank
log.info("Received start signal from " +
s"${actorRefToHost.getOrElse(sender(), "")} [rank: $wkRank]")
case Failure(ex: IndexOutOfBoundsException) =>
// More than worldSize workers have connected, likely due to executor loss.
// Since Rabit currently does not support crash recovery (because the Allreduce results
// are not cached by the tracker, and because existing workers cannot reestablish
// connections to newly spawned executor/worker), the most reasonble action here is to
// interrupt the tracker immediate with failure state.
log.error("Received invalid start signal from " +
s"${actorRefToHost.getOrElse(sender(), "")}: all $worldSize workers have started."
)
promisedShutdownWorkers.failure(new XGBoostError("Invalid start signal" +
" received from worker, likely due to executor loss."))
case Failure(ex) =>
log.error(ex, "Unexpected error")
promisedShutdownWorkers.failure(ex)
}
// ---- Dependency resolving related messages ----
case msg @ WorkerStarted(host, rank, awaitingAcceptance) =>
log.info(s"Worker $host (rank: $rank) has started.")
resolver forward msg
startedWorkers.add(rank)
if (startedWorkers.size == numWorkers) {
log.info("All workers have started.")
}
case req @ DropFromWaitingList(_) =>
// all peer workers in dependency link map have connected;
// forward message to resolver to update dependencies.
resolver forward req
case _ =>
}
def receive: Actor.Receive = {
case tcpEvent: Tcp.Event => handleTcpEvents(tcpEvent)
case trackerMsg: TrackerControlMessage => handleTrackerControlMessage(trackerMsg)
case workerMsg: RabitWorkerRequest => handleRabitWorkerMessage(workerMsg)
case akka.actor.ReceiveTimeout =>
if (startedWorkers.size < numWorkers) {
promisedShutdownWorkers.failure(
new TimeoutException("Timed out waiting for workers to connect: " +
s"${numWorkers - startedWorkers.size} of $numWorkers did not start/connect.")
)
context.stop(self)
}
context.setReceiveTimeout(Duration.Undefined)
}
}
/**
* Resolve the dependency between nodes as they connect to the tracker.
* The dependency is enforced that a worker of rank K depends on its neighbors (from the treeMap
* and ringMap) whose ranks are smaller than K. Since ranks are assigned in the order of
* connections by workers, this dependency constraint assumes that a worker node connects first
* is likely to finish setup first.
*/
private[rabit] class WorkerDependencyResolver(handler: ActorRef) extends Actor with ActorLogging {
import RabitWorkerHandler._
context.watch(handler)
case class Fulfillment(toConnectSet: Set[Int], promise: Promise[AwaitingConnections])
// worker nodes that have connected, but have not send WorkerStarted message.
private val dependencyMap = mutable.Map.empty[Int, Set[Int]]
private val startedWorkers = mutable.Set.empty[Int]
// worker nodes that have started, and await for connections.
private val awaitConnWorkers = mutable.Map.empty[Int, ActorRef]
private val pendingFulfillment = mutable.Map.empty[Int, Fulfillment]
def awaitingWorkers(linkSet: Set[Int]): AwaitingConnections = {
val connSet = awaitConnWorkers.toMap
.filterKeys(k => linkSet.contains(k))
AwaitingConnections(connSet, linkSet.size - connSet.size)
}
def receive: Actor.Receive = {
// a copy of the AssignedRank message that is also sent to the worker
case AssignedRank(rank, tree_neighbors, ring, parent) =>
// the workers that the worker of given `rank` depends on:
// worker of rank K only depends on workers with rank smaller than K.
val dependentWorkers = (tree_neighbors.toSet ++ Set(ring._1, ring._2))
.filter{ r => r != -1 && r < rank}
log.debug(s"Rank $rank connected, dependencies: $dependentWorkers")
dependencyMap.put(rank, dependentWorkers)
case RequestAwaitConnWorkers(rank, toConnectSet) =>
val promise = Promise[AwaitingConnections]()
assert(dependencyMap.contains(rank))
val updatedDependency = dependencyMap(rank) diff startedWorkers
if (updatedDependency.isEmpty) {
// all dependencies are satisfied
log.debug(s"Rank $rank has all dependencies satisfied.")
promise.success(awaitingWorkers(toConnectSet))
} else {
log.debug(s"Rank $rank's request for AwaitConnWorkers is pending fulfillment.")
// promise is pending fulfillment due to unresolved dependency
pendingFulfillment.put(rank, Fulfillment(toConnectSet, promise))
}
sender() ! promise.future
case WorkerStarted(_, started, awaitingAcceptance) =>
startedWorkers.add(started)
if (awaitingAcceptance > 0) {
awaitConnWorkers.put(started, sender())
}
// remove the started rank from all dependencies.
dependencyMap.remove(started)
dependencyMap.foreach { case (r, dset) =>
val updatedDependency = dset diff startedWorkers
// fulfill the future if all dependencies are met (started.)
if (updatedDependency.isEmpty) {
log.debug(s"Rank $r has all dependencies satisfied.")
pendingFulfillment.remove(r).map{
case Fulfillment(toConnectSet, promise) =>
promise.success(awaitingWorkers(toConnectSet))
}
}
dependencyMap.update(r, updatedDependency)
}
case DropFromWaitingList(rank) =>
assert(awaitConnWorkers.remove(rank).isDefined)
case Terminated(ref) =>
if (ref.equals(handler)) {
context.stop(self)
}
}
}
private[scala] object RabitTrackerHandler {
// Messages sent by RabitTracker to this RabitTrackerHandler actor
trait TrackerControlMessage
case object RequestCompletionFuture extends TrackerControlMessage
case object RequestBoundFuture extends TrackerControlMessage
// Start the Rabit tracker at given socket address awaiting worker connections.
// All workers must connect to the tracker before connectionTimeout, otherwise the tracker will
// shut down due to timeout.
case class StartTracker(addr: InetSocketAddress,
maxPortTrials: Int,
connectionTimeout: Duration) extends TrackerControlMessage
// To interrupt the tracker handler due to uncaught exception thrown by the thread acting as
// driver for the distributed training.
case class InterruptTracker(e: Throwable) extends TrackerControlMessage
def props(numWorkers: Int): Props =
Props(new RabitTrackerHandler(numWorkers))
}

View File

@ -1,467 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit.handler
import java.nio.{ByteBuffer, ByteOrder}
import akka.io.Tcp
import akka.actor._
import akka.util.ByteString
import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, RabitTrackerHelpers}
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
import scala.util.Try
/**
* Actor to handle socket communication from worker node.
* To handle fragmentation in received data, this class acts like a FSM
* (finite-state machine) to keep track of the internal states.
*
* @param host IP address of the remote worker
* @param worldSize number of total workers
* @param tracker the RabitTrackerHandler actor reference
*/
private[scala] class RabitWorkerHandler(host: String, worldSize: Int, tracker: ActorRef,
connection: ActorRef)
extends FSM[RabitWorkerHandler.State, RabitWorkerHandler.DataStruct]
with ActorLogging with Stash {
import RabitWorkerHandler._
import RabitTrackerHelpers._
private[this] var rank: Int = 0
private[this] var port: Int = 0
// indicate if the connection is transient (like "print" or "shutdown")
private[this] var transient: Boolean = false
private[this] var peerClosed: Boolean = false
// number of workers pending acceptance of current worker
private[this] var awaitingAcceptance: Int = 0
private[this] var neighboringWorkers = Set.empty[Int]
// TODO: use a single memory allocation to host all buffers,
// including the transient ones for writing.
private[this] val readBuffer = ByteBuffer.allocate(4096)
.order(ByteOrder.nativeOrder())
// in case the received message is longer than needed,
// stash the spilled over part in this buffer, and send
// to self when transition occurs.
private[this] val spillOverBuffer = ByteBuffer.allocate(4096)
.order(ByteOrder.nativeOrder())
// when setup is complete, need to notify peer handlers
// to reduce the awaiting-connection counter.
private[this] var pendingAcknowledgement: Option[AcknowledgeAcceptance] = None
private def resetBuffers(): Unit = {
readBuffer.clear()
if (spillOverBuffer.position() > 0) {
spillOverBuffer.flip()
self ! Tcp.Received(ByteString.fromByteBuffer(spillOverBuffer))
spillOverBuffer.clear()
}
}
private def stashSpillOver(buf: ByteBuffer): Unit = {
if (buf.remaining() > 0) spillOverBuffer.put(buf)
}
def getNeighboringWorkers: Set[Int] = neighboringWorkers
def decodeCommand(buffer: ByteBuffer): TrackerCommand = {
val readBuffer = buffer.duplicate().order(ByteOrder.nativeOrder())
readBuffer.flip()
val rank = readBuffer.getInt()
val worldSize = readBuffer.getInt()
val jobId = readBuffer.getString
val command = readBuffer.getString
val trackerCommand = command match {
case "start" => WorkerStart(rank, worldSize, jobId)
case "shutdown" =>
transient = true
WorkerShutdown(rank, worldSize, jobId)
case "recover" =>
require(rank >= 0, "Invalid rank for recovering worker.")
WorkerRecover(rank, worldSize, jobId)
case "print" =>
transient = true
WorkerTrackerPrint(rank, worldSize, jobId, readBuffer.getString)
}
stashSpillOver(readBuffer)
trackerCommand
}
startWith(AwaitingHandshake, DataStruct())
when(AwaitingHandshake) {
case Event(Tcp.Received(magic), _) =>
assert(magic.length == 4)
val purportedMagic = magic.asNativeOrderByteBuffer.getInt
assert(purportedMagic == MAGIC_NUMBER, s"invalid magic number $purportedMagic from $host")
// echo back the magic number
connection ! Tcp.Write(magic)
goto(AwaitingCommand) using StructTrackerCommand
}
when(AwaitingCommand) {
case Event(Tcp.Received(bytes), validator) =>
bytes.asByteBuffers.foreach { buf => readBuffer.put(buf) }
if (validator.verify(readBuffer)) {
Try(decodeCommand(readBuffer)) match {
case scala.util.Success(decodedCommand) =>
tracker ! decodedCommand
case scala.util.Failure(th: java.nio.BufferUnderflowException) =>
// BufferUnderflowException would occur if the message to print has not arrived yet.
// Do nothing, wait for next Tcp.Received event
case scala.util.Failure(th: Throwable) => throw th
}
}
stay
// when rank for a worker is assigned, send encoded rank information
// back to worker over Tcp socket.
case Event(aRank @ AssignedRank(assignedRank, neighbors, ring, parent), _) =>
log.debug(s"Assigned rank [$assignedRank] for $host, T: $neighbors, R: $ring, P: $parent")
rank = assignedRank
// ranks from the ring
val ringRanks = List(
// ringPrev
if (ring._1 != -1 && ring._1 != rank) ring._1 else -1,
// ringNext
if (ring._2 != -1 && ring._2 != rank) ring._2 else -1
)
// update the set of all linked workers to current worker.
neighboringWorkers = neighbors.toSet ++ ringRanks.filterNot(_ == -1).toSet
connection ! Tcp.Write(ByteString.fromByteBuffer(aRank.toByteBuffer(worldSize)))
// to prevent reading before state transition
connection ! Tcp.SuspendReading
goto(BuildingLinkMap) using StructNodes
}
when(BuildingLinkMap) {
case Event(Tcp.Received(bytes), validator) =>
bytes.asByteBuffers.foreach { buf =>
readBuffer.put(buf)
}
if (validator.verify(readBuffer)) {
readBuffer.flip()
// for a freshly started worker, numConnected should be 0.
val numConnected = readBuffer.getInt()
val toConnectSet = neighboringWorkers.diff(
(0 until numConnected).map { index => readBuffer.getInt() }.toSet)
// check which workers are currently awaiting connections
tracker ! RequestAwaitConnWorkers(rank, toConnectSet)
}
stay
// got a Future from the tracker (resolver) about workers that are
// currently awaiting connections (particularly from this node.)
case Event(future: Future[_], _) =>
// blocks execution until all dependencies for current worker is resolved.
Await.result(future, 1 minute).asInstanceOf[AwaitingConnections] match {
// numNotReachable is the number of workers that currently
// cannot be connected to (pending connection or setup). Instead, this worker will AWAIT
// connections from those currently non-reachable nodes in the future.
case AwaitingConnections(waitConnNodes, numNotReachable) =>
log.debug(s"Rank $rank needs to connect to: $waitConnNodes, # bad: $numNotReachable")
val buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder())
buf.putInt(waitConnNodes.size).putInt(numNotReachable)
buf.flip()
// cache this message until the final state (SetupComplete)
pendingAcknowledgement = Some(AcknowledgeAcceptance(
waitConnNodes, numNotReachable))
connection ! Tcp.Write(ByteString.fromByteBuffer(buf))
if (waitConnNodes.isEmpty) {
connection ! Tcp.SuspendReading
goto(AwaitingErrorCount)
}
else {
waitConnNodes.foreach { case (peerRank, peerRef) =>
peerRef ! RequestWorkerHostPort
}
// a countdown for DivulgedHostPort messages.
stay using DataStruct(Seq.empty[DataField], waitConnNodes.size - 1)
}
}
case Event(DivulgedWorkerHostPort(peerRank, peerHost, peerPort), data) =>
val hostBytes = peerHost.getBytes()
val buffer = ByteBuffer.allocate(4 * 3 + hostBytes.length)
.order(ByteOrder.nativeOrder())
buffer.putInt(peerHost.length).put(hostBytes)
.putInt(peerPort).putInt(peerRank)
buffer.flip()
connection ! Tcp.Write(ByteString.fromByteBuffer(buffer))
if (data.counter == 0) {
// to prevent reading before state transition
connection ! Tcp.SuspendReading
goto(AwaitingErrorCount)
}
else {
stay using data.decrement()
}
}
when(AwaitingErrorCount) {
case Event(Tcp.Received(numErrors), _) =>
val buf = numErrors.asNativeOrderByteBuffer
buf.getInt match {
case 0 =>
stashSpillOver(buf)
goto(AwaitingPortNumber)
case _ =>
stashSpillOver(buf)
goto(BuildingLinkMap) using StructNodes
}
}
when(AwaitingPortNumber) {
case Event(Tcp.Received(assignedPort), _) =>
assert(assignedPort.length == 4)
port = assignedPort.asNativeOrderByteBuffer.getInt
log.debug(s"Rank $rank listening @ $host:$port")
// wait until the worker closes connection.
if (peerClosed) goto(SetupComplete) else stay
case Event(Tcp.PeerClosed, _) =>
peerClosed = true
if (port == 0) stay else goto(SetupComplete)
}
when(SetupComplete) {
case Event(ReduceWaitCount(count: Int), _) =>
awaitingAcceptance -= count
// check peerClosed to avoid prematurely stopping this actor (which sends RST to worker)
if (awaitingAcceptance == 0 && peerClosed) {
tracker ! DropFromWaitingList(rank)
// no longer needed.
context.stop(self)
}
stay
case Event(AcknowledgeAcceptance(peers, numBad), _) =>
awaitingAcceptance = numBad
tracker ! WorkerStarted(host, rank, awaitingAcceptance)
peers.values.foreach { peer =>
peer ! ReduceWaitCount(1)
}
if (awaitingAcceptance == 0 && peerClosed) self ! PoisonPill
stay
// can only divulge the complete host and port information
// when this worker is declared fully connected (otherwise
// port information is still missing.)
case Event(RequestWorkerHostPort, _) =>
sender() ! DivulgedWorkerHostPort(rank, host, port)
stay
}
onTransition {
// reset buffer when state transitions as data becomes stale
case _ -> SetupComplete =>
connection ! Tcp.ResumeReading
resetBuffers()
if (pendingAcknowledgement.isDefined) {
self ! pendingAcknowledgement.get
}
case _ =>
connection ! Tcp.ResumeReading
resetBuffers()
}
// default message handler
whenUnhandled {
case Event(Tcp.PeerClosed, _) =>
peerClosed = true
if (transient) context.stop(self)
stay
}
}
private[scala] object RabitWorkerHandler {
val MAGIC_NUMBER = 0xff99
// Finite states of this actor, which acts like a FSM.
// The following states are defined in order as the FSM progresses.
sealed trait State
// [1] Initial state, awaiting worker to send magic number per protocol.
case object AwaitingHandshake extends State
// [2] Awaiting worker to send command (start/print/recover/shutdown etc.)
case object AwaitingCommand extends State
// [3] Brokers connections between workers per ring/tree/parent link map.
case object BuildingLinkMap extends State
// [4] A transient state in which the worker reports the number of errors in establishing
// connections to other peer workers. If no errors, transition to next state.
case object AwaitingErrorCount extends State
// [5] Awaiting the worker to report its port number for accepting connections from peer workers.
// This port number information is later forwarded to linked workers.
case object AwaitingPortNumber extends State
// [6] Final state after completing the setup with the connecting worker. At this stage, the
// worker will have closed the Tcp connection. The actor remains alive to handle messages from
// peer actors representing workers with pending setups.
case object SetupComplete extends State
sealed trait DataField
case object IntField extends DataField
// an integer preceding the actual string
case object StringField extends DataField
case object IntSeqField extends DataField
object DataStruct {
def apply(): DataStruct = DataStruct(Seq.empty[DataField], 0)
}
// Internal data pertaining to individual state, used to verify the validity of packets sent by
// workers.
case class DataStruct(fields: Seq[DataField], counter: Int) {
/**
* Validate whether the provided buffer is complete (i.e., contains
* all data fields specified for this DataStruct.)
*
* @param buf a byte buffer containing received data.
*/
def verify(buf: ByteBuffer): Boolean = {
if (fields.isEmpty) return true
val dupBuf = buf.duplicate().order(ByteOrder.nativeOrder())
dupBuf.flip()
Try(fields.foldLeft(true) {
case (complete, field) =>
val remBytes = dupBuf.remaining()
complete && (remBytes > 0) && (remBytes >= (field match {
case IntField =>
dupBuf.position(dupBuf.position() + 4)
4
case StringField =>
val strLen = dupBuf.getInt
dupBuf.position(dupBuf.position() + strLen)
4 + strLen
case IntSeqField =>
val seqLen = dupBuf.getInt
dupBuf.position(dupBuf.position() + seqLen * 4)
4 + seqLen * 4
}))
}).getOrElse(false)
}
def increment(): DataStruct = DataStruct(fields, counter + 1)
def decrement(): DataStruct = DataStruct(fields, counter - 1)
}
val StructNodes = DataStruct(List(IntSeqField), 0)
val StructTrackerCommand = DataStruct(List(
IntField, IntField, StringField, StringField
), 0)
// ---- Messages between RabitTrackerHandler and RabitTrackerConnectionHandler ----
// RabitWorkerHandler --> RabitTrackerHandler
sealed trait RabitWorkerRequest
// RabitWorkerHandler <-- RabitTrackerHandler
sealed trait RabitWorkerResponse
// Representations of decoded worker commands.
abstract class TrackerCommand(val command: String) extends RabitWorkerRequest {
def rank: Int
def worldSize: Int
def jobId: String
def encode: ByteString = {
val buf = ByteBuffer.allocate(4 * 4 + jobId.length + command.length)
.order(ByteOrder.nativeOrder())
buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
.putInt(command.length).put(command.getBytes()).flip()
ByteString.fromByteBuffer(buf)
}
}
case class WorkerStart(rank: Int, worldSize: Int, jobId: String)
extends TrackerCommand("start")
case class WorkerShutdown(rank: Int, worldSize: Int, jobId: String)
extends TrackerCommand("shutdown")
case class WorkerRecover(rank: Int, worldSize: Int, jobId: String)
extends TrackerCommand("recover")
case class WorkerTrackerPrint(rank: Int, worldSize: Int, jobId: String, msg: String)
extends TrackerCommand("print") {
override def encode: ByteString = {
val buf = ByteBuffer.allocate(4 * 5 + jobId.length + command.length + msg.length)
.order(ByteOrder.nativeOrder())
buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
.putInt(command.length).put(command.getBytes())
.putInt(msg.length).put(msg.getBytes()).flip()
ByteString.fromByteBuffer(buf)
}
}
// Request to remove the worker of given rank from the list of workers awaiting peer connections.
case class DropFromWaitingList(rank: Int) extends RabitWorkerRequest
// Notify the tracker that the worker of given rank has finished setup and started.
case class WorkerStarted(host: String, rank: Int, awaitingAcceptance: Int)
extends RabitWorkerRequest
// Request the set of workers to connect to, according to the LinkMap structure.
case class RequestAwaitConnWorkers(rank: Int, toConnectSet: Set[Int])
extends RabitWorkerRequest
// Request, from the tracker, the set of nodes to connect.
case class AwaitingConnections(workers: Map[Int, ActorRef], numBad: Int)
extends RabitWorkerResponse
// ---- Messages between ConnectionHandler actors ----
sealed trait IntraWorkerMessage
// Notify neighboring workers to decrease the counter of awaiting workers by `count`.
case class ReduceWaitCount(count: Int) extends IntraWorkerMessage
// Request host and port information from peer ConnectionHandler actors (acting on behave of
// connecting workers.) This message will be brokered by RabitTrackerHandler.
case object RequestWorkerHostPort extends IntraWorkerMessage
// Response to the above request
case class DivulgedWorkerHostPort(rank: Int, host: String, port: Int) extends IntraWorkerMessage
// A reminder to send ReduceWaitCount messages once the actor is in state "SetupComplete".
case class AcknowledgeAcceptance(peers: Map[Int, ActorRef], numBad: Int)
extends IntraWorkerMessage
// ---- End of message definitions ----
def props(host: String, worldSize: Int, tracker: ActorRef, connection: ActorRef): Props = {
Props(new RabitWorkerHandler(host, worldSize, tracker, connection))
}
}

View File

@ -1,136 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit.util
import java.nio.{ByteBuffer, ByteOrder}
/**
* The assigned rank to a connecting Rabit worker, along with the information of the ranks of
* its linked peer workers, which are critical to perform Allreduce.
* When RabitWorkerHandler delegates "start" or "recover" commands from the connecting worker
* client, RabitTrackerHandler utilizes LinkMap to figure out linkage relationships, and respond
* with this class as a message, which is later encoded to byte string, and sent over socket
* connection to the worker client.
*
* @param rank assigned rank (ranked by worker connection order: first worker connecting to the
* tracker is assigned rank 0, second with rank 1, etc.)
* @param neighbors ranks of neighboring workers in a tree map.
* @param ring ranks of neighboring workers in a ring map.
* @param parent rank of the parent worker.
*/
private[rabit] case class AssignedRank(rank: Int, neighbors: Seq[Int],
ring: (Int, Int), parent: Int) {
/**
* Encode the AssignedRank message into byte sequence for socket communication with Rabit worker
* client.
* @param worldSize the number of total distributed workers. Must match `numWorkers` used in
* LinkMap.
* @return a ByteBuffer containing encoded data.
*/
def toByteBuffer(worldSize: Int): ByteBuffer = {
val buffer = ByteBuffer.allocate(4 * (neighbors.length + 6)).order(ByteOrder.nativeOrder())
buffer.putInt(rank).putInt(parent).putInt(worldSize).putInt(neighbors.length)
// neighbors in tree structure
neighbors.foreach { n => buffer.putInt(n) }
buffer.putInt(if (ring._1 != -1 && ring._1 != rank) ring._1 else -1)
buffer.putInt(if (ring._2 != -1 && ring._2 != rank) ring._2 else -1)
buffer.flip()
buffer
}
}
private[rabit] class LinkMap(numWorkers: Int) {
private def getNeighbors(rank: Int): Seq[Int] = {
val rank1 = rank + 1
Vector(rank1 / 2 - 1, rank1 * 2 - 1, rank1 * 2).filter { r =>
r >= 0 && r < numWorkers
}
}
/**
* Construct a ring structure that tends to share nodes with the tree.
*
* @param treeMap
* @param parentMap
* @param rank
* @return Seq[Int] instance starting from rank.
*/
private def constructShareRing(treeMap: Map[Int, Seq[Int]],
parentMap: Map[Int, Int],
rank: Int = 0): Seq[Int] = {
treeMap(rank).toSet - parentMap(rank) match {
case emptySet if emptySet.isEmpty =>
List(rank)
case connectionSet =>
connectionSet.zipWithIndex.foldLeft(List(rank)) {
case (ringSeq, (v, cnt)) =>
val vConnSeq = constructShareRing(treeMap, parentMap, v)
vConnSeq match {
case vconn if vconn.size == cnt + 1 =>
ringSeq ++ vconn.reverse
case vconn =>
ringSeq ++ vconn
}
}
}
}
/**
* Construct a ring connection used to recover local data.
*
* @param treeMap
* @param parentMap
*/
private def constructRingMap(treeMap: Map[Int, Seq[Int]], parentMap: Map[Int, Int]) = {
assert(parentMap(0) == -1)
val sharedRing = constructShareRing(treeMap, parentMap, 0).toVector
assert(sharedRing.length == treeMap.size)
(0 until numWorkers).map { r =>
val rPrev = (r + numWorkers - 1) % numWorkers
val rNext = (r + 1) % numWorkers
sharedRing(r) -> (sharedRing(rPrev), sharedRing(rNext))
}.toMap
}
private[this] val treeMap_ = (0 until numWorkers).map { r => r -> getNeighbors(r) }.toMap
private[this] val parentMap_ = (0 until numWorkers).map{ r => r -> ((r + 1) / 2 - 1) }.toMap
private[this] val ringMap_ = constructRingMap(treeMap_, parentMap_)
val rMap_ = (0 until (numWorkers - 1)).foldLeft((Map(0 -> 0), 0)) {
case ((rmap, k), i) =>
val kNext = ringMap_(k)._2
(rmap ++ Map(kNext -> (i + 1)), kNext)
}._1
val ringMap = ringMap_.map {
case (k, (v0, v1)) => rMap_(k) -> (rMap_(v0), rMap_(v1))
}
val treeMap = treeMap_.map {
case (k, vSeq) => rMap_(k) -> vSeq.map{ v => rMap_(v) }
}
val parentMap = parentMap_.map {
case (k, v) if k == 0 =>
rMap_(k) -> -1
case (k, v) =>
rMap_(k) -> rMap_(v)
}
def assignRank(rank: Int): AssignedRank = {
AssignedRank(rank, treeMap(rank), ringMap(rank), parentMap(rank))
}
}

View File

@ -1,39 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.rabit.util
import java.nio.{ByteOrder, ByteBuffer}
import akka.util.ByteString
private[rabit] object RabitTrackerHelpers {
implicit class ByteStringHelplers(bs: ByteString) {
// Java by default uses big endian. Enforce native endian so that
// the byte order is consistent with the workers.
def asNativeOrderByteBuffer: ByteBuffer = {
bs.asByteBuffer.order(ByteOrder.nativeOrder())
}
}
implicit class ByteBufferHelpers(buf: ByteBuffer) {
def getString: String = {
val len = buf.getInt()
val stringBuffer = ByteBuffer.allocate(len).order(ByteOrder.nativeOrder())
buf.get(stringBuffer.array(), 0, len)
new String(stringBuffer.array(), "utf-8")
}
}
}

Some files were not shown because too many files have changed in this diff Show More