merge 23Mar01
This commit is contained in:
commit
5446c501af
4
.github/workflows/jvm_tests.yml
vendored
4
.github/workflows/jvm_tests.yml
vendored
@ -40,7 +40,7 @@ jobs:
|
||||
key: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }}
|
||||
restore-keys: ${{ runner.os }}-m2-${{ hashFiles('./jvm-packages/pom.xml') }}
|
||||
|
||||
- name: Test XGBoost4J
|
||||
- name: Test XGBoost4J (Core)
|
||||
run: |
|
||||
cd jvm-packages
|
||||
mvn test -B -pl :xgboost4j_2.12
|
||||
@ -67,7 +67,7 @@ jobs:
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_IAM_S3_UPLOADER }}
|
||||
|
||||
|
||||
- name: Test XGBoost4J-Spark
|
||||
- name: Test XGBoost4J (Core, Spark, Examples)
|
||||
run: |
|
||||
rm -rfv build/
|
||||
cd jvm-packages
|
||||
|
||||
13
.github/workflows/python_tests.yml
vendored
13
.github/workflows/python_tests.yml
vendored
@ -65,7 +65,7 @@ jobs:
|
||||
run: |
|
||||
cd python-package
|
||||
python --version
|
||||
python setup.py sdist
|
||||
python -m build --sdist
|
||||
pip install -v ./dist/xgboost-*.tar.gz
|
||||
cd ..
|
||||
python -c 'import xgboost'
|
||||
@ -92,6 +92,9 @@ jobs:
|
||||
auto-update-conda: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
activate-environment: test
|
||||
- name: Install build
|
||||
run: |
|
||||
conda install -c conda-forge python-build
|
||||
- name: Display Conda env
|
||||
run: |
|
||||
conda info
|
||||
@ -100,7 +103,7 @@ jobs:
|
||||
run: |
|
||||
cd python-package
|
||||
python --version
|
||||
python setup.py sdist
|
||||
python -m build --sdist
|
||||
pip install -v ./dist/xgboost-*.tar.gz
|
||||
cd ..
|
||||
python -c 'import xgboost'
|
||||
@ -147,7 +150,7 @@ jobs:
|
||||
run: |
|
||||
cd python-package
|
||||
python --version
|
||||
python setup.py install
|
||||
pip install -v .
|
||||
|
||||
- name: Test Python package
|
||||
run: |
|
||||
@ -194,7 +197,7 @@ jobs:
|
||||
run: |
|
||||
cd python-package
|
||||
python --version
|
||||
python setup.py bdist_wheel --universal
|
||||
pip wheel -v . --wheel-dir dist/
|
||||
pip install ./dist/*.whl
|
||||
|
||||
- name: Test Python package
|
||||
@ -238,7 +241,7 @@ jobs:
|
||||
run: |
|
||||
cd python-package
|
||||
python --version
|
||||
python setup.py install
|
||||
pip install -v .
|
||||
|
||||
- name: Test Python package
|
||||
run: |
|
||||
|
||||
2
.github/workflows/r_tests.yml
vendored
2
.github/workflows/r_tests.yml
vendored
@ -54,7 +54,7 @@ jobs:
|
||||
matrix:
|
||||
config:
|
||||
- {os: windows-latest, r: 'release', compiler: 'mingw', build: 'autotools'}
|
||||
- {os: windows-latest, r: 'release', compiler: 'msvc', build: 'cmake'}
|
||||
- {os: windows-latest, r: '4.2.0', compiler: 'msvc', build: 'cmake'}
|
||||
env:
|
||||
R_REMOTES_NO_ERRORS_FROM_WARNINGS: true
|
||||
RSPM: ${{ matrix.config.rspm }}
|
||||
|
||||
@ -47,6 +47,7 @@ option(USE_NVTX "Build with cuda profiling annotations. Developers only." OFF)
|
||||
set(NVTX_HEADER_DIR "" CACHE PATH "Path to the stand-alone nvtx header")
|
||||
option(RABIT_MOCK "Build rabit with mock" OFF)
|
||||
option(HIDE_CXX_SYMBOLS "Build shared library and hide all C++ symbols" OFF)
|
||||
option(KEEP_BUILD_ARTIFACTS_IN_BINARY_DIR "Output build artifacts in CMake binary dir" OFF)
|
||||
## CUDA
|
||||
option(USE_CUDA "Build with GPU acceleration" OFF)
|
||||
option(USE_NCCL "Build with NCCL to enable distributed GPU support." OFF)
|
||||
@ -312,8 +313,13 @@ if (JVM_BINDINGS)
|
||||
xgboost_target_defs(xgboost4j)
|
||||
endif (JVM_BINDINGS)
|
||||
|
||||
if (KEEP_BUILD_ARTIFACTS_IN_BINARY_DIR)
|
||||
set_output_directory(runxgboost ${xgboost_BINARY_DIR})
|
||||
set_output_directory(xgboost ${xgboost_BINARY_DIR}/lib)
|
||||
else ()
|
||||
set_output_directory(runxgboost ${xgboost_SOURCE_DIR})
|
||||
set_output_directory(xgboost ${xgboost_SOURCE_DIR}/lib)
|
||||
endif ()
|
||||
# Ensure these two targets do not build simultaneously, as they produce outputs with conflicting names
|
||||
add_dependencies(xgboost runxgboost)
|
||||
|
||||
|
||||
@ -32,7 +32,7 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/objective/objective.o \
|
||||
$(PKGROOT)/src/objective/regression_obj.o \
|
||||
$(PKGROOT)/src/objective/multiclass_obj.o \
|
||||
$(PKGROOT)/src/objective/rank_obj.o \
|
||||
$(PKGROOT)/src/objective/lambdarank_obj.o \
|
||||
$(PKGROOT)/src/objective/hinge.o \
|
||||
$(PKGROOT)/src/objective/aft_obj.o \
|
||||
$(PKGROOT)/src/objective/adaptive.o \
|
||||
|
||||
@ -32,7 +32,7 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/objective/objective.o \
|
||||
$(PKGROOT)/src/objective/regression_obj.o \
|
||||
$(PKGROOT)/src/objective/multiclass_obj.o \
|
||||
$(PKGROOT)/src/objective/rank_obj.o \
|
||||
$(PKGROOT)/src/objective/lambdarank_obj.o \
|
||||
$(PKGROOT)/src/objective/hinge.o \
|
||||
$(PKGROOT)/src/objective/aft_obj.o \
|
||||
$(PKGROOT)/src/objective/adaptive.o \
|
||||
|
||||
@ -72,7 +72,7 @@ test_that("xgb.DMatrix: saving, loading", {
|
||||
tmp <- c("0 1:1 2:1", "1 3:1", "0 1:1")
|
||||
tmp_file <- tempfile(fileext = ".libsvm")
|
||||
writeLines(tmp, tmp_file)
|
||||
dtest4 <- xgb.DMatrix(tmp_file, silent = TRUE)
|
||||
dtest4 <- xgb.DMatrix(paste(tmp_file, "?format=libsvm", sep = ""), silent = TRUE)
|
||||
expect_equal(dim(dtest4), c(3, 4))
|
||||
expect_equal(getinfo(dtest4, 'label'), c(0, 1, 0))
|
||||
|
||||
|
||||
@ -20,10 +20,10 @@ num_round = 2
|
||||
# 0 means do not save any model except the final round model
|
||||
save_period = 2
|
||||
# The path of training data
|
||||
data = "agaricus.txt.train"
|
||||
data = "agaricus.txt.train?format=libsvm"
|
||||
# The path of validation data, used to monitor training process, here [test] sets name of the validation set
|
||||
eval[test] = "agaricus.txt.test"
|
||||
eval[test] = "agaricus.txt.test?format=libsvm"
|
||||
# evaluate on training data as well each round
|
||||
eval_train = 1
|
||||
# The path of test data
|
||||
test:data = "agaricus.txt.test"
|
||||
test:data = "agaricus.txt.test?format=libsvm"
|
||||
|
||||
@ -21,8 +21,8 @@ num_round = 2
|
||||
# 0 means do not save any model except the final round model
|
||||
save_period = 0
|
||||
# The path of training data
|
||||
data = "machine.txt.train"
|
||||
data = "machine.txt.train?format=libsvm"
|
||||
# The path of validation data, used to monitor training process, here [test] sets name of the validation set
|
||||
eval[test] = "machine.txt.test"
|
||||
eval[test] = "machine.txt.test?format=libsvm"
|
||||
# The path of test data
|
||||
test:data = "machine.txt.test"
|
||||
test:data = "machine.txt.test?format=libsvm"
|
||||
|
||||
@ -42,8 +42,8 @@ int main() {
|
||||
|
||||
// load the data
|
||||
DMatrixHandle dtrain, dtest;
|
||||
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.train", silent, &dtrain));
|
||||
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.test", silent, &dtest));
|
||||
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.train?format=libsvm", silent, &dtrain));
|
||||
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.test?format=libsvm", silent, &dtest));
|
||||
|
||||
// create the booster
|
||||
BoosterHandle booster;
|
||||
|
||||
@ -7,15 +7,19 @@ import os
|
||||
import xgboost as xgb
|
||||
|
||||
CURRENT_DIR = os.path.dirname(__file__)
|
||||
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
|
||||
dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test'))
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
dtrain = xgb.DMatrix(
|
||||
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
|
||||
)
|
||||
dtest = xgb.DMatrix(
|
||||
os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
|
||||
)
|
||||
watchlist = [(dtest, "eval"), (dtrain, "train")]
|
||||
###
|
||||
# advanced: start from a initial base prediction
|
||||
#
|
||||
print('start running example to start from a initial prediction')
|
||||
print("start running example to start from a initial prediction")
|
||||
# specify parameters via map, definition are same as c++ version
|
||||
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
|
||||
param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
|
||||
# train xgboost for 1 round
|
||||
bst = xgb.train(param, dtrain, 1, watchlist)
|
||||
# Note: we need the margin value instead of transformed prediction in
|
||||
@ -27,5 +31,5 @@ ptest = bst.predict(dtest, output_margin=True)
|
||||
dtrain.set_base_margin(ptrain)
|
||||
dtest.set_base_margin(ptest)
|
||||
|
||||
print('this is result of running from initial prediction')
|
||||
print("this is result of running from initial prediction")
|
||||
bst = xgb.train(param, dtrain, 1, watchlist)
|
||||
|
||||
@ -10,27 +10,45 @@ import xgboost as xgb
|
||||
|
||||
# load data in do training
|
||||
CURRENT_DIR = os.path.dirname(__file__)
|
||||
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
|
||||
param = {'max_depth':2, 'eta':1, 'objective':'binary:logistic'}
|
||||
dtrain = xgb.DMatrix(
|
||||
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
|
||||
)
|
||||
param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
|
||||
num_round = 2
|
||||
|
||||
print('running cross validation')
|
||||
print("running cross validation")
|
||||
# do cross validation, this will print result out as
|
||||
# [iteration] metric_name:mean_value+std_value
|
||||
# std_value is standard deviation of the metric
|
||||
xgb.cv(param, dtrain, num_round, nfold=5,
|
||||
metrics={'error'}, seed=0,
|
||||
callbacks=[xgb.callback.EvaluationMonitor(show_stdv=True)])
|
||||
xgb.cv(
|
||||
param,
|
||||
dtrain,
|
||||
num_round,
|
||||
nfold=5,
|
||||
metrics={"error"},
|
||||
seed=0,
|
||||
callbacks=[xgb.callback.EvaluationMonitor(show_stdv=True)],
|
||||
)
|
||||
|
||||
print('running cross validation, disable standard deviation display')
|
||||
print("running cross validation, disable standard deviation display")
|
||||
# do cross validation, this will print result out as
|
||||
# [iteration] metric_name:mean_value
|
||||
res = xgb.cv(param, dtrain, num_boost_round=10, nfold=5,
|
||||
metrics={'error'}, seed=0,
|
||||
callbacks=[xgb.callback.EvaluationMonitor(show_stdv=False),
|
||||
xgb.callback.EarlyStopping(3)])
|
||||
res = xgb.cv(
|
||||
param,
|
||||
dtrain,
|
||||
num_boost_round=10,
|
||||
nfold=5,
|
||||
metrics={"error"},
|
||||
seed=0,
|
||||
callbacks=[
|
||||
xgb.callback.EvaluationMonitor(show_stdv=False),
|
||||
xgb.callback.EarlyStopping(3),
|
||||
],
|
||||
)
|
||||
print(res)
|
||||
print('running cross validation, with preprocessing function')
|
||||
print("running cross validation, with preprocessing function")
|
||||
|
||||
|
||||
# define the preprocessing function
|
||||
# used to return the preprocessed training, test data, and parameter
|
||||
# we can use this to do weight rescale, etc.
|
||||
@ -38,32 +56,36 @@ print('running cross validation, with preprocessing function')
|
||||
def fpreproc(dtrain, dtest, param):
|
||||
label = dtrain.get_label()
|
||||
ratio = float(np.sum(label == 0)) / np.sum(label == 1)
|
||||
param['scale_pos_weight'] = ratio
|
||||
param["scale_pos_weight"] = ratio
|
||||
return (dtrain, dtest, param)
|
||||
|
||||
|
||||
# do cross validation, for each fold
|
||||
# the dtrain, dtest, param will be passed into fpreproc
|
||||
# then the return value of fpreproc will be used to generate
|
||||
# results of that fold
|
||||
xgb.cv(param, dtrain, num_round, nfold=5,
|
||||
metrics={'auc'}, seed=0, fpreproc=fpreproc)
|
||||
xgb.cv(param, dtrain, num_round, nfold=5, metrics={"auc"}, seed=0, fpreproc=fpreproc)
|
||||
|
||||
###
|
||||
# you can also do cross validation with customized loss function
|
||||
# See custom_objective.py
|
||||
##
|
||||
print('running cross validation, with customized loss function')
|
||||
print("running cross validation, with customized loss function")
|
||||
|
||||
|
||||
def logregobj(preds, dtrain):
|
||||
labels = dtrain.get_label()
|
||||
preds = 1.0 / (1.0 + np.exp(-preds))
|
||||
grad = preds - labels
|
||||
hess = preds * (1.0 - preds)
|
||||
return grad, hess
|
||||
|
||||
|
||||
def evalerror(preds, dtrain):
|
||||
labels = dtrain.get_label()
|
||||
return 'error', float(sum(labels != (preds > 0.0))) / len(labels)
|
||||
return "error", float(sum(labels != (preds > 0.0))) / len(labels)
|
||||
|
||||
param = {'max_depth':2, 'eta':1}
|
||||
|
||||
param = {"max_depth": 2, "eta": 1}
|
||||
# train with customized objective
|
||||
xgb.cv(param, dtrain, num_round, nfold=5, seed=0,
|
||||
obj=logregobj, feval=evalerror)
|
||||
xgb.cv(param, dtrain, num_round, nfold=5, seed=0, obj=logregobj, feval=evalerror)
|
||||
|
||||
@ -7,28 +7,37 @@ import os
|
||||
import xgboost as xgb
|
||||
|
||||
CURRENT_DIR = os.path.dirname(__file__)
|
||||
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
|
||||
dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test'))
|
||||
dtrain = xgb.DMatrix(
|
||||
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
|
||||
)
|
||||
dtest = xgb.DMatrix(
|
||||
os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
|
||||
)
|
||||
|
||||
param = [('max_depth', 2), ('objective', 'binary:logistic'), ('eval_metric', 'logloss'), ('eval_metric', 'error')]
|
||||
param = [
|
||||
("max_depth", 2),
|
||||
("objective", "binary:logistic"),
|
||||
("eval_metric", "logloss"),
|
||||
("eval_metric", "error"),
|
||||
]
|
||||
|
||||
num_round = 2
|
||||
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
||||
watchlist = [(dtest, "eval"), (dtrain, "train")]
|
||||
|
||||
evals_result = {}
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist, evals_result=evals_result)
|
||||
|
||||
print('Access logloss metric directly from evals_result:')
|
||||
print(evals_result['eval']['logloss'])
|
||||
print("Access logloss metric directly from evals_result:")
|
||||
print(evals_result["eval"]["logloss"])
|
||||
|
||||
print('')
|
||||
print('Access metrics through a loop:')
|
||||
print("")
|
||||
print("Access metrics through a loop:")
|
||||
for e_name, e_mtrs in evals_result.items():
|
||||
print('- {}'.format(e_name))
|
||||
print("- {}".format(e_name))
|
||||
for e_mtr_name, e_mtr_vals in e_mtrs.items():
|
||||
print(' - {}'.format(e_mtr_name))
|
||||
print(' - {}'.format(e_mtr_vals))
|
||||
print(" - {}".format(e_mtr_name))
|
||||
print(" - {}".format(e_mtr_vals))
|
||||
|
||||
print('')
|
||||
print('Access complete dictionary:')
|
||||
print("")
|
||||
print("Access complete dictionary:")
|
||||
print(evals_result)
|
||||
|
||||
@ -11,14 +11,22 @@ import xgboost as xgb
|
||||
# basically, we are using linear model, instead of tree for our boosters
|
||||
##
|
||||
CURRENT_DIR = os.path.dirname(__file__)
|
||||
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
|
||||
dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test'))
|
||||
dtrain = xgb.DMatrix(
|
||||
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
|
||||
)
|
||||
dtest = xgb.DMatrix(
|
||||
os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
|
||||
)
|
||||
# change booster to gblinear, so that we are fitting a linear model
|
||||
# alpha is the L1 regularizer
|
||||
# lambda is the L2 regularizer
|
||||
# you can also set lambda_bias which is L2 regularizer on the bias term
|
||||
param = {'objective':'binary:logistic', 'booster':'gblinear',
|
||||
'alpha': 0.0001, 'lambda': 1}
|
||||
param = {
|
||||
"objective": "binary:logistic",
|
||||
"booster": "gblinear",
|
||||
"alpha": 0.0001,
|
||||
"lambda": 1,
|
||||
}
|
||||
|
||||
# normally, you do not need to set eta (step_size)
|
||||
# XGBoost uses a parallel coordinate descent algorithm (shotgun),
|
||||
@ -29,9 +37,15 @@ param = {'objective':'binary:logistic', 'booster':'gblinear',
|
||||
##
|
||||
# the rest of settings are the same
|
||||
##
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
watchlist = [(dtest, "eval"), (dtrain, "train")]
|
||||
num_round = 4
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist)
|
||||
preds = bst.predict(dtest)
|
||||
labels = dtest.get_label()
|
||||
print('error=%f' % (sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds))))
|
||||
print(
|
||||
"error=%f"
|
||||
% (
|
||||
sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i])
|
||||
/ float(len(preds))
|
||||
)
|
||||
)
|
||||
|
||||
@ -16,8 +16,8 @@ test = os.path.join(CURRENT_DIR, "../data/agaricus.txt.test")
|
||||
|
||||
def native_interface():
|
||||
# load data in do training
|
||||
dtrain = xgb.DMatrix(train)
|
||||
dtest = xgb.DMatrix(test)
|
||||
dtrain = xgb.DMatrix(train + "?format=libsvm")
|
||||
dtest = xgb.DMatrix(test + "?format=libsvm")
|
||||
param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
|
||||
watchlist = [(dtest, "eval"), (dtrain, "train")]
|
||||
num_round = 3
|
||||
|
||||
@ -8,14 +8,18 @@ import xgboost as xgb
|
||||
|
||||
# load data in do training
|
||||
CURRENT_DIR = os.path.dirname(__file__)
|
||||
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
|
||||
dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test'))
|
||||
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
dtrain = xgb.DMatrix(
|
||||
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
|
||||
)
|
||||
dtest = xgb.DMatrix(
|
||||
os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
|
||||
)
|
||||
param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
|
||||
watchlist = [(dtest, "eval"), (dtrain, "train")]
|
||||
num_round = 3
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist)
|
||||
|
||||
print('start testing predict the leaf indices')
|
||||
print("start testing predict the leaf indices")
|
||||
# predict using first 2 tree
|
||||
leafindex = bst.predict(
|
||||
dtest, iteration_range=(0, 2), pred_leaf=True, strict_shape=True
|
||||
|
||||
@ -3,61 +3,12 @@
|
||||
This directory contains a demo of Federated Learning using
|
||||
[NVFlare](https://nvidia.github.io/NVFlare/).
|
||||
|
||||
## Training with CPU only
|
||||
## Horizontal Federated XGBoost
|
||||
|
||||
To run the demo, first build XGBoost with the federated learning plugin enabled (see the
|
||||
[README](../../plugin/federated/README.md)).
|
||||
For horizontal federated learning using XGBoost (data is split row-wise), check out the `horizontal` directory
|
||||
(see the [README](horizontal/README.md)).
|
||||
|
||||
Install NVFlare (note that currently NVFlare only supports Python 3.8):
|
||||
```shell
|
||||
pip install nvflare
|
||||
```
|
||||
## Vertical Federated XGBoost
|
||||
|
||||
Prepare the data:
|
||||
```shell
|
||||
./prepare_data.sh
|
||||
```
|
||||
|
||||
Start the NVFlare federated server:
|
||||
```shell
|
||||
/tmp/nvflare/poc/server/startup/start.sh
|
||||
```
|
||||
|
||||
In another terminal, start the first worker:
|
||||
```shell
|
||||
/tmp/nvflare/poc/site-1/startup/start.sh
|
||||
```
|
||||
|
||||
And the second worker:
|
||||
```shell
|
||||
/tmp/nvflare/poc/site-2/startup/start.sh
|
||||
```
|
||||
|
||||
Then start the admin CLI:
|
||||
```shell
|
||||
/tmp/nvflare/poc/admin/startup/fl_admin.sh
|
||||
```
|
||||
|
||||
In the admin CLI, run the following command:
|
||||
```shell
|
||||
submit_job hello-xgboost
|
||||
```
|
||||
|
||||
Once the training finishes, the model file should be written into
|
||||
`/tmp/nvlfare/poc/site-1/run_1/test.model.json` and `/tmp/nvflare/poc/site-2/run_1/test.model.json`
|
||||
respectively.
|
||||
|
||||
Finally, shutdown everything from the admin CLI, using `admin` as password:
|
||||
```shell
|
||||
shutdown client
|
||||
shutdown server
|
||||
```
|
||||
|
||||
## Training with GPUs
|
||||
|
||||
To demo with Federated Learning using GPUs, make sure your machine has at least 2 GPUs.
|
||||
Build XGBoost with the federated learning plugin enabled along with CUDA, but with NCCL
|
||||
turned off (see the [README](../../plugin/federated/README.md)).
|
||||
|
||||
Modify `config/config_fed_client.json` and set `use_gpus` to `true`, then repeat the steps
|
||||
above.
|
||||
For vertical federated learning using XGBoost (data is split column-wise), check out the `vertical` directory
|
||||
(see the [README](vertical/README.md)).
|
||||
|
||||
@ -1,23 +0,0 @@
|
||||
{
|
||||
"format_version": 2,
|
||||
"executors": [
|
||||
{
|
||||
"tasks": [
|
||||
"train"
|
||||
],
|
||||
"executor": {
|
||||
"path": "trainer.XGBoostTrainer",
|
||||
"args": {
|
||||
"server_address": "localhost:9091",
|
||||
"world_size": 2,
|
||||
"server_cert_path": "server-cert.pem",
|
||||
"client_key_path": "client-key.pem",
|
||||
"client_cert_path": "client-cert.pem",
|
||||
"use_gpus": "false"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"task_result_filters": [],
|
||||
"task_data_filters": []
|
||||
}
|
||||
@ -1,22 +0,0 @@
|
||||
{
|
||||
"format_version": 2,
|
||||
"server": {
|
||||
"heart_beat_timeout": 600
|
||||
},
|
||||
"task_data_filters": [],
|
||||
"task_result_filters": [],
|
||||
"workflows": [
|
||||
{
|
||||
"id": "server_workflow",
|
||||
"path": "controller.XGBoostController",
|
||||
"args": {
|
||||
"port": 9091,
|
||||
"world_size": 2,
|
||||
"server_key_path": "server-key.pem",
|
||||
"server_cert_path": "server-cert.pem",
|
||||
"client_cert_path": "client-cert.pem"
|
||||
}
|
||||
}
|
||||
],
|
||||
"components": []
|
||||
}
|
||||
63
demo/nvflare/horizontal/README.md
Normal file
63
demo/nvflare/horizontal/README.md
Normal file
@ -0,0 +1,63 @@
|
||||
# Experimental Support of Horizontal Federated XGBoost using NVFlare
|
||||
|
||||
This directory contains a demo of Horizontal Federated Learning using
|
||||
[NVFlare](https://nvidia.github.io/NVFlare/).
|
||||
|
||||
## Training with CPU only
|
||||
|
||||
To run the demo, first build XGBoost with the federated learning plugin enabled (see the
|
||||
[README](../../plugin/federated/README.md)).
|
||||
|
||||
Install NVFlare (note that currently NVFlare only supports Python 3.8):
|
||||
```shell
|
||||
pip install nvflare
|
||||
```
|
||||
|
||||
Prepare the data:
|
||||
```shell
|
||||
./prepare_data.sh
|
||||
```
|
||||
|
||||
Start the NVFlare federated server:
|
||||
```shell
|
||||
/tmp/nvflare/poc/server/startup/start.sh
|
||||
```
|
||||
|
||||
In another terminal, start the first worker:
|
||||
```shell
|
||||
/tmp/nvflare/poc/site-1/startup/start.sh
|
||||
```
|
||||
|
||||
And the second worker:
|
||||
```shell
|
||||
/tmp/nvflare/poc/site-2/startup/start.sh
|
||||
```
|
||||
|
||||
Then start the admin CLI:
|
||||
```shell
|
||||
/tmp/nvflare/poc/admin/startup/fl_admin.sh
|
||||
```
|
||||
|
||||
In the admin CLI, run the following command:
|
||||
```shell
|
||||
submit_job horizontal-xgboost
|
||||
```
|
||||
|
||||
Once the training finishes, the model file should be written into
|
||||
`/tmp/nvlfare/poc/site-1/run_1/test.model.json` and `/tmp/nvflare/poc/site-2/run_1/test.model.json`
|
||||
respectively.
|
||||
|
||||
Finally, shutdown everything from the admin CLI, using `admin` as password:
|
||||
```shell
|
||||
shutdown client
|
||||
shutdown server
|
||||
```
|
||||
|
||||
## Training with GPUs
|
||||
|
||||
To demo with Federated Learning using GPUs, make sure your machine has at least 2 GPUs.
|
||||
Build XGBoost with the federated learning plugin enabled along with CUDA, but with NCCL
|
||||
turned off (see the [README](../../plugin/federated/README.md)).
|
||||
|
||||
Modify `config/config_fed_client.json` and set `use_gpus` to `true`, then repeat the steps
|
||||
above.
|
||||
@ -15,8 +15,8 @@ split -n l/${world_size} --numeric-suffixes=1 -a 1 ../data/agaricus.txt.train ag
|
||||
split -n l/${world_size} --numeric-suffixes=1 -a 1 ../data/agaricus.txt.test agaricus.txt.test-site-
|
||||
|
||||
nvflare poc -n 2 --prepare
|
||||
mkdir -p /tmp/nvflare/poc/admin/transfer/hello-xgboost
|
||||
cp -fr config custom /tmp/nvflare/poc/admin/transfer/hello-xgboost
|
||||
mkdir -p /tmp/nvflare/poc/admin/transfer/horizontal-xgboost
|
||||
cp -fr config custom /tmp/nvflare/poc/admin/transfer/horizontal-xgboost
|
||||
cp server-*.pem client-cert.pem /tmp/nvflare/poc/server/
|
||||
for id in $(eval echo "{1..$world_size}"); do
|
||||
cp server-cert.pem client-*.pem /tmp/nvflare/poc/site-"$id"/
|
||||
59
demo/nvflare/vertical/README.md
Normal file
59
demo/nvflare/vertical/README.md
Normal file
@ -0,0 +1,59 @@
|
||||
# Experimental Support of Vertical Federated XGBoost using NVFlare
|
||||
|
||||
This directory contains a demo of Vertical Federated Learning using
|
||||
[NVFlare](https://nvidia.github.io/NVFlare/).
|
||||
|
||||
## Training with CPU only
|
||||
|
||||
To run the demo, first build XGBoost with the federated learning plugin enabled (see the
|
||||
[README](../../plugin/federated/README.md)).
|
||||
|
||||
Install NVFlare (note that currently NVFlare only supports Python 3.8):
|
||||
```shell
|
||||
pip install nvflare
|
||||
```
|
||||
|
||||
Prepare the data (note that this step will download the HIGGS dataset, which is 2.6GB compressed, and 7.5GB
|
||||
uncompressed, so make sure you have enough disk space and are on a fast internet connection):
|
||||
```shell
|
||||
./prepare_data.sh
|
||||
```
|
||||
|
||||
Start the NVFlare federated server:
|
||||
```shell
|
||||
/tmp/nvflare/poc/server/startup/start.sh
|
||||
```
|
||||
|
||||
In another terminal, start the first worker:
|
||||
```shell
|
||||
/tmp/nvflare/poc/site-1/startup/start.sh
|
||||
```
|
||||
|
||||
And the second worker:
|
||||
```shell
|
||||
/tmp/nvflare/poc/site-2/startup/start.sh
|
||||
```
|
||||
|
||||
Then start the admin CLI:
|
||||
```shell
|
||||
/tmp/nvflare/poc/admin/startup/fl_admin.sh
|
||||
```
|
||||
|
||||
In the admin CLI, run the following command:
|
||||
```shell
|
||||
submit_job vertical-xgboost
|
||||
```
|
||||
|
||||
Once the training finishes, the model file should be written into
|
||||
`/tmp/nvlfare/poc/site-1/run_1/test.model.json` and `/tmp/nvflare/poc/site-2/run_1/test.model.json`
|
||||
respectively.
|
||||
|
||||
Finally, shutdown everything from the admin CLI, using `admin` as password:
|
||||
```shell
|
||||
shutdown client
|
||||
shutdown server
|
||||
```
|
||||
|
||||
## Training with GPUs
|
||||
|
||||
Currently GPUs are not yet supported by vertical federated XGBoost.
|
||||
68
demo/nvflare/vertical/custom/controller.py
Normal file
68
demo/nvflare/vertical/custom/controller.py
Normal file
@ -0,0 +1,68 @@
|
||||
"""
|
||||
Example of training controller with NVFlare
|
||||
===========================================
|
||||
"""
|
||||
import multiprocessing
|
||||
|
||||
from nvflare.apis.client import Client
|
||||
from nvflare.apis.fl_context import FLContext
|
||||
from nvflare.apis.impl.controller import Controller, Task
|
||||
from nvflare.apis.shareable import Shareable
|
||||
from nvflare.apis.signal import Signal
|
||||
from trainer import SupportedTasks
|
||||
|
||||
import xgboost.federated
|
||||
|
||||
|
||||
class XGBoostController(Controller):
|
||||
def __init__(self, port: int, world_size: int, server_key_path: str,
|
||||
server_cert_path: str, client_cert_path: str):
|
||||
"""Controller for federated XGBoost.
|
||||
|
||||
Args:
|
||||
port: the port for the gRPC server to listen on.
|
||||
world_size: the number of sites.
|
||||
server_key_path: the path to the server key file.
|
||||
server_cert_path: the path to the server certificate file.
|
||||
client_cert_path: the path to the client certificate file.
|
||||
"""
|
||||
super().__init__()
|
||||
self._port = port
|
||||
self._world_size = world_size
|
||||
self._server_key_path = server_key_path
|
||||
self._server_cert_path = server_cert_path
|
||||
self._client_cert_path = client_cert_path
|
||||
self._server = None
|
||||
|
||||
def start_controller(self, fl_ctx: FLContext):
|
||||
self._server = multiprocessing.Process(
|
||||
target=xgboost.federated.run_federated_server,
|
||||
args=(self._port, self._world_size, self._server_key_path,
|
||||
self._server_cert_path, self._client_cert_path))
|
||||
self._server.start()
|
||||
|
||||
def stop_controller(self, fl_ctx: FLContext):
|
||||
if self._server:
|
||||
self._server.terminate()
|
||||
|
||||
def process_result_of_unknown_task(self, client: Client, task_name: str,
|
||||
client_task_id: str, result: Shareable,
|
||||
fl_ctx: FLContext):
|
||||
self.log_warning(fl_ctx, f"Unknown task: {task_name} from client {client.name}.")
|
||||
|
||||
def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
|
||||
self.log_info(fl_ctx, "XGBoost training control flow started.")
|
||||
if abort_signal.triggered:
|
||||
return
|
||||
task = Task(name=SupportedTasks.TRAIN, data=Shareable())
|
||||
self.broadcast_and_wait(
|
||||
task=task,
|
||||
min_responses=self._world_size,
|
||||
fl_ctx=fl_ctx,
|
||||
wait_time_after_min_received=1,
|
||||
abort_signal=abort_signal,
|
||||
)
|
||||
if abort_signal.triggered:
|
||||
return
|
||||
|
||||
self.log_info(fl_ctx, "XGBoost training control flow finished.")
|
||||
97
demo/nvflare/vertical/custom/trainer.py
Normal file
97
demo/nvflare/vertical/custom/trainer.py
Normal file
@ -0,0 +1,97 @@
|
||||
import os
|
||||
|
||||
from nvflare.apis.executor import Executor
|
||||
from nvflare.apis.fl_constant import FLContextKey, ReturnCode
|
||||
from nvflare.apis.fl_context import FLContext
|
||||
from nvflare.apis.shareable import Shareable, make_reply
|
||||
from nvflare.apis.signal import Signal
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import callback
|
||||
|
||||
|
||||
class SupportedTasks(object):
|
||||
TRAIN = "train"
|
||||
|
||||
|
||||
class XGBoostTrainer(Executor):
|
||||
def __init__(self, server_address: str, world_size: int, server_cert_path: str,
|
||||
client_key_path: str, client_cert_path: str):
|
||||
"""Trainer for federated XGBoost.
|
||||
|
||||
Args:
|
||||
server_address: address for the gRPC server to connect to.
|
||||
world_size: the number of sites.
|
||||
server_cert_path: the path to the server certificate file.
|
||||
client_key_path: the path to the client key file.
|
||||
client_cert_path: the path to the client certificate file.
|
||||
"""
|
||||
super().__init__()
|
||||
self._server_address = server_address
|
||||
self._world_size = world_size
|
||||
self._server_cert_path = server_cert_path
|
||||
self._client_key_path = client_key_path
|
||||
self._client_cert_path = client_cert_path
|
||||
|
||||
def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext,
|
||||
abort_signal: Signal) -> Shareable:
|
||||
self.log_info(fl_ctx, f"Executing {task_name}")
|
||||
try:
|
||||
if task_name == SupportedTasks.TRAIN:
|
||||
self._do_training(fl_ctx)
|
||||
return make_reply(ReturnCode.OK)
|
||||
else:
|
||||
self.log_error(fl_ctx, f"{task_name} is not a supported task.")
|
||||
return make_reply(ReturnCode.TASK_UNKNOWN)
|
||||
except BaseException as e:
|
||||
self.log_exception(fl_ctx,
|
||||
f"Task {task_name} failed. Exception: {e.__str__()}")
|
||||
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
|
||||
|
||||
def _do_training(self, fl_ctx: FLContext):
|
||||
client_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME)
|
||||
rank = int(client_name.split('-')[1]) - 1
|
||||
communicator_env = {
|
||||
'xgboost_communicator': 'federated',
|
||||
'federated_server_address': self._server_address,
|
||||
'federated_world_size': self._world_size,
|
||||
'federated_rank': rank,
|
||||
'federated_server_cert': self._server_cert_path,
|
||||
'federated_client_key': self._client_key_path,
|
||||
'federated_client_cert': self._client_cert_path
|
||||
}
|
||||
with xgb.collective.CommunicatorContext(**communicator_env):
|
||||
# Load file, file will not be sharded in federated mode.
|
||||
if rank == 0:
|
||||
label = '&label_column=0'
|
||||
else:
|
||||
label = ''
|
||||
dtrain = xgb.DMatrix(f'higgs.train.csv?format=csv{label}', data_split_mode=1)
|
||||
dtest = xgb.DMatrix(f'higgs.test.csv?format=csv{label}', data_split_mode=1)
|
||||
|
||||
# specify parameters via map
|
||||
param = {
|
||||
'validate_parameters': True,
|
||||
'eta': 0.1,
|
||||
'gamma': 1.0,
|
||||
'max_depth': 8,
|
||||
'min_child_weight': 100,
|
||||
'tree_method': 'approx',
|
||||
'grow_policy': 'depthwise',
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc',
|
||||
}
|
||||
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(dtest, "eval"), (dtrain, "train")]
|
||||
# number of boosting rounds
|
||||
num_round = 10
|
||||
|
||||
bst = xgb.train(param, dtrain, num_round, evals=watchlist, early_stopping_rounds=2)
|
||||
|
||||
# Save the model.
|
||||
workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT)
|
||||
run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN)
|
||||
run_dir = workspace.get_run_dir(run_number)
|
||||
bst.save_model(os.path.join(run_dir, "higgs.model.federated.vertical.json"))
|
||||
xgb.collective.communicator_print("Finished training\n")
|
||||
65
demo/nvflare/vertical/prepare_data.sh
Executable file
65
demo/nvflare/vertical/prepare_data.sh
Executable file
@ -0,0 +1,65 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
rm -fr ./*.pem /tmp/nvflare/poc
|
||||
|
||||
world_size=2
|
||||
|
||||
# Generate server and client certificates.
|
||||
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out server-cert.pem -subj "/C=US/CN=localhost"
|
||||
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost"
|
||||
|
||||
# Download HIGGS dataset.
|
||||
if [ -f "HIGGS.csv" ]; then
|
||||
echo "HIGGS.csv exists, skipping download."
|
||||
else
|
||||
echo "Downloading HIGGS dataset."
|
||||
wget https://archive.ics.uci.edu/ml/machine-learning-databases/00280/HIGGS.csv.gz
|
||||
gunzip HIGGS.csv.gz
|
||||
fi
|
||||
|
||||
# Split into train/test.
|
||||
if [[ -f higgs.train.csv && -f higgs.test.csv ]]; then
|
||||
echo "higgs.train.csv and higgs.test.csv exist, skipping split."
|
||||
else
|
||||
echo "Splitting HIGGS dataset into train/test."
|
||||
head -n 10450000 HIGGS.csv > higgs.train.csv
|
||||
tail -n 550000 HIGGS.csv > higgs.test.csv
|
||||
fi
|
||||
|
||||
# Split train and test files by column to simulate a federated environment.
|
||||
site_files=(higgs.{train,test}.csv-site-*)
|
||||
if [ ${#site_files[@]} -eq $((world_size*2)) ]; then
|
||||
echo "Site files exist, skipping split."
|
||||
else
|
||||
echo "Splitting train/test into site files."
|
||||
total_cols=28 # plus label
|
||||
cols=$((total_cols/world_size))
|
||||
echo "Columns per site: $cols"
|
||||
for (( site=1; site<=world_size; site++ )); do
|
||||
if (( site == 1 )); then
|
||||
start=$((cols*(site-1)+1))
|
||||
else
|
||||
start=$((cols*(site-1)+2))
|
||||
fi
|
||||
if (( site == world_size )); then
|
||||
end=$((total_cols+1))
|
||||
else
|
||||
end=$((cols*site+1))
|
||||
fi
|
||||
echo "Site $site, columns $start-$end"
|
||||
cut -d, -f${start}-${end} higgs.train.csv > higgs.train.csv-site-"${site}"
|
||||
cut -d, -f${start}-${end} higgs.test.csv > higgs.test.csv-site-"${site}"
|
||||
done
|
||||
fi
|
||||
|
||||
nvflare poc -n 2 --prepare
|
||||
mkdir -p /tmp/nvflare/poc/admin/transfer/vertical-xgboost
|
||||
cp -fr config custom /tmp/nvflare/poc/admin/transfer/vertical-xgboost
|
||||
cp server-*.pem client-cert.pem /tmp/nvflare/poc/server/
|
||||
for (( site=1; site<=world_size; site++ )); do
|
||||
cp server-cert.pem client-*.pem /tmp/nvflare/poc/site-"${site}"/
|
||||
ln -s "${PWD}"/higgs.train.csv-site-"${site}" /tmp/nvflare/poc/site-"${site}"/higgs.train.csv
|
||||
ln -s "${PWD}"/higgs.test.csv-site-"${site}" /tmp/nvflare/poc/site-"${site}"/higgs.test.csv
|
||||
done
|
||||
@ -105,7 +105,7 @@ def make_pysrc_wheel(release: str, outdir: str) -> None:
|
||||
os.mkdir(dist)
|
||||
|
||||
with DirectoryExcursion(os.path.join(ROOT, "python-package")):
|
||||
subprocess.check_call(["python", "setup.py", "sdist"])
|
||||
subprocess.check_call(["python", "-m", "build", "--sdist"])
|
||||
src = os.path.join(DIST, f"xgboost-{release}.tar.gz")
|
||||
subprocess.check_call(["twine", "check", src])
|
||||
shutil.move(src, os.path.join(dist, f"xgboost-{release}.tar.gz"))
|
||||
|
||||
732
doc/Doxyfile.in
732
doc/Doxyfile.in
File diff suppressed because it is too large
Load Diff
172
doc/build.rst
172
doc/build.rst
@ -12,6 +12,7 @@ systems. If the instructions do not work for you, please feel free to ask quest
|
||||
Consider installing XGBoost from a pre-built binary, to avoid the trouble of building XGBoost from the source. Checkout :doc:`Installation Guide </install>`.
|
||||
|
||||
.. contents:: Contents
|
||||
:local:
|
||||
|
||||
.. _get_source:
|
||||
|
||||
@ -152,11 +153,11 @@ On Windows, run CMake as follows:
|
||||
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. -G"Visual Studio 14 2015 Win64" -DUSE_CUDA=ON
|
||||
cmake .. -G"Visual Studio 17 2022" -A x64 -DUSE_CUDA=ON
|
||||
|
||||
(Change the ``-G`` option appropriately if you have a different version of Visual Studio installed.)
|
||||
|
||||
The above cmake configuration run will create an ``xgboost.sln`` solution file in the build directory. Build this solution in release mode as a x64 build, either from Visual studio or from command line:
|
||||
The above cmake configuration run will create an ``xgboost.sln`` solution file in the build directory. Build this solution in Release mode, either from Visual studio or from command line:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
@ -176,111 +177,104 @@ Building Python Package with Default Toolchains
|
||||
===============================================
|
||||
There are several ways to build and install the package from source:
|
||||
|
||||
1. Use Python setuptools directly
|
||||
1. Build C++ core with CMake first
|
||||
|
||||
The XGBoost Python package supports most of the setuptools commands, here is a list of tested commands:
|
||||
You can first build C++ library using CMake as described in :ref:`build_shared_lib`.
|
||||
After compilation, a shared library will appear in ``lib/`` directory.
|
||||
On Linux distributions, the shared library is ``lib/libxgboost.so``.
|
||||
The install script ``pip install .`` will reuse the shared library instead of compiling
|
||||
it from scratch, making it quite fast to run.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ cd python-package/
|
||||
$ pip install . # Will re-use lib/libxgboost.so
|
||||
|
||||
2. Install the Python package directly
|
||||
|
||||
You can navigate to ``python-package/`` directory and install the Python package directly
|
||||
by running
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ cd python-package/
|
||||
$ pip install -v .
|
||||
|
||||
which will compile XGBoost's native (C++) code using default CMake flags.
|
||||
To enable additional compilation options, pass corresponding ``--config-settings``:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install -v . --config-settings use_cuda=True --config-settings use_nccl=True
|
||||
|
||||
Use Pip 22.1 or later to use ``--config-settings`` option.
|
||||
|
||||
Here are the available options for ``--config-settings``:
|
||||
|
||||
.. literalinclude:: ../python-package/packager/build_config.py
|
||||
:language: python
|
||||
:start-at: @dataclasses.dataclass
|
||||
:end-before: def _set_config_setting(
|
||||
|
||||
``use_system_libxgboost`` is a special option. See Item 4 below for
|
||||
detailed description.
|
||||
|
||||
.. note:: Verbose flag recommended
|
||||
|
||||
As ``pip install .`` will build C++ code, it will take a while to complete.
|
||||
To ensure that the build is progressing successfully, we suggest that
|
||||
you add the verbose flag (``-v``) when invoking ``pip install``.
|
||||
|
||||
|
||||
3. Editable installation
|
||||
|
||||
To further enable rapid development and iteration, we provide an **editable installation**.
|
||||
In an editable installation, the installed package is simply a symbolic link to your
|
||||
working copy of the XGBoost source code. So every changes you make to your source
|
||||
directory will be immediately visible to the Python interpreter. Here is how to
|
||||
install XGBoost as editable installation:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python setup.py install # Install the XGBoost to your current Python environment.
|
||||
python setup.py build # Build the Python package.
|
||||
python setup.py build_ext # Build only the C++ core.
|
||||
python setup.py sdist # Create a source distribution
|
||||
python setup.py bdist # Create a binary distribution
|
||||
python setup.py bdist_wheel # Create a binary distribution with wheel format
|
||||
|
||||
Running ``python setup.py install`` will compile XGBoost using default CMake flags. For
|
||||
passing additional compilation options, append the flags to the command. For example,
|
||||
to enable CUDA acceleration and NCCL (distributed GPU) support:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python setup.py install --use-cuda --use-nccl
|
||||
|
||||
Please refer to ``setup.py`` for a complete list of available options. Some other
|
||||
options used for development are only available for using CMake directly. See next
|
||||
section on how to use CMake with setuptools manually.
|
||||
|
||||
You can install the created distribution packages using pip. For example, after running
|
||||
``sdist`` setuptools command, a tar ball similar to ``xgboost-1.0.0.tar.gz`` will be
|
||||
created under the ``dist`` directory. Then you can install it by invoking the following
|
||||
command under ``dist`` directory:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# under python-package directory
|
||||
cd dist
|
||||
pip install ./xgboost-1.0.0.tar.gz
|
||||
|
||||
|
||||
For details about these commands, please refer to the official document of `setuptools
|
||||
<https://setuptools.readthedocs.io/en/latest/>`_, or just Google "how to install Python
|
||||
package from source". XGBoost Python package follows the general convention.
|
||||
Setuptools is usually available with your Python distribution, if not you can install it
|
||||
via system command. For example on Debian or Ubuntu:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
sudo apt-get install python-setuptools
|
||||
|
||||
|
||||
For cleaning up the directory after running above commands, ``python setup.py clean`` is
|
||||
not sufficient. After copying out the build result, simply running ``git clean -xdf``
|
||||
under ``python-package`` is an efficient way to remove generated cache files. If you
|
||||
find weird behaviors in Python build or running linter, it might be caused by those
|
||||
cached files.
|
||||
|
||||
For using develop command (editable installation), see next section.
|
||||
|
||||
.. code-block::
|
||||
|
||||
python setup.py develop # Create a editable installation.
|
||||
pip install -e . # Same as above, but carried out by pip.
|
||||
|
||||
|
||||
2. Build C++ core with CMake first
|
||||
|
||||
This is mostly for C++ developers who don't want to go through the hooks in Python
|
||||
setuptools. You can build C++ library directly using CMake as described in above
|
||||
sections. After compilation, a shared object (or called dynamic linked library, jargon
|
||||
depending on your platform) will appear in XGBoost's source tree under ``lib/``
|
||||
directory. On Linux distributions it's ``lib/libxgboost.so``. From there all Python
|
||||
setuptools commands will reuse that shared object instead of compiling it again. This
|
||||
is especially convenient if you are using the editable installation, where the installed
|
||||
package is simply a link to the source tree. We can perform rapid testing during
|
||||
development. Here is a simple bash script does that:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Under xgboost source tree.
|
||||
# Under xgboost source directory
|
||||
mkdir build
|
||||
cd build
|
||||
cmake ..
|
||||
make -j$(nproc)
|
||||
# Build shared library libxgboost.so
|
||||
cmake .. -GNinja
|
||||
ninja
|
||||
# Install as editable installation
|
||||
cd ../python-package
|
||||
pip install -e . # or equivalently python setup.py develop
|
||||
pip install -e .
|
||||
|
||||
3. Use ``libxgboost.so`` on system path.
|
||||
4. Use ``libxgboost.so`` on system path.
|
||||
|
||||
This is for distributing xgboost in a language independent manner, where
|
||||
``libxgboost.so`` is separately packaged with Python package. Assuming `libxgboost.so`
|
||||
is already presented in system library path, which can be queried via:
|
||||
This option is useful for package managers that wish to separately package
|
||||
``libxgboost.so`` and the XGBoost Python package. For example, Conda
|
||||
publishes ``libxgboost`` (for the shared library) and ``py-xgboost``
|
||||
(for the Python package).
|
||||
|
||||
To use this option, first make sure that ``libxgboost.so`` exists in the system library path:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import sys
|
||||
import os
|
||||
os.path.join(sys.prefix, 'lib')
|
||||
import pathlib
|
||||
libpath = pathlib.Path(sys.prefix).joinpath("lib", "libxgboost.so")
|
||||
assert libpath.exists()
|
||||
|
||||
Then one only needs to provide an user option when installing Python package to reuse the
|
||||
shared object in system path:
|
||||
Then pass ``use_system_libxgboost=True`` option to ``pip install``:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd xgboost/python-package
|
||||
python setup.py install --use-system-libxgboost
|
||||
cd python-package
|
||||
pip install . --config-settings use_system_libxgboost=True
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
See :doc:`contrib/python_packaging` for instructions on packaging
|
||||
and distributing XGBoost as Python distributions.
|
||||
|
||||
.. _python_mingw:
|
||||
|
||||
Building Python Package for Windows with MinGW-w64 (Advanced)
|
||||
@ -297,7 +291,7 @@ So you may want to build XGBoost with GCC own your own risk. This presents some
|
||||
2. ``-O3`` is OK.
|
||||
3. ``-mtune=native`` is also OK.
|
||||
4. Don't use ``-march=native`` gcc flag. Using it causes the Python interpreter to crash if the DLL was actually used.
|
||||
5. You may need to provide the lib with the runtime libs. If ``mingw32/bin`` is not in ``PATH``, build a wheel (``python setup.py bdist_wheel``), open it with an archiver and put the needed dlls to the directory where ``xgboost.dll`` is situated. Then you can install the wheel with ``pip``.
|
||||
5. You may need to provide the lib with the runtime libs. If ``mingw32/bin`` is not in ``PATH``, build a wheel (``pip wheel``), open it with an archiver and put the needed dlls to the directory where ``xgboost.dll`` is situated. Then you can install the wheel with ``pip``.
|
||||
|
||||
******************************
|
||||
Building R Package From Source
|
||||
|
||||
@ -35,8 +35,9 @@ calls ``cibuildwheel`` to build the wheel. The ``cibuildwheel`` is a library tha
|
||||
suitable Python environment for each OS and processor target. Since we don't have Apple Silion
|
||||
machine in GitHub Actions, cross-compilation is needed; ``cibuildwheel`` takes care of the complex
|
||||
task of cross-compiling a Python wheel. (Note that ``cibuildwheel`` will call
|
||||
``setup.py bdist_wheel``. Since XGBoost has a native library component, ``setup.py`` contains
|
||||
a glue code to call CMake and a C++ compiler to build the native library on the fly.)
|
||||
``pip wheel``. Since XGBoost has a native library component, we created a customized build
|
||||
backend that hooks into ``pip``. The customized backend contains the glue code to compile the native
|
||||
library on the fly.)
|
||||
|
||||
*********************************************************
|
||||
Reproduce CI testing environments using Docker containers
|
||||
|
||||
@ -23,6 +23,7 @@ Here are guidelines for contributing to various aspect of the XGBoost project:
|
||||
Community Guideline <community>
|
||||
donate
|
||||
coding_guide
|
||||
python_packaging
|
||||
unit_tests
|
||||
Docs and Examples <docs>
|
||||
git_guide
|
||||
|
||||
83
doc/contrib/python_packaging.rst
Normal file
83
doc/contrib/python_packaging.rst
Normal file
@ -0,0 +1,83 @@
|
||||
###########################################
|
||||
Notes on packaging XGBoost's Python package
|
||||
###########################################
|
||||
|
||||
|
||||
.. contents:: Contents
|
||||
:local:
|
||||
|
||||
.. _packaging_python_xgboost:
|
||||
|
||||
***************************************************
|
||||
How to build binary wheels and source distributions
|
||||
***************************************************
|
||||
|
||||
Wheels and source distributions (sdist for short) are the two main
|
||||
mechanisms for packaging and distributing Python packages.
|
||||
|
||||
* A **source distribution** (sdist) is a tarball (``.tar.gz`` extension) that
|
||||
contains the source code.
|
||||
* A **wheel** is a ZIP-compressed archive (with ``.whl`` extension)
|
||||
representing a *built* distribution. Unlike an sdist, a wheel can contain
|
||||
compiled components. The compiled components are compiled prior to distribution,
|
||||
making it more convenient for end-users to install a wheel. Wheels containing
|
||||
compiled components are referred to as **binary wheels**.
|
||||
|
||||
See `Python Packaging User Guide <https://packaging.python.org/en/latest/>`_
|
||||
to learn more about how Python packages in general are packaged and
|
||||
distributed.
|
||||
|
||||
For the remainder of this document, we will focus on packaging and
|
||||
distributing XGBoost.
|
||||
|
||||
Building sdists
|
||||
===============
|
||||
|
||||
In the case of XGBoost, an sdist contains both the Python code as well as
|
||||
the C++ code, so that the core part of XGBoost can be compiled into the
|
||||
shared libary ``libxgboost.so`` [#shared_lib_name]_.
|
||||
|
||||
You can obtain an sdist as follows:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python -m build --sdist .
|
||||
|
||||
(You'll need to install the ``build`` package first:
|
||||
``pip install build`` or ``conda install python-build``.)
|
||||
|
||||
Running ``pip install`` with an sdist will launch CMake and a C++ compiler
|
||||
to compile the bundled C++ code into ``libxgboost.so``:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install -v xgboost-2.0.0.tar.gz # Add -v to show build progress
|
||||
|
||||
Building binary wheels
|
||||
======================
|
||||
|
||||
You can also build a wheel as follows:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip wheel --no-deps -v .
|
||||
|
||||
Notably, the resulting wheel contains a copy of the shared library
|
||||
``libxgboost.so`` [#shared_lib_name]_. The wheel is a **binary wheel**,
|
||||
since it contains a compiled binary.
|
||||
|
||||
|
||||
Running ``pip install`` with the binary wheel will extract the content of
|
||||
the wheel into the current Python environment. Since the wheel already
|
||||
contains a pre-built copy of ``libxgboost.so``, it does not have to be
|
||||
built at the time of install. So ``pip install`` with the binary wheel
|
||||
completes quickly:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install xgboost-2.0.0-py3-none-linux_x86_64.whl # Completes quickly
|
||||
|
||||
.. rubric:: Footnotes
|
||||
|
||||
.. [#shared_lib_name] The name of the shared library file will differ
|
||||
depending on the operating system in use. See :ref:`build_shared_lib`.
|
||||
@ -16,15 +16,28 @@ Stable Release
|
||||
Python
|
||||
------
|
||||
|
||||
Pre-built binary are uploaded to PyPI (Python Package Index) for each release. Supported platforms are Linux (x86_64, aarch64), Windows (x86_64) and MacOS (x86_64, Apple Silicon).
|
||||
Pre-built binary wheels are uploaded to PyPI (Python Package Index) for each release. Supported platforms are Linux (x86_64, aarch64), Windows (x86_64) and MacOS (x86_64, Apple Silicon).
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Pip 21.3+ is required
|
||||
pip install xgboost
|
||||
|
||||
|
||||
You might need to run the command with ``--user`` flag or use ``virtualenv`` if you run
|
||||
into permission errors. Python pre-built binary capability for each platform:
|
||||
into permission errors.
|
||||
|
||||
.. note:: Windows users need to install Visual C++ Redistributable
|
||||
|
||||
XGBoost requires DLLs from `Visual C++ Redistributable
|
||||
<https://www.microsoft.com/en-us/download/details.aspx?id=48145>`_
|
||||
in order to function, so make sure to install it. Exception: If
|
||||
you have Visual Studio installed, you already have access to
|
||||
necessary libraries and thus don't need to install Visual C++
|
||||
Redistributable.
|
||||
|
||||
|
||||
Capabilities of binary wheels for each platform:
|
||||
|
||||
.. |tick| unicode:: U+2714
|
||||
.. |cross| unicode:: U+2718
|
||||
|
||||
@ -41,3 +41,7 @@ Contents
|
||||
XGBoost4J Scala API <scaladocs/xgboost4j/index>
|
||||
XGBoost4J-Spark Scala API <scaladocs/xgboost4j-spark/index>
|
||||
XGBoost4J-Flink Scala API <scaladocs/xgboost4j-flink/index>
|
||||
|
||||
.. note::
|
||||
|
||||
Please note that the flink interface is still under construction.
|
||||
|
||||
@ -219,6 +219,16 @@
|
||||
"num_pairsample": { "type": "string" },
|
||||
"fix_list_weight": { "type": "string" }
|
||||
}
|
||||
},
|
||||
"lambdarank_param": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"lambdarank_num_pair_per_sample": { "type": "string" },
|
||||
"lambdarank_pair_method": { "type": "string" },
|
||||
"lambdarank_unbiased": {"type": "string" },
|
||||
"lambdarank_bias_norm": {"type": "string" },
|
||||
"ndcg_exp_gain": {"type": "string"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
@ -477,22 +487,22 @@
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "const": "rank:pairwise" },
|
||||
"lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
|
||||
"lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"}
|
||||
},
|
||||
"required": [
|
||||
"name",
|
||||
"lambda_rank_param"
|
||||
"lambdarank_param"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "const": "rank:ndcg" },
|
||||
"lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
|
||||
"lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"}
|
||||
},
|
||||
"required": [
|
||||
"name",
|
||||
"lambda_rank_param"
|
||||
"lambdarank_param"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -380,9 +380,9 @@ Specify the learning task and the corresponding learning objective. The objectiv
|
||||
See :doc:`/tutorials/aft_survival_analysis` for details.
|
||||
- ``multi:softmax``: set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes)
|
||||
- ``multi:softprob``: same as softmax, but output a vector of ``ndata * nclass``, which can be further reshaped to ``ndata * nclass`` matrix. The result contains predicted probability of each data point belonging to each class.
|
||||
- ``rank:pairwise``: Use LambdaMART to perform pairwise ranking where the pairwise loss is minimized
|
||||
- ``rank:ndcg``: Use LambdaMART to perform list-wise ranking where `Normalized Discounted Cumulative Gain (NDCG) <http://en.wikipedia.org/wiki/NDCG>`_ is maximized
|
||||
- ``rank:map``: Use LambdaMART to perform list-wise ranking where `Mean Average Precision (MAP) <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_ is maximized
|
||||
- ``rank:ndcg``: Use LambdaMART to perform pair-wise ranking where `Normalized Discounted Cumulative Gain (NDCG) <http://en.wikipedia.org/wiki/NDCG>`_ is maximized. This objective supports position debiasing for click data.
|
||||
- ``rank:map``: Use LambdaMART to perform pair-wise ranking where `Mean Average Precision (MAP) <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_ is maximized
|
||||
- ``rank:pairwise``: Use LambdaRank to perform pair-wise ranking using the `ranknet` objective.
|
||||
- ``reg:gamma``: gamma regression with log-link. Output is a mean of gamma distribution. It might be useful, e.g., for modeling insurance claims severity, or for any outcome that might be `gamma-distributed <https://en.wikipedia.org/wiki/Gamma_distribution#Occurrence_and_applications>`_.
|
||||
- ``reg:tweedie``: Tweedie regression with log-link. It might be useful, e.g., for modeling total loss in insurance, or for any outcome that might be `Tweedie-distributed <https://en.wikipedia.org/wiki/Tweedie_distribution#Occurrence_and_applications>`_.
|
||||
|
||||
@ -395,8 +395,9 @@ Specify the learning task and the corresponding learning objective. The objectiv
|
||||
|
||||
* ``eval_metric`` [default according to objective]
|
||||
|
||||
- Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, mean average precision for ranking)
|
||||
- User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous one
|
||||
- Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, `mean average precision` for ``rank:map``, etc.)
|
||||
- User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous ones
|
||||
|
||||
- The choices are listed below:
|
||||
|
||||
- ``rmse``: `root mean square error <http://en.wikipedia.org/wiki/Root_mean_square_error>`_
|
||||
@ -480,6 +481,36 @@ Parameter for using AFT Survival Loss (``survival:aft``) and Negative Log Likeli
|
||||
|
||||
* ``aft_loss_distribution``: Probability Density Function, ``normal``, ``logistic``, or ``extreme``.
|
||||
|
||||
.. _ltr-param:
|
||||
|
||||
Parameters for learning to rank (``rank:ndcg``, ``rank:map``, ``rank:pairwise``)
|
||||
================================================================================
|
||||
|
||||
These are parameters specific to learning to rank task. See :doc:`Learning to Rank </tutorials/learning_to_rank>` for an in-depth explanation.
|
||||
|
||||
* ``lambdarank_pair_method`` [default = ``mean``]
|
||||
|
||||
How to construct pairs for pair-wise learning.
|
||||
|
||||
- ``mean``: Sample ``lambdarank_num_pair_per_sample`` pairs for each document in the query list.
|
||||
- ``topk``: Focus on top-``lambdarank_num_pair_per_sample`` documents. Construct :math:`|query|` pairs for each document at the top-``lambdarank_num_pair_per_sample`` ranked by the model.
|
||||
|
||||
* ``lambdarank_num_pair_per_sample`` [range = :math:`[1, \infty]`]
|
||||
|
||||
It specifies the number of pairs sampled for each document when pair method is ``mean``, or the truncation level for queries when the pair method is ``topk``. For example, to train with ``ndcg@6``, set ``lambdarank_num_pair_per_sample`` to :math:`6` and ``lambdarank_pair_method`` to ``topk``.
|
||||
|
||||
* ``lambdarank_unbiased`` [default = ``false``]
|
||||
|
||||
Specify whether do we need to debias input click data.
|
||||
|
||||
* ``lambdarank_bias_norm`` [default = 2.0]
|
||||
|
||||
:math:`L_p` normalization for position debiasing, default is :math:`L_2`. Only relevant when ``lambdarank_unbiased`` is set to true.
|
||||
|
||||
* ``ndcg_exp_gain`` [default = ``true``]
|
||||
|
||||
Whether we should use exponential gain function for ``NDCG``. There are two forms of gain function for ``NDCG``, one is using relevance value directly while the other is using :math:`2^{rel} - 1` to emphasize on retrieving relevant documents. When ``ndcg_exp_gain`` is true (the default), relevance degree cannot be greater than 31.
|
||||
|
||||
***********************
|
||||
Command Line Parameters
|
||||
***********************
|
||||
|
||||
@ -23,7 +23,7 @@ Requirements
|
||||
|
||||
Dask can be installed using either pip or conda (see the dask `installation
|
||||
documentation <https://docs.dask.org/en/latest/install.html>`_ for more information). For
|
||||
accelerating XGBoost with GPUs, `dask-cuda <https://github.com/rapidsai/dask-cuda>`_ is
|
||||
accelerating XGBoost with GPUs, `dask-cuda <https://github.com/rapidsai/dask-cuda>`__ is
|
||||
recommended for creating GPU clusters.
|
||||
|
||||
|
||||
|
||||
@ -77,7 +77,7 @@ The external memory version takes in the following `URI <https://en.wikipedia.or
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
filename#cacheprefix
|
||||
filename?format=libsvm#cacheprefix
|
||||
|
||||
The ``filename`` is the normal path to LIBSVM format file you want to load in, and
|
||||
``cacheprefix`` is a path to a cache file that XGBoost will use for caching preprocessed
|
||||
@ -97,13 +97,13 @@ you have a dataset stored in a file similar to ``agaricus.txt.train`` with LIBSV
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
dtrain = DMatrix('../data/agaricus.txt.train#dtrain.cache')
|
||||
dtrain = DMatrix('../data/agaricus.txt.train?format=libsvm#dtrain.cache')
|
||||
|
||||
XGBoost will first load ``agaricus.txt.train`` in, preprocess it, then write to a new file named
|
||||
``dtrain.cache`` as an on disk cache for storing preprocessed data in an internal binary format. For
|
||||
more notes about text input formats, see :doc:`/tutorials/input_format`.
|
||||
|
||||
For CLI version, simply add the cache suffix, e.g. ``"../data/agaricus.txt.train#dtrain.cache"``.
|
||||
For CLI version, simply add the cache suffix, e.g. ``"../data/agaricus.txt.train?format=libsvm#dtrain.cache"``.
|
||||
|
||||
|
||||
**********************************
|
||||
|
||||
@ -2,10 +2,15 @@
|
||||
Text Input Format of DMatrix
|
||||
############################
|
||||
|
||||
.. _basic_input_format:
|
||||
|
||||
Here we will briefly describe the text input formats for XGBoost. However, for users with access to a supported language environment like Python or R, it's recommended to use data parsers from that ecosystem instead. For instance, :py:func:`sklearn.datasets.load_svmlight_file`.
|
||||
|
||||
******************
|
||||
Basic Input Format
|
||||
******************
|
||||
XGBoost currently supports two text formats for ingesting data: LIBSVM and CSV. The rest of this document will describe the LIBSVM format. (See `this Wikipedia article <https://en.wikipedia.org/wiki/Comma-separated_values>`_ for a description of the CSV format.). Please be careful that, XGBoost does **not** understand file extensions, nor try to guess the file format, as there is no universal agreement upon file extension of LIBSVM or CSV. Instead it employs `URI <https://en.wikipedia.org/wiki/Uniform_Resource_Identifier>`_ format for specifying the precise input file type. For example if you provide a `csv` file ``./data.train.csv`` as input, XGBoost will blindly use the default LIBSVM parser to digest it and generate a parser error. Instead, users need to provide an URI in the form of ``train.csv?format=csv``. For external memory input, the URI should of a form similar to ``train.csv?format=csv#dtrain.cache``. See :ref:`python_data_interface` and :doc:`/tutorials/external_memory` also.
|
||||
|
||||
XGBoost currently supports two text formats for ingesting data: LIBSVM and CSV. The rest of this document will describe the LIBSVM format. (See `this Wikipedia article <https://en.wikipedia.org/wiki/Comma-separated_values>`_ for a description of the CSV format.). Please be careful that, XGBoost does **not** understand file extensions, nor try to guess the file format, as there is no universal agreement upon file extension of LIBSVM or CSV. Instead it employs `URI <https://en.wikipedia.org/wiki/Uniform_Resource_Identifier>`_ format for specifying the precise input file type. For example if you provide a `csv` file ``./data.train.csv`` as input, XGBoost will blindly use the default LIBSVM parser to digest it and generate a parser error. Instead, users need to provide an URI in the form of ``train.csv?format=csv`` or ``train.csv?format=libsvm``. For external memory input, the URI should of a form similar to ``train.csv?format=csv#dtrain.cache``. See :ref:`python_data_interface` and :doc:`/tutorials/external_memory` also.
|
||||
|
||||
For training or predicting, XGBoost takes an instance file with the format as below:
|
||||
|
||||
|
||||
@ -108,8 +108,8 @@ virtualenv and pip:
|
||||
python -m venv xgboost_env
|
||||
source xgboost_env/bin/activate
|
||||
pip install pyarrow pandas venv-pack xgboost
|
||||
# https://rapids.ai/pip.html#install
|
||||
pip install cudf-cu11 --extra-index-url=https://pypi.ngc.nvidia.com
|
||||
# https://docs.rapids.ai/install#pip-install
|
||||
pip install cudf-cu11 --extra-index-url=https://pypi.nvidia.com
|
||||
venv-pack -o xgboost_env.tar.gz
|
||||
|
||||
With Conda:
|
||||
@ -241,7 +241,7 @@ additional spark configurations and dependencies:
|
||||
--master spark://<master-ip>:7077 \
|
||||
--conf spark.executor.resource.gpu.amount=1 \
|
||||
--conf spark.task.resource.gpu.amount=1 \
|
||||
--packages com.nvidia:rapids-4-spark_2.12:22.08.0 \
|
||||
--packages com.nvidia:rapids-4-spark_2.12:23.04.0 \
|
||||
--conf spark.plugins=com.nvidia.spark.SQLPlugin \
|
||||
--conf spark.sql.execution.arrow.maxRecordsPerBatch=1000000 \
|
||||
--archives xgboost_env.tar.gz#environment \
|
||||
|
||||
@ -38,7 +38,7 @@ typedef uint64_t bst_ulong; // NOLINT(*)
|
||||
*/
|
||||
|
||||
/**
|
||||
* @defgroup Library
|
||||
* @defgroup Library Library
|
||||
*
|
||||
* These functions are used to obtain general information about XGBoost including version,
|
||||
* build info and current global configuration.
|
||||
@ -112,7 +112,7 @@ XGB_DLL int XGBGetGlobalConfig(char const **out_config);
|
||||
/**@}*/
|
||||
|
||||
/**
|
||||
* @defgroup DMatrix
|
||||
* @defgroup DMatrix DMatrix
|
||||
*
|
||||
* @brief DMatrix is the baisc data storage for XGBoost used by all XGBoost algorithms
|
||||
* including both training, prediction and explanation. There are a few variants of
|
||||
@ -138,7 +138,11 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle
|
||||
/*!
|
||||
* \brief load a data matrix
|
||||
* \param config JSON encoded parameters for DMatrix construction. Accepted fields are:
|
||||
* - uri: The URI of the input file.
|
||||
|
||||
* - uri: The URI of the input file. The URI parameter `format` is required when loading text data.
|
||||
* \verbatim embed:rst:leading-asterisk
|
||||
* See :doc:`/tutorials/input_format` for more info.
|
||||
* \endverbatim
|
||||
* - silent (optional): Whether to print message during loading. Default to true.
|
||||
* - data_split_mode (optional): Whether to split by row or column. In distributed mode, the
|
||||
* file is split accordingly; otherwise this is only an indicator on how the file was split
|
||||
@ -200,7 +204,7 @@ XGB_DLL int XGDMatrixCreateFromDense(char const *data, char const *config, DMatr
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGDMatrixCreateFromCSC(char const *indptr, char const *indices, char const *data,
|
||||
bst_ulong nrow, char const *c_json_config, DMatrixHandle *out);
|
||||
bst_ulong nrow, char const *config, DMatrixHandle *out);
|
||||
|
||||
/*!
|
||||
* \brief create a matrix content from CSC format
|
||||
@ -281,7 +285,7 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data, char const *
|
||||
DMatrixHandle *out);
|
||||
|
||||
/**
|
||||
* @defgroup Streaming
|
||||
* @defgroup Streaming Streaming
|
||||
* @ingroup DMatrix
|
||||
*
|
||||
* @brief Quantile DMatrix and external memory DMatrix can be created from batches of
|
||||
@ -431,7 +435,7 @@ XGB_EXTERN_C typedef void DataIterResetCallback(DataIterHandle handle); // NOLIN
|
||||
* - Step 0: Define a data iterator with 2 methods `reset`, and `next`.
|
||||
* - Step 1: Create a DMatrix proxy by \ref XGProxyDMatrixCreate and hold the handle.
|
||||
* - Step 2: Pass the iterator handle, proxy handle and 2 methods into
|
||||
* `XGDMatrixCreateFromCallback`, along with other parameters encoded as a JSON object.
|
||||
* \ref XGDMatrixCreateFromCallback, along with other parameters encoded as a JSON object.
|
||||
* - Step 3: Call appropriate data setters in `next` functions.
|
||||
*
|
||||
* \param iter A handle to external data iterator.
|
||||
@ -830,7 +834,7 @@ XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config
|
||||
/** @} */ // End of DMatrix
|
||||
|
||||
/**
|
||||
* @defgroup Booster
|
||||
* @defgroup Booster Booster
|
||||
*
|
||||
* @brief The `Booster` class is the gradient-boosted model for XGBoost.
|
||||
* @{
|
||||
@ -953,7 +957,7 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle, int iter, DMatrixHandle d
|
||||
*/
|
||||
|
||||
/**
|
||||
* @defgroup Prediction
|
||||
* @defgroup Prediction Prediction
|
||||
* @ingroup Booster
|
||||
*
|
||||
* @brief These functions are used for running prediction and explanation algorithms.
|
||||
@ -1155,7 +1159,7 @@ XGB_DLL int XGBoosterPredictFromCudaColumnar(BoosterHandle handle, char const *v
|
||||
|
||||
|
||||
/**
|
||||
* @defgroup Serialization
|
||||
* @defgroup Serialization Serialization
|
||||
* @ingroup Booster
|
||||
*
|
||||
* @brief There are multiple ways to serialize a Booster object depending on the use case.
|
||||
@ -1490,7 +1494,7 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config,
|
||||
/**@}*/ // End of Booster
|
||||
|
||||
/**
|
||||
* @defgroup Collective
|
||||
* @defgroup Collective Collective
|
||||
*
|
||||
* @brief Experimental support for exposing internal communicator in XGBoost.
|
||||
*
|
||||
|
||||
@ -50,7 +50,19 @@ struct Context : public XGBoostParameter<Context> {
|
||||
|
||||
bool IsCPU() const { return gpu_id == kCpuId; }
|
||||
bool IsCUDA() const { return !IsCPU(); }
|
||||
|
||||
CUDAContext const* CUDACtx() const;
|
||||
// Make a CUDA context based on the current context.
|
||||
Context MakeCUDA(std::int32_t device = 0) const {
|
||||
Context ctx = *this;
|
||||
ctx.gpu_id = device;
|
||||
return ctx;
|
||||
}
|
||||
Context MakeCPU() const {
|
||||
Context ctx = *this;
|
||||
ctx.gpu_id = kCpuId;
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(Context) {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2015-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2015-2023 by XGBoost Contributors
|
||||
* \file data.h
|
||||
* \brief The input data structure of xgboost.
|
||||
* \author Tianqi Chen
|
||||
@ -196,6 +196,14 @@ class MetaInfo {
|
||||
*/
|
||||
bool IsVerticalFederated() const;
|
||||
|
||||
/*!
|
||||
* \brief A convenient method to check if the MetaInfo should contain labels.
|
||||
*
|
||||
* Normally we assume labels are available everywhere. The only exception is in vertical federated
|
||||
* learning where labels are only available on worker 0.
|
||||
*/
|
||||
bool ShouldHaveLabels() const;
|
||||
|
||||
private:
|
||||
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
|
||||
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
|
||||
@ -230,44 +238,72 @@ struct Entry {
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Parameters for constructing batches.
|
||||
/**
|
||||
* \brief Parameters for constructing histogram index batches.
|
||||
*/
|
||||
struct BatchParam {
|
||||
/*! \brief The GPU device to use. */
|
||||
int gpu_id {-1};
|
||||
/*! \brief Maximum number of bins per feature for histograms. */
|
||||
/**
|
||||
* \brief Maximum number of bins per feature for histograms.
|
||||
*/
|
||||
bst_bin_t max_bin{0};
|
||||
/*! \brief Hessian, used for sketching with future approx implementation. */
|
||||
/**
|
||||
* \brief Hessian, used for sketching with future approx implementation.
|
||||
*/
|
||||
common::Span<float> hess;
|
||||
/*! \brief Whether should DMatrix regenerate the batch. Only used for GHistIndex. */
|
||||
/**
|
||||
* \brief Whether should we force DMatrix to regenerate the batch. Only used for
|
||||
* GHistIndex.
|
||||
*/
|
||||
bool regen{false};
|
||||
/*! \brief Parameter used to generate column matrix for hist. */
|
||||
/**
|
||||
* \brief Forbid regenerating the gradient index. Used for internal validation.
|
||||
*/
|
||||
bool forbid_regen{false};
|
||||
/**
|
||||
* \brief Parameter used to generate column matrix for hist.
|
||||
*/
|
||||
double sparse_thresh{std::numeric_limits<double>::quiet_NaN()};
|
||||
|
||||
/**
|
||||
* \brief Exact or others that don't need histogram.
|
||||
*/
|
||||
BatchParam() = default;
|
||||
// GPU Hist
|
||||
BatchParam(int32_t device, bst_bin_t max_bin)
|
||||
: gpu_id{device}, max_bin{max_bin} {}
|
||||
// Hist
|
||||
/**
|
||||
* \brief Used by the hist tree method.
|
||||
*/
|
||||
BatchParam(bst_bin_t max_bin, double sparse_thresh)
|
||||
: max_bin{max_bin}, sparse_thresh{sparse_thresh} {}
|
||||
// Approx
|
||||
/**
|
||||
* \brief Get batch with sketch weighted by hessian. The batch will be regenerated if
|
||||
* the span is changed, so caller should keep the span for each iteration.
|
||||
* \brief Used by the approx tree method.
|
||||
*
|
||||
* Get batch with sketch weighted by hessian. The batch will be regenerated if the
|
||||
* span is changed, so caller should keep the span for each iteration.
|
||||
*/
|
||||
BatchParam(bst_bin_t max_bin, common::Span<float> hessian, bool regenerate)
|
||||
: max_bin{max_bin}, hess{hessian}, regen{regenerate} {}
|
||||
|
||||
bool operator!=(BatchParam const& other) const {
|
||||
if (hess.empty() && other.hess.empty()) {
|
||||
return gpu_id != other.gpu_id || max_bin != other.max_bin;
|
||||
bool ParamNotEqual(BatchParam const& other) const {
|
||||
// Check non-floating parameters.
|
||||
bool cond = max_bin != other.max_bin;
|
||||
// Check sparse thresh.
|
||||
bool l_nan = std::isnan(sparse_thresh);
|
||||
bool r_nan = std::isnan(other.sparse_thresh);
|
||||
bool st_chg = (l_nan != r_nan) || (!l_nan && !r_nan && (sparse_thresh != other.sparse_thresh));
|
||||
cond |= st_chg;
|
||||
|
||||
return cond;
|
||||
}
|
||||
return gpu_id != other.gpu_id || max_bin != other.max_bin || hess.data() != other.hess.data();
|
||||
}
|
||||
bool operator==(BatchParam const& other) const {
|
||||
return !(*this != other);
|
||||
bool Initialized() const { return max_bin != 0; }
|
||||
/**
|
||||
* \brief Make a copy of self for DMatrix to describe how its existing index was generated.
|
||||
*/
|
||||
BatchParam MakeCache() const {
|
||||
auto p = *this;
|
||||
// These parameters have nothing to do with how the gradient index was generated in the
|
||||
// first place.
|
||||
p.regen = false;
|
||||
p.forbid_regen = false;
|
||||
return p;
|
||||
}
|
||||
};
|
||||
|
||||
@ -427,7 +463,7 @@ class EllpackPage {
|
||||
* This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix
|
||||
* in CSR format.
|
||||
*/
|
||||
explicit EllpackPage(DMatrix* dmat, const BatchParam& param);
|
||||
explicit EllpackPage(Context const* ctx, DMatrix* dmat, const BatchParam& param);
|
||||
|
||||
/*! \brief Destructor. */
|
||||
~EllpackPage();
|
||||
@ -543,7 +579,9 @@ class DMatrix {
|
||||
template <typename T>
|
||||
BatchSet<T> GetBatches();
|
||||
template <typename T>
|
||||
BatchSet<T> GetBatches(const BatchParam& param);
|
||||
BatchSet<T> GetBatches(Context const* ctx);
|
||||
template <typename T>
|
||||
BatchSet<T> GetBatches(Context const* ctx, const BatchParam& param);
|
||||
template <typename T>
|
||||
bool PageExists() const;
|
||||
|
||||
@ -558,21 +596,17 @@ class DMatrix {
|
||||
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
|
||||
}
|
||||
|
||||
/*!
|
||||
/**
|
||||
* \brief Load DMatrix from URI.
|
||||
*
|
||||
* \param uri The URI of input.
|
||||
* \param silent Whether print information during loading.
|
||||
* \param data_split_mode In distributed mode, split the input according this mode; otherwise,
|
||||
* it's just an indicator on how the input was split beforehand.
|
||||
* \param file_format The format type of the file, used for dmlc::Parser::Create.
|
||||
* By default "auto" will be able to load in both local binary file.
|
||||
* \param page_size Page size for external memory.
|
||||
* \return The created DMatrix.
|
||||
*/
|
||||
static DMatrix* Load(const std::string& uri,
|
||||
bool silent = true,
|
||||
DataSplitMode data_split_mode = DataSplitMode::kRow,
|
||||
const std::string& file_format = "auto");
|
||||
static DMatrix* Load(const std::string& uri, bool silent = true,
|
||||
DataSplitMode data_split_mode = DataSplitMode::kRow);
|
||||
|
||||
/**
|
||||
* \brief Creates a new DMatrix from an external data adapter.
|
||||
@ -654,11 +688,12 @@ class DMatrix {
|
||||
|
||||
protected:
|
||||
virtual BatchSet<SparsePage> GetRowBatches() = 0;
|
||||
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
|
||||
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
|
||||
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;
|
||||
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) = 0;
|
||||
virtual BatchSet<ExtSparsePage> GetExtBatches(BatchParam const& param) = 0;
|
||||
virtual BatchSet<CSCPage> GetColumnBatches(Context const* ctx) = 0;
|
||||
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches(Context const* ctx) = 0;
|
||||
virtual BatchSet<EllpackPage> GetEllpackBatches(Context const* ctx, BatchParam const& param) = 0;
|
||||
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(Context const* ctx,
|
||||
BatchParam const& param) = 0;
|
||||
virtual BatchSet<ExtSparsePage> GetExtBatches(Context const* ctx, BatchParam const& param) = 0;
|
||||
|
||||
virtual bool EllpackExists() const = 0;
|
||||
virtual bool GHistIndexExists() const = 0;
|
||||
@ -686,28 +721,33 @@ inline bool DMatrix::PageExists<SparsePage>() const {
|
||||
}
|
||||
|
||||
template <>
|
||||
inline BatchSet<CSCPage> DMatrix::GetBatches() {
|
||||
return GetColumnBatches();
|
||||
inline BatchSet<SparsePage> DMatrix::GetBatches(Context const*) {
|
||||
return GetRowBatches();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline BatchSet<SortedCSCPage> DMatrix::GetBatches() {
|
||||
return GetSortedColumnBatches();
|
||||
inline BatchSet<CSCPage> DMatrix::GetBatches(Context const* ctx) {
|
||||
return GetColumnBatches(ctx);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline BatchSet<EllpackPage> DMatrix::GetBatches(const BatchParam& param) {
|
||||
return GetEllpackBatches(param);
|
||||
inline BatchSet<SortedCSCPage> DMatrix::GetBatches(Context const* ctx) {
|
||||
return GetSortedColumnBatches(ctx);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(const BatchParam& param) {
|
||||
return GetGradientIndex(param);
|
||||
inline BatchSet<EllpackPage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
|
||||
return GetEllpackBatches(ctx, param);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline BatchSet<ExtSparsePage> DMatrix::GetBatches() {
|
||||
return GetExtBatches(BatchParam{});
|
||||
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
|
||||
return GetGradientIndex(ctx, param);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline BatchSet<ExtSparsePage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
|
||||
return GetExtBatches(ctx, param);
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@ -567,7 +567,7 @@ class RegTree : public Model {
|
||||
* \brief drop the trace after fill, must be called after fill.
|
||||
* \param inst The sparse instance to drop.
|
||||
*/
|
||||
void Drop(const SparsePage::Inst& inst);
|
||||
void Drop();
|
||||
/*!
|
||||
* \brief returns the size of the feature vector
|
||||
* \return the size of the feature vector
|
||||
@ -807,13 +807,10 @@ inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
|
||||
has_missing_ = data_.size() != feature_count;
|
||||
}
|
||||
|
||||
inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) {
|
||||
for (auto const& entry : inst) {
|
||||
if (entry.index >= data_.size()) {
|
||||
continue;
|
||||
}
|
||||
data_[entry.index].flag = -1;
|
||||
}
|
||||
inline void RegTree::FVec::Drop() {
|
||||
Entry e{};
|
||||
e.flag = -1;
|
||||
std::fill_n(data_.data(), data_.size(), e);
|
||||
has_missing_ = true;
|
||||
}
|
||||
|
||||
|
||||
@ -33,16 +33,16 @@
|
||||
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
|
||||
<maven.compiler.source>1.8</maven.compiler.source>
|
||||
<maven.compiler.target>1.8</maven.compiler.target>
|
||||
<flink.version>1.8.3</flink.version>
|
||||
<spark.version>3.1.1</spark.version>
|
||||
<scala.version>2.12.8</scala.version>
|
||||
<flink.version>1.17.0</flink.version>
|
||||
<spark.version>3.4.0</spark.version>
|
||||
<scala.version>2.12.17</scala.version>
|
||||
<scala.binary.version>2.12</scala.binary.version>
|
||||
<hadoop.version>3.3.5</hadoop.version>
|
||||
<maven.wagon.http.retryHandler.count>5</maven.wagon.http.retryHandler.count>
|
||||
<log.capi.invocation>OFF</log.capi.invocation>
|
||||
<use.cuda>OFF</use.cuda>
|
||||
<cudf.version>22.12.0</cudf.version>
|
||||
<spark.rapids.version>22.12.0</spark.rapids.version>
|
||||
<cudf.version>23.04.0</cudf.version>
|
||||
<spark.rapids.version>23.04.0</spark.rapids.version>
|
||||
<cudf.classifier>cuda11</cudf.classifier>
|
||||
</properties>
|
||||
<repositories>
|
||||
@ -374,7 +374,7 @@
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-checkstyle-plugin</artifactId>
|
||||
<version>3.2.1</version>
|
||||
<version>3.2.2</version>
|
||||
<configuration>
|
||||
<configLocation>checkstyle.xml</configLocation>
|
||||
<failOnViolation>true</failOnViolation>
|
||||
@ -450,7 +450,7 @@
|
||||
<plugins>
|
||||
<plugin>
|
||||
<artifactId>maven-project-info-reports-plugin</artifactId>
|
||||
<version>3.4.2</version>
|
||||
<version>3.4.3</version>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>net.alchim31.maven</groupId>
|
||||
@ -469,7 +469,7 @@
|
||||
<dependency>
|
||||
<groupId>com.esotericsoftware</groupId>
|
||||
<artifactId>kryo</artifactId>
|
||||
<version>5.4.0</version>
|
||||
<version>5.5.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
@ -477,11 +477,6 @@
|
||||
<version>${scala.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-reflect</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-library</artifactId>
|
||||
@ -495,13 +490,13 @@
|
||||
<dependency>
|
||||
<groupId>org.scalatest</groupId>
|
||||
<artifactId>scalatest_${scala.binary.version}</artifactId>
|
||||
<version>3.0.8</version>
|
||||
<version>3.2.15</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scalactic</groupId>
|
||||
<artifactId>scalactic_${scala.binary.version}</artifactId>
|
||||
<version>3.0.8</version>
|
||||
<version>3.2.15</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
@ -26,7 +26,7 @@
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-spark_${scala.binary.version}</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
@ -37,12 +37,7 @@
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>3.12.0</version>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2021 by Contributors
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -62,8 +62,8 @@ public class BasicWalkThrough {
|
||||
|
||||
public static void main(String[] args) throws IOException, XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
@ -112,7 +112,8 @@ public class BasicWalkThrough {
|
||||
|
||||
System.out.println("start build dmatrix from csr sparse data ...");
|
||||
//build dmatrix from CSR Sparse Matrix
|
||||
DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
|
||||
DataLoader.CSRSparseData spData =
|
||||
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
|
||||
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
|
||||
DMatrix.SparseType.CSR, 127);
|
||||
|
||||
@ -32,8 +32,8 @@ public class BoostFromPrediction {
|
||||
System.out.println("start running example to start from a initial prediction");
|
||||
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
|
||||
@ -30,7 +30,7 @@ import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
public class CrossValidation {
|
||||
public static void main(String[] args) throws IOException, XGBoostError {
|
||||
//load train mat
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
|
||||
//set params
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
|
||||
@ -139,9 +139,9 @@ public class CustomObjective {
|
||||
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
//load train mat (svmlight format)
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
//load valid mat (svmlight format)
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
params.put("eta", 1.0);
|
||||
|
||||
@ -29,9 +29,9 @@ import ml.dmlc.xgboost4j.java.example.util.DataLoader;
|
||||
public class EarlyStopping {
|
||||
public static void main(String[] args) throws IOException, XGBoostError {
|
||||
DataLoader.CSRSparseData trainCSR =
|
||||
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
|
||||
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
DataLoader.CSRSparseData testCSR =
|
||||
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.test");
|
||||
DataLoader.loadSVMFile("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
|
||||
@ -32,8 +32,8 @@ public class ExternalMemory {
|
||||
//this is the only difference, add a # followed by a cache prefix name
|
||||
//several cache file with the prefix will be generated
|
||||
//currently only support convert from libsvm file
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm#dtrain.cache");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm#dtest.cache");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
|
||||
@ -32,8 +32,8 @@ import ml.dmlc.xgboost4j.java.example.util.CustomEval;
|
||||
public class GeneralizedLinearModel {
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
//specify parameters
|
||||
//change booster to gblinear, so that we are fitting a linear model
|
||||
|
||||
@ -31,8 +31,8 @@ import ml.dmlc.xgboost4j.java.example.util.CustomEval;
|
||||
public class PredictFirstNtree {
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
|
||||
@ -31,8 +31,8 @@ import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
public class PredictLeafIndices {
|
||||
public static void main(String[] args) throws XGBoostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm");
|
||||
|
||||
//specify parameters
|
||||
HashMap<String, Object> params = new HashMap<String, Object>();
|
||||
|
||||
@ -0,0 +1,107 @@
|
||||
/*
|
||||
Copyright (c) 2014-2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.example.flink;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.flink.api.common.typeinfo.TypeHint;
|
||||
import org.apache.flink.api.common.typeinfo.TypeInformation;
|
||||
import org.apache.flink.api.java.DataSet;
|
||||
import org.apache.flink.api.java.ExecutionEnvironment;
|
||||
import org.apache.flink.api.java.operators.MapOperator;
|
||||
import org.apache.flink.api.java.tuple.Tuple13;
|
||||
import org.apache.flink.api.java.tuple.Tuple2;
|
||||
import org.apache.flink.api.java.utils.DataSetUtils;
|
||||
import org.apache.flink.ml.linalg.DenseVector;
|
||||
import org.apache.flink.ml.linalg.Vector;
|
||||
import org.apache.flink.ml.linalg.Vectors;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.flink.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.flink.XGBoostModel;
|
||||
|
||||
|
||||
public class DistTrainWithFlinkExample {
|
||||
|
||||
static Tuple2<XGBoostModel, DataSet<Float[]>> runPrediction(
|
||||
ExecutionEnvironment env,
|
||||
java.nio.file.Path trainPath,
|
||||
int percentage) throws Exception {
|
||||
// reading data
|
||||
final DataSet<Tuple2<Long, Tuple2<Vector, Double>>> data =
|
||||
DataSetUtils.zipWithIndex(parseCsv(env, trainPath));
|
||||
final long size = data.count();
|
||||
final long trainCount = Math.round(size * 0.01 * percentage);
|
||||
final DataSet<Tuple2<Vector, Double>> trainData =
|
||||
data
|
||||
.filter(item -> item.f0 < trainCount)
|
||||
.map(t -> t.f1)
|
||||
.returns(TypeInformation.of(new TypeHint<Tuple2<Vector, Double>>(){}));
|
||||
final DataSet<Vector> testData =
|
||||
data
|
||||
.filter(tuple -> tuple.f0 >= trainCount)
|
||||
.map(t -> t.f1.f0)
|
||||
.returns(TypeInformation.of(new TypeHint<Vector>(){}));
|
||||
|
||||
// define parameters
|
||||
HashMap<String, Object> paramMap = new HashMap<String, Object>(3);
|
||||
paramMap.put("eta", 0.1);
|
||||
paramMap.put("max_depth", 2);
|
||||
paramMap.put("objective", "binary:logistic");
|
||||
|
||||
// number of iterations
|
||||
final int round = 2;
|
||||
// train the model
|
||||
XGBoostModel model = XGBoost.train(trainData, paramMap, round);
|
||||
DataSet<Float[]> predTest = model.predict(testData);
|
||||
return new Tuple2<XGBoostModel, DataSet<Float[]>>(model, predTest);
|
||||
}
|
||||
|
||||
private static MapOperator<Tuple13<Double, String, Double, Double, Double, Integer, Integer,
|
||||
Integer, Integer, Integer, Integer, Integer, Integer>,
|
||||
Tuple2<Vector, Double>> parseCsv(ExecutionEnvironment env, Path trainPath) {
|
||||
return env.readCsvFile(trainPath.toString())
|
||||
.ignoreFirstLine()
|
||||
.types(Double.class, String.class, Double.class, Double.class, Double.class,
|
||||
Integer.class, Integer.class, Integer.class, Integer.class, Integer.class,
|
||||
Integer.class, Integer.class, Integer.class)
|
||||
.map(DistTrainWithFlinkExample::mapFunction);
|
||||
}
|
||||
|
||||
private static Tuple2<Vector, Double> mapFunction(Tuple13<Double, String, Double, Double, Double,
|
||||
Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer> tuple) {
|
||||
final DenseVector dense = Vectors.dense(tuple.f2, tuple.f3, tuple.f4, tuple.f5, tuple.f6,
|
||||
tuple.f7, tuple.f8, tuple.f9, tuple.f10, tuple.f11, tuple.f12);
|
||||
if (tuple.f1.contains("inf")) {
|
||||
return new Tuple2<Vector, Double>(dense, 1.0);
|
||||
} else {
|
||||
return new Tuple2<Vector, Double>(dense, 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
final java.nio.file.Path parentPath = java.nio.file.Paths.get(Arrays.stream(args)
|
||||
.findFirst().orElse("."));
|
||||
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
|
||||
Tuple2<XGBoostModel, DataSet<Float[]>> tuple2 = runPrediction(
|
||||
env, parentPath.resolve("veterans_lung_cancer.csv"), 70
|
||||
);
|
||||
List<Float[]> list = tuple2.f1.collect();
|
||||
System.out.println(list.size());
|
||||
}
|
||||
}
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -36,8 +36,8 @@ object BasicWalkThrough {
|
||||
}
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMax = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMax = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMax = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
@ -76,7 +76,7 @@ object BasicWalkThrough {
|
||||
|
||||
// build dmatrix from CSR Sparse Matrix
|
||||
println("start build dmatrix from csr sparse data ...")
|
||||
val spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train")
|
||||
val spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val trainMax2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data,
|
||||
JDMatrix.SparseType.CSR)
|
||||
trainMax2.setLabel(spData.labels)
|
||||
|
||||
@ -24,8 +24,8 @@ object BoostFromPrediction {
|
||||
def main(args: Array[String]): Unit = {
|
||||
println("start running example to start from a initial prediction")
|
||||
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
|
||||
@ -21,7 +21,7 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
|
||||
object CrossValidation {
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMat: DMatrix = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val trainMat: DMatrix = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
|
||||
// set params
|
||||
val params = new mutable.HashMap[String, Any]
|
||||
|
||||
@ -138,8 +138,8 @@ object CustomObjective {
|
||||
}
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
params += "max_depth" -> 2
|
||||
|
||||
@ -25,8 +25,8 @@ object ExternalMemory {
|
||||
// this is the only difference, add a # followed by a cache prefix name
|
||||
// several cache file with the prefix will be generated
|
||||
// currently only support convert from libsvm file
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm#dtrain.cache")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm#dtest.cache")
|
||||
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
|
||||
@ -27,8 +27,8 @@ import ml.dmlc.xgboost4j.scala.example.util.CustomEval
|
||||
*/
|
||||
object GeneralizedLinearModel {
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
|
||||
// specify parameters
|
||||
// change booster to gblinear, so that we are fitting a linear model
|
||||
|
||||
@ -23,8 +23,8 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
object PredictFirstNTree {
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
|
||||
@ -25,8 +25,8 @@ import ml.dmlc.xgboost4j.scala.{XGBoost, DMatrix}
|
||||
object PredictLeafIndices {
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train?format=libsvm")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test?format=libsvm")
|
||||
|
||||
val params = new mutable.HashMap[String, Any]()
|
||||
params += "eta" -> 1.0
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014 - 2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -15,27 +15,84 @@
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.scala.example.flink
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.flink.XGBoost
|
||||
import org.apache.flink.api.scala.{ExecutionEnvironment, _}
|
||||
import org.apache.flink.ml.MLUtils
|
||||
import java.lang.{Double => JDouble, Long => JLong}
|
||||
import java.nio.file.{Path, Paths}
|
||||
import org.apache.flink.api.java.tuple.{Tuple13, Tuple2}
|
||||
import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
|
||||
import org.apache.flink.ml.linalg.{Vector, Vectors}
|
||||
import ml.dmlc.xgboost4j.java.flink.{XGBoost, XGBoostModel}
|
||||
import org.apache.flink.api.common.typeinfo.{TypeHint, TypeInformation}
|
||||
import org.apache.flink.api.java.utils.DataSetUtils
|
||||
|
||||
|
||||
object DistTrainWithFlink {
|
||||
def main(args: Array[String]) {
|
||||
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
|
||||
// read trainining data
|
||||
val trainData =
|
||||
MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.train")
|
||||
val testData = MLUtils.readLibSVM(env, "/path/to/data/agaricus.txt.test")
|
||||
// define parameters
|
||||
val paramMap = List(
|
||||
"eta" -> 0.1,
|
||||
"max_depth" -> 2,
|
||||
"objective" -> "binary:logistic").toMap
|
||||
import scala.jdk.CollectionConverters._
|
||||
private val rowTypeHint = TypeInformation.of(new TypeHint[Tuple2[Vector, JDouble]]{})
|
||||
private val testDataTypeHint = TypeInformation.of(classOf[Vector])
|
||||
|
||||
private[flink] def parseCsv(trainPath: Path)(implicit env: ExecutionEnvironment):
|
||||
DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = {
|
||||
DataSetUtils.zipWithIndex(
|
||||
env
|
||||
.readCsvFile(trainPath.toString)
|
||||
.ignoreFirstLine
|
||||
.types(
|
||||
classOf[Double], classOf[String], classOf[Double], classOf[Double], classOf[Double],
|
||||
classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer],
|
||||
classOf[Integer], classOf[Integer], classOf[Integer], classOf[Integer]
|
||||
)
|
||||
.map((row: Tuple13[Double, String, Double, Double, Double,
|
||||
Integer, Integer, Integer, Integer, Integer, Integer, Integer, Integer]) => {
|
||||
val dense = Vectors.dense(row.f2, row.f3, row.f4,
|
||||
row.f5.toDouble, row.f6.toDouble, row.f7.toDouble, row.f8.toDouble,
|
||||
row.f9.toDouble, row.f10.toDouble, row.f11.toDouble, row.f12.toDouble)
|
||||
val label = if (row.f1.contains("inf")) {
|
||||
JDouble.valueOf(1.0)
|
||||
} else {
|
||||
JDouble.valueOf(0.0)
|
||||
}
|
||||
new Tuple2[Vector, JDouble](dense, label)
|
||||
})
|
||||
.returns(rowTypeHint)
|
||||
)
|
||||
}
|
||||
|
||||
private[flink] def runPrediction(trainPath: Path, percentage: Int)
|
||||
(implicit env: ExecutionEnvironment):
|
||||
(XGBoostModel, DataSet[Array[Float]]) = {
|
||||
// read training data
|
||||
val data: DataSet[Tuple2[JLong, Tuple2[Vector, JDouble]]] = parseCsv(trainPath)
|
||||
val trainSize = Math.round(0.01 * percentage * data.count())
|
||||
val trainData: DataSet[Tuple2[Vector, JDouble]] =
|
||||
data.filter(d => d.f0 < trainSize).map(_.f1).returns(rowTypeHint)
|
||||
|
||||
|
||||
val testData: DataSet[Vector] =
|
||||
data
|
||||
.filter(d => d.f0 >= trainSize)
|
||||
.map(_.f1.f0)
|
||||
.returns(testDataTypeHint)
|
||||
|
||||
val paramMap = mapAsJavaMap(Map(
|
||||
("eta", "0.1".asInstanceOf[AnyRef]),
|
||||
("max_depth", "2"),
|
||||
("objective", "binary:logistic"),
|
||||
("verbosity", "1")
|
||||
))
|
||||
|
||||
// number of iterations
|
||||
val round = 2
|
||||
// train the model
|
||||
val model = XGBoost.train(trainData, paramMap, round)
|
||||
val predTest = model.predict(testData.map{x => x.vector})
|
||||
model.saveModelAsHadoopFile("file:///path/to/xgboost.model")
|
||||
val result = model.predict(testData).map(prediction => prediction.map(Float.unbox))
|
||||
(model, result)
|
||||
}
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
implicit val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
|
||||
val parentPath = Paths.get(args.headOption.getOrElse("."))
|
||||
val (_, predTest) = runPrediction(parentPath.resolve("veterans_lung_cancer.csv"), 70)
|
||||
val list = predTest.collect().asScala
|
||||
println(list.length)
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,36 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.java.example.flink
|
||||
|
||||
import org.apache.flink.api.java.ExecutionEnvironment
|
||||
import org.scalatest.Inspectors._
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.scalatest.matchers.should.Matchers._
|
||||
|
||||
import java.nio.file.Paths
|
||||
|
||||
class DistTrainWithFlinkExampleTest extends AnyFunSuite {
|
||||
private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
|
||||
private val data = parentPath.resolve("veterans_lung_cancer.csv")
|
||||
|
||||
test("Smoke test for scala flink example") {
|
||||
val env = ExecutionEnvironment.createLocalEnvironment(1)
|
||||
val tuple2 = DistTrainWithFlinkExample.runPrediction(env, data, 70)
|
||||
val results = tuple2.f1.collect()
|
||||
results should have size 41
|
||||
forEvery(results)(item => item should have size 1)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,37 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package ml.dmlc.xgboost4j.scala.example.flink
|
||||
|
||||
import org.apache.flink.api.java.ExecutionEnvironment
|
||||
import org.scalatest.Inspectors._
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.scalatest.matchers.should.Matchers._
|
||||
|
||||
import java.nio.file.Paths
|
||||
import scala.jdk.CollectionConverters._
|
||||
|
||||
class DistTrainWithFlinkSuite extends AnyFunSuite {
|
||||
private val parentPath = Paths.get("../../").resolve("demo").resolve("data")
|
||||
private val data = parentPath.resolve("veterans_lung_cancer.csv")
|
||||
|
||||
test("Smoke test for scala flink example") {
|
||||
implicit val env: ExecutionEnvironment = ExecutionEnvironment.createLocalEnvironment(1)
|
||||
val (_, result) = DistTrainWithFlink.runPrediction(data, 70)
|
||||
val results = result.collect().asScala
|
||||
results should have size 41
|
||||
forEvery(results)(item => item should have size 1)
|
||||
}
|
||||
}
|
||||
@ -8,8 +8,11 @@
|
||||
<artifactId>xgboost-jvm_2.12</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4j-flink_2.12</artifactId>
|
||||
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
<properties>
|
||||
<flink-ml.version>2.2.0</flink-ml.version>
|
||||
</properties>
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
@ -26,32 +29,22 @@
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
|
||||
<version>2.0.0-SNAPSHOT</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>3.12.0</version>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.flink</groupId>
|
||||
<artifactId>flink-scala_${scala.binary.version}</artifactId>
|
||||
<artifactId>flink-clients</artifactId>
|
||||
<version>${flink.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.flink</groupId>
|
||||
<artifactId>flink-clients_${scala.binary.version}</artifactId>
|
||||
<version>${flink.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.flink</groupId>
|
||||
<artifactId>flink-ml_${scala.binary.version}</artifactId>
|
||||
<version>${flink.version}</version>
|
||||
<artifactId>flink-ml-servable-core</artifactId>
|
||||
<version>${flink-ml.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-common</artifactId>
|
||||
<version>3.3.5</version>
|
||||
<version>${hadoop.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
@ -0,0 +1,187 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java.flink;
|
||||
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Iterator;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.StreamSupport;
|
||||
|
||||
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
|
||||
import org.apache.flink.api.java.DataSet;
|
||||
import org.apache.flink.api.java.tuple.Tuple2;
|
||||
import org.apache.flink.ml.linalg.SparseVector;
|
||||
import org.apache.flink.ml.linalg.Vector;
|
||||
import org.apache.flink.util.Collector;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.fs.FSDataInputStream;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.Communicator;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.RabitTracker;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
|
||||
public class XGBoost {
|
||||
private static final Logger logger = LoggerFactory.getLogger(XGBoost.class);
|
||||
|
||||
private static class MapFunction
|
||||
extends RichMapPartitionFunction<Tuple2<Vector, Double>, XGBoostModel> {
|
||||
|
||||
private final Map<String, Object> params;
|
||||
private final int round;
|
||||
private final Map<String, String> workerEnvs;
|
||||
|
||||
public MapFunction(Map<String, Object> params, int round, Map<String, String> workerEnvs) {
|
||||
this.params = params;
|
||||
this.round = round;
|
||||
this.workerEnvs = workerEnvs;
|
||||
}
|
||||
|
||||
public void mapPartition(java.lang.Iterable<Tuple2<Vector, Double>> it,
|
||||
Collector<XGBoostModel> collector) throws XGBoostError {
|
||||
workerEnvs.put(
|
||||
"DMLC_TASK_ID",
|
||||
String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask())
|
||||
);
|
||||
|
||||
if (logger.isInfoEnabled()) {
|
||||
logger.info("start with env: {}", workerEnvs.entrySet().stream()
|
||||
.map(e -> String.format("\"%s\": \"%s\"", e.getKey(), e.getValue()))
|
||||
.collect(Collectors.joining(", "))
|
||||
);
|
||||
}
|
||||
|
||||
final Iterator<LabeledPoint> dataIter =
|
||||
StreamSupport
|
||||
.stream(it.spliterator(), false)
|
||||
.map(VectorToPointMapper.INSTANCE)
|
||||
.iterator();
|
||||
|
||||
if (dataIter.hasNext()) {
|
||||
final DMatrix trainMat = new DMatrix(dataIter, null);
|
||||
int numEarlyStoppingRounds =
|
||||
Optional.ofNullable(params.get("numEarlyStoppingRounds"))
|
||||
.map(x -> Integer.parseInt(x.toString()))
|
||||
.orElse(0);
|
||||
|
||||
final Booster booster = trainBooster(trainMat, numEarlyStoppingRounds);
|
||||
collector.collect(new XGBoostModel(booster));
|
||||
} else {
|
||||
logger.warn("Nothing to train with.");
|
||||
}
|
||||
}
|
||||
|
||||
private Booster trainBooster(DMatrix trainMat,
|
||||
int numEarlyStoppingRounds) throws XGBoostError {
|
||||
Booster booster;
|
||||
final Map<String, DMatrix> watches =
|
||||
new HashMap<String, DMatrix>() {{ put("train", trainMat); }};
|
||||
try {
|
||||
Communicator.init(workerEnvs);
|
||||
booster = ml.dmlc.xgboost4j.java.XGBoost
|
||||
.train(
|
||||
trainMat,
|
||||
params,
|
||||
round,
|
||||
watches,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
numEarlyStoppingRounds);
|
||||
} catch (XGBoostError xgbException) {
|
||||
final String identifier = String.valueOf(this.getRuntimeContext().getIndexOfThisSubtask());
|
||||
logger.warn(
|
||||
String.format("XGBooster worker %s has failed due to", identifier),
|
||||
xgbException
|
||||
);
|
||||
throw xgbException;
|
||||
} finally {
|
||||
Communicator.shutdown();
|
||||
}
|
||||
return booster;
|
||||
}
|
||||
|
||||
private static class VectorToPointMapper
|
||||
implements Function<Tuple2<Vector, Double>, LabeledPoint> {
|
||||
public static VectorToPointMapper INSTANCE = new VectorToPointMapper();
|
||||
@Override
|
||||
public LabeledPoint apply(Tuple2<Vector, Double> tuple) {
|
||||
final SparseVector vector = tuple.f0.toSparse();
|
||||
final double[] values = vector.values;
|
||||
final int size = values.length;
|
||||
final float[] array = new float[size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
array[i] = (float) values[i];
|
||||
}
|
||||
return new LabeledPoint(
|
||||
tuple.f1.floatValue(),
|
||||
vector.size(),
|
||||
vector.indices,
|
||||
array);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load XGBoost model from path, using Hadoop Filesystem API.
|
||||
*
|
||||
* @param modelPath The path that is accessible by hadoop filesystem API.
|
||||
* @return The loaded model
|
||||
*/
|
||||
public static XGBoostModel loadModelFromHadoopFile(final String modelPath) throws Exception {
|
||||
final FileSystem fileSystem = FileSystem.get(new Configuration());
|
||||
final Path f = new Path(modelPath);
|
||||
|
||||
try (FSDataInputStream opened = fileSystem.open(f)) {
|
||||
return new XGBoostModel(ml.dmlc.xgboost4j.java.XGBoost.loadModel(opened));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a xgboost model with link.
|
||||
*
|
||||
* @param dtrain The training data.
|
||||
* @param params XGBoost parameters.
|
||||
* @param numBoostRound Number of rounds to train.
|
||||
*/
|
||||
public static XGBoostModel train(DataSet<Tuple2<Vector, Double>> dtrain,
|
||||
Map<String, Object> params,
|
||||
int numBoostRound) throws Exception {
|
||||
final RabitTracker tracker =
|
||||
new RabitTracker(dtrain.getExecutionEnvironment().getParallelism());
|
||||
if (tracker.start(0L)) {
|
||||
return dtrain
|
||||
.mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerEnvs()))
|
||||
.reduce((x, y) -> x)
|
||||
.collect()
|
||||
.get(0);
|
||||
} else {
|
||||
throw new Error("Tracker cannot be started");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,136 @@
|
||||
/*
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java.flink;
|
||||
import java.io.IOException;
|
||||
import java.io.Serializable;
|
||||
import java.util.Arrays;
|
||||
import java.util.Iterator;
|
||||
import java.util.stream.StreamSupport;
|
||||
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.apache.flink.api.common.functions.MapPartitionFunction;
|
||||
import org.apache.flink.api.java.DataSet;
|
||||
import org.apache.flink.ml.linalg.SparseVector;
|
||||
import org.apache.flink.ml.linalg.Vector;
|
||||
import org.apache.flink.util.Collector;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint;
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
|
||||
public class XGBoostModel implements Serializable {
|
||||
private static final org.slf4j.Logger logger =
|
||||
org.slf4j.LoggerFactory.getLogger(XGBoostModel.class);
|
||||
|
||||
private final Booster booster;
|
||||
private final PredictorFunction predictorFunction;
|
||||
|
||||
|
||||
public XGBoostModel(Booster booster) {
|
||||
this.booster = booster;
|
||||
this.predictorFunction = new PredictorFunction(booster);
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the model as a Hadoop filesystem file.
|
||||
*
|
||||
* @param modelPath The model path as in Hadoop path.
|
||||
*/
|
||||
public void saveModelAsHadoopFile(String modelPath) throws IOException, XGBoostError {
|
||||
booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)));
|
||||
}
|
||||
|
||||
public byte[] toByteArray(String format) throws XGBoostError {
|
||||
return booster.toByteArray(format);
|
||||
}
|
||||
|
||||
/**
|
||||
* Save the model as a Hadoop filesystem file.
|
||||
*
|
||||
* @param modelPath The model path as in Hadoop path.
|
||||
* @param format The model format (ubj, json, deprecated)
|
||||
* @throws XGBoostError internal error
|
||||
* @throws IOException save error
|
||||
*/
|
||||
public void saveModelAsHadoopFile(String modelPath, String format)
|
||||
throws IOException, XGBoostError {
|
||||
booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)), format);
|
||||
}
|
||||
|
||||
/**
|
||||
* predict with the given DMatrix
|
||||
*
|
||||
* @param testSet the local test set represented as DMatrix
|
||||
* @return prediction result
|
||||
*/
|
||||
public float[][] predict(DMatrix testSet) throws XGBoostError {
|
||||
return booster.predict(testSet, true, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict given vector dataset.
|
||||
*
|
||||
* @param data The dataset to be predicted.
|
||||
* @return The prediction result.
|
||||
*/
|
||||
public DataSet<Float[]> predict(DataSet<Vector> data) {
|
||||
return data.mapPartition(predictorFunction);
|
||||
}
|
||||
|
||||
|
||||
private static class PredictorFunction implements MapPartitionFunction<Vector, Float[]> {
|
||||
|
||||
private final Booster booster;
|
||||
|
||||
public PredictorFunction(Booster booster) {
|
||||
this.booster = booster;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void mapPartition(Iterable<Vector> it, Collector<Float[]> out) throws Exception {
|
||||
final Iterator<LabeledPoint> dataIter =
|
||||
StreamSupport.stream(it.spliterator(), false)
|
||||
.map(Vector::toSparse)
|
||||
.map(PredictorFunction::fromVector)
|
||||
.iterator();
|
||||
|
||||
if (dataIter.hasNext()) {
|
||||
final DMatrix data = new DMatrix(dataIter, null);
|
||||
float[][] predictions = booster.predict(data, true, 2);
|
||||
Arrays.stream(predictions).map(ArrayUtils::toObject).forEach(out::collect);
|
||||
} else {
|
||||
logger.debug("Empty partition");
|
||||
}
|
||||
}
|
||||
|
||||
private static LabeledPoint fromVector(SparseVector vector) {
|
||||
final int[] index = vector.indices;
|
||||
final double[] value = vector.values;
|
||||
int size = value.length;
|
||||
final float[] values = new float[size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
values[i] = (float) value[i];
|
||||
}
|
||||
return new LabeledPoint(0.0f, vector.size(), index, values);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,99 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.flink
|
||||
|
||||
import scala.collection.JavaConverters.asScalaIteratorConverter
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => XGBoostScala}
|
||||
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.apache.flink.api.common.functions.RichMapPartitionFunction
|
||||
import org.apache.flink.api.scala.{DataSet, _}
|
||||
import org.apache.flink.ml.common.LabeledVector
|
||||
import org.apache.flink.util.Collector
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
|
||||
object XGBoost {
|
||||
/**
|
||||
* Helper map function to start the job.
|
||||
*
|
||||
* @param workerEnvs
|
||||
*/
|
||||
private class MapFunction(paramMap: Map[String, Any],
|
||||
round: Int,
|
||||
workerEnvs: java.util.Map[String, String])
|
||||
extends RichMapPartitionFunction[LabeledVector, XGBoostModel] {
|
||||
val logger = LogFactory.getLog(this.getClass)
|
||||
|
||||
def mapPartition(it: java.lang.Iterable[LabeledVector],
|
||||
collector: Collector[XGBoostModel]): Unit = {
|
||||
workerEnvs.put("DMLC_TASK_ID", String.valueOf(this.getRuntimeContext.getIndexOfThisSubtask))
|
||||
logger.info("start with env" + workerEnvs.toString)
|
||||
Communicator.init(workerEnvs)
|
||||
val mapper = (x: LabeledVector) => {
|
||||
val (index, value) = x.vector.toSeq.unzip
|
||||
LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray)
|
||||
}
|
||||
val dataIter = for (x <- it.iterator().asScala) yield mapper(x)
|
||||
val trainMat = new DMatrix(dataIter, null)
|
||||
val watches = List("train" -> trainMat).toMap
|
||||
val round = 2
|
||||
val numEarlyStoppingRounds = paramMap.get("numEarlyStoppingRounds")
|
||||
.map(_.toString.toInt).getOrElse(0)
|
||||
val booster = XGBoostScala.train(trainMat, paramMap, round, watches,
|
||||
earlyStoppingRound = numEarlyStoppingRounds)
|
||||
Communicator.shutdown()
|
||||
collector.collect(new XGBoostModel(booster))
|
||||
}
|
||||
}
|
||||
|
||||
val logger = LogFactory.getLog(this.getClass)
|
||||
|
||||
/**
|
||||
* Load XGBoost model from path, using Hadoop Filesystem API.
|
||||
*
|
||||
* @param modelPath The path that is accessible by hadoop filesystem API.
|
||||
* @return The loaded model
|
||||
*/
|
||||
def loadModelFromHadoopFile(modelPath: String) : XGBoostModel = {
|
||||
new XGBoostModel(
|
||||
XGBoostScala.loadModel(FileSystem.get(new Configuration).open(new Path(modelPath))))
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a xgboost model with link.
|
||||
*
|
||||
* @param dtrain The training data.
|
||||
* @param params The parameters to XGBoost.
|
||||
* @param round Number of rounds to train.
|
||||
*/
|
||||
def train(dtrain: DataSet[LabeledVector], params: Map[String, Any], round: Int):
|
||||
XGBoostModel = {
|
||||
val tracker = new RabitTracker(dtrain.getExecutionEnvironment.getParallelism)
|
||||
if (tracker.start(0L)) {
|
||||
dtrain
|
||||
.mapPartition(new MapFunction(params, round, tracker.getWorkerEnvs))
|
||||
.reduce((x, y) => x).collect().head
|
||||
} else {
|
||||
throw new Error("Tracker cannot be started")
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,67 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.flink
|
||||
|
||||
import ml.dmlc.xgboost4j.LabeledPoint
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||
|
||||
import org.apache.flink.api.scala.{DataSet, _}
|
||||
import org.apache.flink.ml.math.Vector
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
|
||||
class XGBoostModel (booster: Booster) extends Serializable {
|
||||
/**
|
||||
* Save the model as a Hadoop filesystem file.
|
||||
*
|
||||
* @param modelPath The model path as in Hadoop path.
|
||||
*/
|
||||
def saveModelAsHadoopFile(modelPath: String): Unit = {
|
||||
booster.saveModel(FileSystem
|
||||
.get(new Configuration)
|
||||
.create(new Path(modelPath)))
|
||||
}
|
||||
|
||||
/**
|
||||
* predict with the given DMatrix
|
||||
* @param testSet the local test set represented as DMatrix
|
||||
* @return prediction result
|
||||
*/
|
||||
def predict(testSet: DMatrix): Array[Array[Float]] = {
|
||||
booster.predict(testSet, true, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict given vector dataset.
|
||||
*
|
||||
* @param data The dataset to be predicted.
|
||||
* @return The prediction result.
|
||||
*/
|
||||
def predict(data: DataSet[Vector]) : DataSet[Array[Float]] = {
|
||||
val predictMap: Iterator[Vector] => Traversable[Array[Float]] =
|
||||
(it: Iterator[Vector]) => {
|
||||
val mapper = (x: Vector) => {
|
||||
val (index, value) = x.toSeq.unzip
|
||||
LabeledPoint(0.0f, x.size, index.toArray, value.map(_.toFloat).toArray)
|
||||
}
|
||||
val dataIter = for (x <- it) yield mapper(x)
|
||||
val dmat = new DMatrix(dataIter, null)
|
||||
this.booster.predict(dmat)
|
||||
}
|
||||
data.mapPartition(predictMap)
|
||||
}
|
||||
}
|
||||
@ -38,22 +38,10 @@
|
||||
<version>4.13.2</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.akka</groupId>
|
||||
<artifactId>akka-actor_${scala.binary.version}</artifactId>
|
||||
<version>2.6.20</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.akka</groupId>
|
||||
<artifactId>akka-testkit_${scala.binary.version}</artifactId>
|
||||
<version>2.6.20</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scalatest</groupId>
|
||||
<artifactId>scalatest_${scala.binary.version}</artifactId>
|
||||
<version>3.0.5</version>
|
||||
<version>3.2.15</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
|
||||
@ -19,10 +19,10 @@ package ml.dmlc.xgboost4j.scala
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import ai.rapids.cudf.Table
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
|
||||
|
||||
class QuantileDMatrixSuite extends FunSuite {
|
||||
class QuantileDMatrixSuite extends AnyFunSuite {
|
||||
|
||||
test("QuantileDMatrix test") {
|
||||
|
||||
|
||||
@ -44,13 +44,6 @@
|
||||
<version>${spark.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ai.rapids</groupId>
|
||||
<artifactId>cudf</artifactId>
|
||||
<version>${cudf.version}</version>
|
||||
<classifier>${cudf.classifier}</classifier>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.nvidia</groupId>
|
||||
<artifactId>rapids-4-spark_${scala.binary.version}</artifactId>
|
||||
|
||||
@ -20,14 +20,15 @@ import java.nio.file.{Files, Path}
|
||||
import java.sql.{Date, Timestamp}
|
||||
import java.util.{Locale, TimeZone}
|
||||
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
import org.apache.spark.{GpuTestUtils, SparkConf}
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.network.util.JavaUtils
|
||||
import org.apache.spark.sql.{Row, SparkSession}
|
||||
|
||||
trait GpuTestSuite extends FunSuite with TmpFolderSuite {
|
||||
trait GpuTestSuite extends AnyFunSuite with TmpFolderSuite {
|
||||
import SparkSessionHolder.withSparkSession
|
||||
|
||||
protected def getResourcePath(resource: String): String = {
|
||||
@ -200,7 +201,7 @@ trait GpuTestSuite extends FunSuite with TmpFolderSuite {
|
||||
|
||||
}
|
||||
|
||||
trait TmpFolderSuite extends BeforeAndAfterAll { self: FunSuite =>
|
||||
trait TmpFolderSuite extends BeforeAndAfterAll { self: AnyFunSuite =>
|
||||
protected var tempDir: Path = _
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2021-2022 by Contributors
|
||||
Copyright (c) 2021-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -22,7 +22,6 @@ import java.util.ServiceLoader
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.{AbstractIterator, Iterator, mutable}
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Communicator
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
|
||||
@ -35,7 +34,6 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import org.apache.commons.logging.LogFactory
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.ml.{Estimator, Model}
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}
|
||||
@ -263,12 +261,6 @@ object PreXGBoost extends PreXGBoostProvider {
|
||||
private var batchCnt = 0
|
||||
|
||||
private val batchIterImpl = rowIterator.grouped(inferBatchSize).flatMap { batchRow =>
|
||||
if (batchCnt == 0) {
|
||||
val rabitEnv = Array(
|
||||
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
|
||||
Communicator.init(rabitEnv.asJava)
|
||||
}
|
||||
|
||||
val features = batchRow.iterator.map(row => row.getAs[Vector](featuresCol))
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
|
||||
@ -295,13 +287,8 @@ object PreXGBoost extends PreXGBoostProvider {
|
||||
|
||||
override def hasNext: Boolean = batchIterImpl.hasNext
|
||||
|
||||
override def next(): Row = {
|
||||
val ret = batchIterImpl.next()
|
||||
if (!batchIterImpl.hasNext) {
|
||||
Communicator.shutdown()
|
||||
}
|
||||
ret
|
||||
}
|
||||
override def next(): Row = batchIterImpl.next()
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
Copyright (c) 2014-2023 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -23,7 +23,6 @@ import scala.util.Random
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
|
||||
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
@ -44,21 +43,16 @@ import org.apache.spark.sql.SparkSession
|
||||
* Use a finite, non-zero timeout value to prevent tracker from
|
||||
* hanging indefinitely (in milliseconds)
|
||||
* (supported by "scala" implementation only.)
|
||||
* @param trackerImpl Choice between "python" or "scala". The former utilizes the Java wrapper of
|
||||
* the Python Rabit tracker (in dmlc_core), whereas the latter is implemented
|
||||
* in Scala without Python components, and with full support of timeouts.
|
||||
* The Scala implementation is currently experimental, use at your own risk.
|
||||
*
|
||||
* @param hostIp The Rabit Tracker host IP address which is only used for python implementation.
|
||||
* This is only needed if the host IP cannot be automatically guessed.
|
||||
* @param pythonExec The python executed path for Rabit Tracker,
|
||||
* which is only used for python implementation.
|
||||
*/
|
||||
case class TrackerConf(workerConnectionTimeout: Long, trackerImpl: String,
|
||||
case class TrackerConf(workerConnectionTimeout: Long,
|
||||
hostIp: String = "", pythonExec: String = "")
|
||||
|
||||
object TrackerConf {
|
||||
def apply(): TrackerConf = TrackerConf(0L, "python")
|
||||
def apply(): TrackerConf = TrackerConf(0L)
|
||||
}
|
||||
|
||||
private[scala] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int,
|
||||
@ -349,11 +343,9 @@ object XGBoost extends Serializable {
|
||||
|
||||
/** visiable for testing */
|
||||
private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = {
|
||||
val tracker: IRabitTracker = trackerConf.trackerImpl match {
|
||||
case "scala" => new RabitTracker(nWorkers)
|
||||
case "python" => new PyRabitTracker(nWorkers, trackerConf.hostIp, trackerConf.pythonExec)
|
||||
case _ => new PyRabitTracker(nWorkers)
|
||||
}
|
||||
val tracker: IRabitTracker = new PyRabitTracker(
|
||||
nWorkers, trackerConf.hostIp, trackerConf.pythonExec
|
||||
)
|
||||
tracker
|
||||
}
|
||||
|
||||
|
||||
@ -22,11 +22,10 @@ import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker => PyRabitTracker}
|
||||
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
|
||||
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
class CommunicatorRobustnessSuite extends FunSuite with PerTest {
|
||||
class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
|
||||
|
||||
private def getXGBoostExecutionParams(paramMap: Map[String, Any]): XGBoostExecutionParams = {
|
||||
val classifier = new XGBoostClassifier(paramMap)
|
||||
@ -40,7 +39,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
|
||||
|
||||
val paramMap = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, "python", hostIp))
|
||||
"tracker_conf" -> TrackerConf(0L, hostIp))
|
||||
val xgbExecParams = getXGBoostExecutionParams(paramMap)
|
||||
val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
|
||||
tracker match {
|
||||
@ -53,7 +52,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
|
||||
|
||||
val paramMap1 = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, "python", "", pythonExec))
|
||||
"tracker_conf" -> TrackerConf(0L, "", pythonExec))
|
||||
val xgbExecParams1 = getXGBoostExecutionParams(paramMap1)
|
||||
val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf)
|
||||
tracker1 match {
|
||||
@ -66,7 +65,7 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
|
||||
|
||||
val paramMap2 = Map(
|
||||
"num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(0L, "python", hostIp, pythonExec))
|
||||
"tracker_conf" -> TrackerConf(0L, hostIp, pythonExec))
|
||||
val xgbExecParams2 = getXGBoostExecutionParams(paramMap2)
|
||||
val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf)
|
||||
tracker2 match {
|
||||
@ -78,58 +77,6 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
|
||||
}
|
||||
}
|
||||
|
||||
test("training with Scala-implemented Rabit tracker") {
|
||||
val eval = new EvalError()
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "max_depth" -> "6",
|
||||
"objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
|
||||
"tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala"))
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
||||
}
|
||||
|
||||
test("test Communicator allreduce to validate Scala-implemented Rabit tracker") {
|
||||
val vectorLength = 100
|
||||
val rdd = sc.parallelize(
|
||||
(1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache()
|
||||
|
||||
val tracker = new ScalaRabitTracker(numWorkers)
|
||||
tracker.start(0)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]()
|
||||
|
||||
val rawData = rdd.mapPartitions { iter =>
|
||||
Iterator(iter.toArray)
|
||||
}.collect()
|
||||
|
||||
val maxVec = (0 until vectorLength).toArray.map { j =>
|
||||
(0 until numWorkers).toArray.map { i => rawData(i)(j) }.max
|
||||
}
|
||||
|
||||
val allReduceResults = rdd.mapPartitions { iter =>
|
||||
Communicator.init(trackerEnvs)
|
||||
val arr = iter.toArray
|
||||
val results = Communicator.allReduce(arr, Communicator.OpType.MAX)
|
||||
Communicator.shutdown()
|
||||
Iterator(results)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
allReduceResults.foreachPartition(() => _)
|
||||
val byPartitionResults = allReduceResults.collect()
|
||||
assert(byPartitionResults(0).length == vectorLength)
|
||||
collectedAllReduceResults.put(byPartitionResults(0))
|
||||
}
|
||||
}
|
||||
sparkThread.start()
|
||||
assert(tracker.waitFor(0L) == 0)
|
||||
sparkThread.join()
|
||||
|
||||
assert(collectedAllReduceResults.poll().sameElements(maxVec))
|
||||
}
|
||||
|
||||
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
|
||||
/*
|
||||
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
|
||||
@ -193,68 +140,6 @@ class CommunicatorRobustnessSuite extends FunSuite with PerTest {
|
||||
assert(tracker.waitFor(0) != 0)
|
||||
}
|
||||
|
||||
test("test Scala RabitTracker's exception handling: it should not hang forever.") {
|
||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
|
||||
val tracker = new ScalaRabitTracker(numWorkers)
|
||||
tracker.start(0)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
|
||||
val workerCount: Int = numWorkers
|
||||
val dummyTasks = rdd.mapPartitions { iter =>
|
||||
Communicator.init(trackerEnvs)
|
||||
val index = iter.next()
|
||||
Thread.sleep(100 + index * 10)
|
||||
if (index == workerCount) {
|
||||
// kill the worker by throwing an exception
|
||||
throw new RuntimeException("Worker exception.")
|
||||
}
|
||||
Communicator.shutdown()
|
||||
Iterator(index)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
// forces a Spark job.
|
||||
dummyTasks.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkThread.start()
|
||||
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
||||
}
|
||||
|
||||
test("test Scala RabitTracker's workerConnectionTimeout") {
|
||||
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
|
||||
|
||||
val tracker = new ScalaRabitTracker(numWorkers)
|
||||
tracker.start(500)
|
||||
val trackerEnvs = tracker.getWorkerEnvs
|
||||
|
||||
val dummyTasks = rdd.mapPartitions { iter =>
|
||||
val index = iter.next()
|
||||
// simulate that the first worker cannot connect to tracker due to network issues.
|
||||
if (index != 1) {
|
||||
Communicator.init(trackerEnvs)
|
||||
Thread.sleep(1000)
|
||||
Communicator.shutdown()
|
||||
}
|
||||
|
||||
Iterator(index)
|
||||
}.cache()
|
||||
|
||||
val sparkThread = new Thread() {
|
||||
override def run(): Unit = {
|
||||
// forces a Spark job.
|
||||
dummyTasks.foreachPartition(() => _)
|
||||
}
|
||||
}
|
||||
sparkThread.setUncaughtExceptionHandler(tracker)
|
||||
sparkThread.start()
|
||||
// should fail due to connection timeout
|
||||
assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode)
|
||||
}
|
||||
|
||||
test("should allow the dataframe containing communicator calls to be partially evaluated for" +
|
||||
" multiple times (ISSUE-4406)") {
|
||||
val paramMap = Map(
|
||||
|
||||
@ -17,13 +17,13 @@
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils
|
||||
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
|
||||
|
||||
import org.apache.spark.sql.functions._
|
||||
|
||||
class DeterministicPartitioningSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
class DeterministicPartitioningSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
|
||||
|
||||
test("perform deterministic partitioning when checkpointInternal and" +
|
||||
" checkpointPath is set (Classifier)") {
|
||||
|
||||
@ -19,10 +19,10 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import java.io.File
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, ExternalCheckpointManager, XGBoost => SXGBoost}
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
|
||||
class ExternalCheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
|
||||
|
||||
private def produceParamMap(checkpointPath: String, checkpointInterval: Int):
|
||||
Map[String, Any] = {
|
||||
|
||||
@ -18,12 +18,12 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import org.apache.spark.Partitioner
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.apache.spark.sql.functions._
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
class FeatureSizeValidatingSuite extends FunSuite with PerTest {
|
||||
class FeatureSizeValidatingSuite extends AnyFunSuite with PerTest {
|
||||
|
||||
test("transform throwing exception if feature size of dataset is greater than model's") {
|
||||
val modelPath = getClass.getResource("/model/0.82/model").getPath
|
||||
|
||||
@ -19,12 +19,12 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.apache.spark.sql.DataFrame
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.spark.SparkException
|
||||
|
||||
class MissingValueHandlingSuite extends FunSuite with PerTest {
|
||||
class MissingValueHandlingSuite extends AnyFunSuite with PerTest {
|
||||
test("dense vectors containing missing value") {
|
||||
def buildDenseDataFrame(): DataFrame = {
|
||||
val numRows = 100
|
||||
|
||||
@ -16,12 +16,13 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
|
||||
class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {
|
||||
class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
|
||||
|
||||
test("XGBoost and Spark parameters synchronize correctly") {
|
||||
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
|
||||
|
||||
@ -22,13 +22,14 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.{BeforeAndAfterEach, FunSuite}
|
||||
import org.scalatest.BeforeAndAfterEach
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import scala.math.min
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.commons.io.IOUtils
|
||||
|
||||
trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
|
||||
trait PerTest extends BeforeAndAfterEach { self: AnyFunSuite =>
|
||||
|
||||
protected val numWorkers: Int = min(Runtime.getRuntime.availableProcessors(), 4)
|
||||
|
||||
|
||||
@ -25,9 +25,9 @@ import scala.util.Random
|
||||
import org.apache.spark.ml.feature._
|
||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
class PersistenceSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
|
||||
|
||||
test("test persistence of XGBoostClassifier and XGBoostClassificationModel") {
|
||||
val eval = new EvalError()
|
||||
|
||||
@ -19,9 +19,10 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import java.nio.file.{Files, Path}
|
||||
|
||||
import org.apache.spark.network.util.JavaUtils
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
trait TmpFolderPerSuite extends BeforeAndAfterAll { self: FunSuite =>
|
||||
trait TmpFolderPerSuite extends BeforeAndAfterAll { self: AnyFunSuite =>
|
||||
protected var tempDir: Path = _
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
|
||||
@ -22,13 +22,13 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
|
||||
import org.apache.spark.ml.linalg._
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.apache.commons.io.IOUtils
|
||||
|
||||
import org.apache.spark.Partitioner
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
|
||||
class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuite {
|
||||
class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
|
||||
|
||||
protected val treeMethod: String = "auto"
|
||||
|
||||
|
||||
@ -21,11 +21,11 @@ import ml.dmlc.xgboost4j.scala.Booster
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
import org.apache.spark.SparkException
|
||||
|
||||
class XGBoostCommunicatorRegressionSuite extends FunSuite with PerTest {
|
||||
class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
|
||||
val predictionErrorMin = 0.00001f
|
||||
val maxFailure = 2;
|
||||
|
||||
|
||||
@ -19,9 +19,9 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
|
||||
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
class XGBoostConfigureSuite extends FunSuite with PerTest {
|
||||
class XGBoostConfigureSuite extends AnyFunSuite with PerTest {
|
||||
|
||||
override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
|
||||
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||
|
||||
@ -22,12 +22,12 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
|
||||
import org.apache.spark.{SparkException, TaskContext}
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
import org.apache.spark.sql.functions.lit
|
||||
|
||||
class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
|
||||
class XGBoostGeneralSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
|
||||
|
||||
test("distributed training with the specified worker number") {
|
||||
val trainingRDD = sc.parallelize(Classification.train)
|
||||
|
||||
@ -23,11 +23,11 @@ import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
|
||||
class XGBoostRegressorSuite extends FunSuite with PerTest with TmpFolderPerSuite {
|
||||
class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
|
||||
protected val treeMethod: String = "auto"
|
||||
|
||||
test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {
|
||||
|
||||
@ -69,7 +69,7 @@ pom_template = """
|
||||
<dependency>
|
||||
<groupId>org.scalactic</groupId>
|
||||
<artifactId>scalactic_${{scala.binary.version}}</artifactId>
|
||||
<version>3.0.8</version>
|
||||
<version>3.2.15</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
|
||||
@ -31,22 +31,10 @@
|
||||
<version>4.13.2</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.akka</groupId>
|
||||
<artifactId>akka-actor_${scala.binary.version}</artifactId>
|
||||
<version>2.6.20</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.akka</groupId>
|
||||
<artifactId>akka-testkit_${scala.binary.version}</artifactId>
|
||||
<version>2.6.20</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scalatest</groupId>
|
||||
<artifactId>scalatest_${scala.binary.version}</artifactId>
|
||||
<version>3.0.5</version>
|
||||
<version>3.2.15</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
@ -1,195 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.rabit
|
||||
|
||||
import java.net.{InetAddress, InetSocketAddress}
|
||||
|
||||
import akka.actor.ActorSystem
|
||||
import akka.pattern.ask
|
||||
import ml.dmlc.xgboost4j.java.{IRabitTracker, TrackerProperties}
|
||||
import ml.dmlc.xgboost4j.scala.rabit.handler.RabitTrackerHandler
|
||||
|
||||
import scala.concurrent.duration._
|
||||
import scala.concurrent.{Await, Future}
|
||||
import scala.util.{Failure, Success, Try}
|
||||
|
||||
/**
|
||||
* Scala implementation of the Rabit tracker interface without Python dependency.
|
||||
* The Scala Rabit tracker fully implements the timeout logic, effectively preventing the tracker
|
||||
* (and thus any distributed tasks) to hang indefinitely due to network issues or worker node
|
||||
* failures.
|
||||
*
|
||||
* Note that this implementation is currently experimental, and should be used at your own risk.
|
||||
*
|
||||
* Example usage:
|
||||
* {{{
|
||||
* import scala.concurrent.duration._
|
||||
*
|
||||
* val tracker = new RabitTracker(32)
|
||||
* // allow up to 10 minutes for all workers to connect to the tracker.
|
||||
* tracker.start(10 minutes)
|
||||
*
|
||||
* /* ...
|
||||
* launching workers in parallel
|
||||
* ...
|
||||
* */
|
||||
*
|
||||
* // wait for worker execution up to 6 hours.
|
||||
* // providing a finite timeout prevents a long-running task from hanging forever in
|
||||
* // catastrophic events, like the loss of an executor during model training.
|
||||
* tracker.waitFor(6 hours)
|
||||
* }}}
|
||||
*
|
||||
* @param numWorkers Number of distributed workers from which the tracker expects connections.
|
||||
* @param port The minimum port number that the tracker binds to.
|
||||
* If port is omitted, or given as None, a random ephemeral port is chosen at runtime.
|
||||
* @param maxPortTrials The maximum number of trials of socket binding, by sequentially
|
||||
* increasing the port number.
|
||||
*/
|
||||
private[scala] class RabitTracker(numWorkers: Int, port: Option[Int] = None,
|
||||
maxPortTrials: Int = 1000)
|
||||
extends IRabitTracker {
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
require(numWorkers >=1, "numWorkers must be greater than or equal to one (1).")
|
||||
|
||||
val system = ActorSystem.create("RabitTracker")
|
||||
val handler = system.actorOf(RabitTrackerHandler.props(numWorkers), "Handler")
|
||||
implicit val askTimeout: akka.util.Timeout = akka.util.Timeout(30 seconds)
|
||||
private[this] val tcpBindingTimeout: Duration = 1 minute
|
||||
|
||||
var workerEnvs: Map[String, String] = Map.empty
|
||||
|
||||
override def uncaughtException(t: Thread, e: Throwable): Unit = {
|
||||
handler ? RabitTrackerHandler.InterruptTracker(e)
|
||||
}
|
||||
|
||||
/**
|
||||
* Start the Rabit tracker.
|
||||
*
|
||||
* @param timeout The timeout for awaiting connections from worker nodes.
|
||||
* Note that when used in Spark applications, because all Spark transformations are
|
||||
* lazily executed, the I/O time for loading RDDs/DataFrames from external sources
|
||||
* (local dist, HDFS, S3 etc.) must be taken into account for the timeout value.
|
||||
* If the timeout value is too small, the Rabit tracker will likely timeout before workers
|
||||
* establishing connections to the tracker, due to the overhead of loading data.
|
||||
* Using a finite timeout is encouraged, as it prevents the tracker (thus the Spark driver
|
||||
* running it) from hanging indefinitely due to worker connection issues (e.g. firewall.)
|
||||
* @return Boolean flag indicating if the Rabit tracker starts successfully.
|
||||
*/
|
||||
private def start(timeout: Duration): Boolean = {
|
||||
val hostAddress = Option(TrackerProperties.getInstance().getHostIp)
|
||||
.map(InetAddress.getByName).getOrElse(InetAddress.getLocalHost)
|
||||
|
||||
handler ? RabitTrackerHandler.StartTracker(
|
||||
new InetSocketAddress(hostAddress, port.getOrElse(0)), maxPortTrials, timeout)
|
||||
|
||||
// block by waiting for the actor to bind to a port
|
||||
Try(Await.result(handler ? RabitTrackerHandler.RequestBoundFuture, askTimeout.duration)
|
||||
.asInstanceOf[Future[Map[String, String]]]) match {
|
||||
case Success(futurePortBound) =>
|
||||
// The success of the Future is contingent on binding to an InetSocketAddress.
|
||||
val isBound = Try(Await.ready(futurePortBound, tcpBindingTimeout)).isSuccess
|
||||
if (isBound) {
|
||||
workerEnvs = Await.result(futurePortBound, 0 nano)
|
||||
}
|
||||
isBound
|
||||
case Failure(ex: Throwable) =>
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Start the Rabit tracker.
|
||||
*
|
||||
* @param connectionTimeoutMillis Timeout, in milliseconds, for the tracker to wait for worker
|
||||
* connections. If a non-positive value is provided, the tracker
|
||||
* waits for incoming worker connections indefinitely.
|
||||
* @return Boolean flag indicating if the Rabit tracker starts successfully.
|
||||
*/
|
||||
def start(connectionTimeoutMillis: Long): Boolean = {
|
||||
if (connectionTimeoutMillis <= 0) {
|
||||
start(Duration.Inf)
|
||||
} else {
|
||||
start(Duration.fromNanos(connectionTimeoutMillis * 1e6))
|
||||
}
|
||||
}
|
||||
|
||||
def stop(): Unit = {
|
||||
system.terminate()
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a Map of necessary environment variables to initiate Rabit workers.
|
||||
*
|
||||
* @return HashMap containing tracker information.
|
||||
*/
|
||||
def getWorkerEnvs: java.util.Map[String, String] = {
|
||||
new java.util.HashMap((workerEnvs ++ Map(
|
||||
"DMLC_NUM_WORKER" -> numWorkers.toString,
|
||||
"DMLC_NUM_SERVER" -> "0"
|
||||
)).asJava)
|
||||
}
|
||||
|
||||
/**
|
||||
* Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
|
||||
* This method blocks until timeout or task completion.
|
||||
*
|
||||
* @param atMost the maximum execution time for the workers. By default,
|
||||
* the tracker waits for the workers indefinitely.
|
||||
* @return 0 if the tasks complete successfully, and non-zero otherwise.
|
||||
*/
|
||||
private def waitFor(atMost: Duration): Int = {
|
||||
// request the completion Future from the tracker actor
|
||||
Try(Await.result(handler ? RabitTrackerHandler.RequestCompletionFuture, askTimeout.duration)
|
||||
.asInstanceOf[Future[Int]]) match {
|
||||
case Success(futureCompleted) =>
|
||||
// wait for all workers to complete synchronously.
|
||||
val statusCode = Try(Await.result(futureCompleted, atMost)) match {
|
||||
case Success(n) if n == numWorkers =>
|
||||
IRabitTracker.TrackerStatus.SUCCESS.getStatusCode
|
||||
case Success(n) if n < numWorkers =>
|
||||
IRabitTracker.TrackerStatus.TIMEOUT.getStatusCode
|
||||
case Failure(e) =>
|
||||
IRabitTracker.TrackerStatus.FAILURE.getStatusCode
|
||||
}
|
||||
system.terminate()
|
||||
statusCode
|
||||
case Failure(ex: Throwable) =>
|
||||
system.terminate()
|
||||
IRabitTracker.TrackerStatus.FAILURE.getStatusCode
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Await workers to complete assigned tasks for at most 'atMostMillis' milliseconds.
|
||||
* This method blocks until timeout or task completion.
|
||||
*
|
||||
* @param atMostMillis Number of milliseconds for the tracker to wait for workers. If a
|
||||
* non-positive number is given, the tracker waits indefinitely.
|
||||
* @return 0 if the tasks complete successfully, and non-zero otherwise
|
||||
*/
|
||||
def waitFor(atMostMillis: Long): Int = {
|
||||
if (atMostMillis <= 0) {
|
||||
waitFor(Duration.Inf)
|
||||
} else {
|
||||
waitFor(Duration.fromNanos(atMostMillis * 1e6))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,361 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.rabit.handler
|
||||
|
||||
import java.net.InetSocketAddress
|
||||
import java.util.UUID
|
||||
|
||||
import scala.concurrent.duration._
|
||||
import scala.collection.mutable
|
||||
import scala.concurrent.{Promise, TimeoutException}
|
||||
import akka.io.{IO, Tcp}
|
||||
import akka.actor._
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, LinkMap}
|
||||
|
||||
import scala.util.{Failure, Random, Success, Try}
|
||||
|
||||
/** The Akka actor for handling and coordinating Rabit worker connections.
|
||||
* This is the main actor for handling socket connections, interacting with the synchronous
|
||||
* tracker interface, and resolving tree/ring/parent dependencies between workers.
|
||||
*
|
||||
* @param numWorkers Number of workers to track.
|
||||
*/
|
||||
private[scala] class RabitTrackerHandler(numWorkers: Int)
|
||||
extends Actor with ActorLogging {
|
||||
|
||||
import context.system
|
||||
import RabitWorkerHandler._
|
||||
import RabitTrackerHandler._
|
||||
|
||||
private[this] val promisedWorkerEnvs = Promise[Map[String, String]]()
|
||||
private[this] val promisedShutdownWorkers = Promise[Int]()
|
||||
private[this] val tcpManager = IO(Tcp)
|
||||
|
||||
// resolves worker connection dependency.
|
||||
val resolver = context.actorOf(Props(classOf[WorkerDependencyResolver], self), "Resolver")
|
||||
|
||||
// workers that have sent "shutdown" signal
|
||||
private[this] val shutdownWorkers = mutable.Set.empty[Int]
|
||||
private[this] val jobToRankMap = mutable.HashMap.empty[String, Int]
|
||||
private[this] val actorRefToHost = mutable.HashMap.empty[ActorRef, String]
|
||||
private[this] val ranksToAssign = mutable.ListBuffer(0 until numWorkers: _*)
|
||||
private[this] var maxPortTrials = 0
|
||||
private[this] var workerConnectionTimeout: Duration = Duration.Inf
|
||||
private[this] var portTrials = 0
|
||||
private[this] val startedWorkers = mutable.Set.empty[Int]
|
||||
|
||||
val linkMap = new LinkMap(numWorkers)
|
||||
|
||||
def decideRank(rank: Int, jobId: String = "NULL"): Option[Int] = {
|
||||
rank match {
|
||||
case r if r >= 0 => Some(r)
|
||||
case _ =>
|
||||
jobId match {
|
||||
case "NULL" => None
|
||||
case jid => jobToRankMap.get(jid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handler for all Akka Tcp connection/binding events. Read/write over the socket is handled
|
||||
* by the RabitWorkerHandler.
|
||||
*
|
||||
* @param event Generic Tcp.Event
|
||||
*/
|
||||
private def handleTcpEvents(event: Tcp.Event): Unit = event match {
|
||||
case Tcp.Bound(local) =>
|
||||
// expect all workers to connect within timeout
|
||||
log.info(s"Tracker listening @ ${local.getAddress.getHostAddress}:${local.getPort}")
|
||||
log.info(s"Worker connection timeout is $workerConnectionTimeout.")
|
||||
|
||||
context.setReceiveTimeout(workerConnectionTimeout)
|
||||
promisedWorkerEnvs.success(Map(
|
||||
"DMLC_TRACKER_URI" -> local.getAddress.getHostAddress,
|
||||
"DMLC_TRACKER_PORT" -> local.getPort.toString,
|
||||
// not required because the world size will be communicated to the
|
||||
// worker node after the rank is assigned.
|
||||
"rabit_world_size" -> numWorkers.toString
|
||||
))
|
||||
|
||||
case Tcp.CommandFailed(cmd: Tcp.Bind) =>
|
||||
if (portTrials < maxPortTrials) {
|
||||
portTrials += 1
|
||||
tcpManager ! Tcp.Bind(self,
|
||||
new InetSocketAddress(cmd.localAddress.getAddress, cmd.localAddress.getPort + 1),
|
||||
backlog = 256)
|
||||
}
|
||||
|
||||
case Tcp.Connected(remote, local) =>
|
||||
log.debug(s"Incoming connection from worker @ ${remote.getAddress.getHostAddress}")
|
||||
// revoke timeout if all workers have connected.
|
||||
val workerHandler = context.actorOf(RabitWorkerHandler.props(
|
||||
remote.getAddress.getHostAddress, numWorkers, self, sender()
|
||||
), s"ConnectionHandler-${UUID.randomUUID().toString}")
|
||||
val connection = sender()
|
||||
connection ! Tcp.Register(workerHandler, keepOpenOnPeerClosed = true)
|
||||
|
||||
actorRefToHost.put(workerHandler, remote.getAddress.getHostName)
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles external tracker control messages sent by RabitTracker (usually in ask patterns)
|
||||
* to interact with the tracker interface.
|
||||
*
|
||||
* @param trackerMsg control messages sent by RabitTracker class.
|
||||
*/
|
||||
private def handleTrackerControlMessage(trackerMsg: TrackerControlMessage): Unit =
|
||||
trackerMsg match {
|
||||
|
||||
case msg: StartTracker =>
|
||||
maxPortTrials = msg.maxPortTrials
|
||||
workerConnectionTimeout = msg.connectionTimeout
|
||||
|
||||
// if the port number is missing, try binding to a random ephemeral port.
|
||||
if (msg.addr.getPort == 0) {
|
||||
tcpManager ! Tcp.Bind(self,
|
||||
new InetSocketAddress(msg.addr.getAddress, new Random().nextInt(61000 - 32768) + 32768),
|
||||
backlog = 256)
|
||||
} else {
|
||||
tcpManager ! Tcp.Bind(self, msg.addr, backlog = 256)
|
||||
}
|
||||
sender() ! true
|
||||
|
||||
case RequestBoundFuture =>
|
||||
sender() ! promisedWorkerEnvs.future
|
||||
|
||||
case RequestCompletionFuture =>
|
||||
sender() ! promisedShutdownWorkers.future
|
||||
|
||||
case InterruptTracker(e) =>
|
||||
log.error(e, "Uncaught exception thrown by worker.")
|
||||
// make sure that waitFor() does not hang indefinitely.
|
||||
promisedShutdownWorkers.failure(e)
|
||||
context.stop(self)
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles messages sent by child actors representing connecting Rabit workers, by brokering
|
||||
* messages to the dependency resolver, and processing worker commands.
|
||||
*
|
||||
* @param workerMsg Message sent by RabitWorkerHandler actors.
|
||||
*/
|
||||
private def handleRabitWorkerMessage(workerMsg: RabitWorkerRequest): Unit = workerMsg match {
|
||||
case req @ RequestAwaitConnWorkers(_, _) =>
|
||||
// since the requester may request to connect to other workers
|
||||
// that have not fully set up, delegate this request to the
|
||||
// dependency resolver which handles the dependencies properly.
|
||||
resolver forward req
|
||||
|
||||
// ---- Rabit worker commands: start/recover/shutdown/print ----
|
||||
case WorkerTrackerPrint(_, _, _, msg) =>
|
||||
log.info(msg.trim)
|
||||
|
||||
case WorkerShutdown(rank, _, _) =>
|
||||
assert(rank >= 0, "Invalid rank.")
|
||||
assert(!shutdownWorkers.contains(rank))
|
||||
shutdownWorkers.add(rank)
|
||||
|
||||
log.info(s"Received shutdown signal from $rank")
|
||||
|
||||
if (shutdownWorkers.size == numWorkers) {
|
||||
promisedShutdownWorkers.success(shutdownWorkers.size)
|
||||
}
|
||||
|
||||
case WorkerRecover(prevRank, worldSize, jobId) =>
|
||||
assert(prevRank >= 0)
|
||||
sender() ! linkMap.assignRank(prevRank)
|
||||
|
||||
case WorkerStart(rank, worldSize, jobId) =>
|
||||
assert(worldSize == numWorkers || worldSize == -1,
|
||||
s"Purported worldSize ($worldSize) does not match worker count ($numWorkers)."
|
||||
)
|
||||
|
||||
Try(decideRank(rank, jobId).getOrElse(ranksToAssign.remove(0))) match {
|
||||
case Success(wkRank) =>
|
||||
if (jobId != "NULL") {
|
||||
jobToRankMap.put(jobId, wkRank)
|
||||
}
|
||||
|
||||
val assignedRank = linkMap.assignRank(wkRank)
|
||||
sender() ! assignedRank
|
||||
resolver ! assignedRank
|
||||
|
||||
log.info("Received start signal from " +
|
||||
s"${actorRefToHost.getOrElse(sender(), "")} [rank: $wkRank]")
|
||||
|
||||
case Failure(ex: IndexOutOfBoundsException) =>
|
||||
// More than worldSize workers have connected, likely due to executor loss.
|
||||
// Since Rabit currently does not support crash recovery (because the Allreduce results
|
||||
// are not cached by the tracker, and because existing workers cannot reestablish
|
||||
// connections to newly spawned executor/worker), the most reasonble action here is to
|
||||
// interrupt the tracker immediate with failure state.
|
||||
log.error("Received invalid start signal from " +
|
||||
s"${actorRefToHost.getOrElse(sender(), "")}: all $worldSize workers have started."
|
||||
)
|
||||
promisedShutdownWorkers.failure(new XGBoostError("Invalid start signal" +
|
||||
" received from worker, likely due to executor loss."))
|
||||
|
||||
case Failure(ex) =>
|
||||
log.error(ex, "Unexpected error")
|
||||
promisedShutdownWorkers.failure(ex)
|
||||
}
|
||||
|
||||
|
||||
// ---- Dependency resolving related messages ----
|
||||
case msg @ WorkerStarted(host, rank, awaitingAcceptance) =>
|
||||
log.info(s"Worker $host (rank: $rank) has started.")
|
||||
resolver forward msg
|
||||
|
||||
startedWorkers.add(rank)
|
||||
if (startedWorkers.size == numWorkers) {
|
||||
log.info("All workers have started.")
|
||||
}
|
||||
|
||||
case req @ DropFromWaitingList(_) =>
|
||||
// all peer workers in dependency link map have connected;
|
||||
// forward message to resolver to update dependencies.
|
||||
resolver forward req
|
||||
|
||||
case _ =>
|
||||
}
|
||||
|
||||
def receive: Actor.Receive = {
|
||||
case tcpEvent: Tcp.Event => handleTcpEvents(tcpEvent)
|
||||
case trackerMsg: TrackerControlMessage => handleTrackerControlMessage(trackerMsg)
|
||||
case workerMsg: RabitWorkerRequest => handleRabitWorkerMessage(workerMsg)
|
||||
|
||||
case akka.actor.ReceiveTimeout =>
|
||||
if (startedWorkers.size < numWorkers) {
|
||||
promisedShutdownWorkers.failure(
|
||||
new TimeoutException("Timed out waiting for workers to connect: " +
|
||||
s"${numWorkers - startedWorkers.size} of $numWorkers did not start/connect.")
|
||||
)
|
||||
context.stop(self)
|
||||
}
|
||||
|
||||
context.setReceiveTimeout(Duration.Undefined)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve the dependency between nodes as they connect to the tracker.
|
||||
* The dependency is enforced that a worker of rank K depends on its neighbors (from the treeMap
|
||||
* and ringMap) whose ranks are smaller than K. Since ranks are assigned in the order of
|
||||
* connections by workers, this dependency constraint assumes that a worker node connects first
|
||||
* is likely to finish setup first.
|
||||
*/
|
||||
private[rabit] class WorkerDependencyResolver(handler: ActorRef) extends Actor with ActorLogging {
|
||||
import RabitWorkerHandler._
|
||||
|
||||
context.watch(handler)
|
||||
|
||||
case class Fulfillment(toConnectSet: Set[Int], promise: Promise[AwaitingConnections])
|
||||
|
||||
// worker nodes that have connected, but have not send WorkerStarted message.
|
||||
private val dependencyMap = mutable.Map.empty[Int, Set[Int]]
|
||||
private val startedWorkers = mutable.Set.empty[Int]
|
||||
// worker nodes that have started, and await for connections.
|
||||
private val awaitConnWorkers = mutable.Map.empty[Int, ActorRef]
|
||||
private val pendingFulfillment = mutable.Map.empty[Int, Fulfillment]
|
||||
|
||||
def awaitingWorkers(linkSet: Set[Int]): AwaitingConnections = {
|
||||
val connSet = awaitConnWorkers.toMap
|
||||
.filterKeys(k => linkSet.contains(k))
|
||||
AwaitingConnections(connSet, linkSet.size - connSet.size)
|
||||
}
|
||||
|
||||
def receive: Actor.Receive = {
|
||||
// a copy of the AssignedRank message that is also sent to the worker
|
||||
case AssignedRank(rank, tree_neighbors, ring, parent) =>
|
||||
// the workers that the worker of given `rank` depends on:
|
||||
// worker of rank K only depends on workers with rank smaller than K.
|
||||
val dependentWorkers = (tree_neighbors.toSet ++ Set(ring._1, ring._2))
|
||||
.filter{ r => r != -1 && r < rank}
|
||||
|
||||
log.debug(s"Rank $rank connected, dependencies: $dependentWorkers")
|
||||
dependencyMap.put(rank, dependentWorkers)
|
||||
|
||||
case RequestAwaitConnWorkers(rank, toConnectSet) =>
|
||||
val promise = Promise[AwaitingConnections]()
|
||||
|
||||
assert(dependencyMap.contains(rank))
|
||||
|
||||
val updatedDependency = dependencyMap(rank) diff startedWorkers
|
||||
if (updatedDependency.isEmpty) {
|
||||
// all dependencies are satisfied
|
||||
log.debug(s"Rank $rank has all dependencies satisfied.")
|
||||
promise.success(awaitingWorkers(toConnectSet))
|
||||
} else {
|
||||
log.debug(s"Rank $rank's request for AwaitConnWorkers is pending fulfillment.")
|
||||
// promise is pending fulfillment due to unresolved dependency
|
||||
pendingFulfillment.put(rank, Fulfillment(toConnectSet, promise))
|
||||
}
|
||||
|
||||
sender() ! promise.future
|
||||
|
||||
case WorkerStarted(_, started, awaitingAcceptance) =>
|
||||
startedWorkers.add(started)
|
||||
if (awaitingAcceptance > 0) {
|
||||
awaitConnWorkers.put(started, sender())
|
||||
}
|
||||
|
||||
// remove the started rank from all dependencies.
|
||||
dependencyMap.remove(started)
|
||||
dependencyMap.foreach { case (r, dset) =>
|
||||
val updatedDependency = dset diff startedWorkers
|
||||
// fulfill the future if all dependencies are met (started.)
|
||||
if (updatedDependency.isEmpty) {
|
||||
log.debug(s"Rank $r has all dependencies satisfied.")
|
||||
pendingFulfillment.remove(r).map{
|
||||
case Fulfillment(toConnectSet, promise) =>
|
||||
promise.success(awaitingWorkers(toConnectSet))
|
||||
}
|
||||
}
|
||||
|
||||
dependencyMap.update(r, updatedDependency)
|
||||
}
|
||||
|
||||
case DropFromWaitingList(rank) =>
|
||||
assert(awaitConnWorkers.remove(rank).isDefined)
|
||||
|
||||
case Terminated(ref) =>
|
||||
if (ref.equals(handler)) {
|
||||
context.stop(self)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[scala] object RabitTrackerHandler {
|
||||
// Messages sent by RabitTracker to this RabitTrackerHandler actor
|
||||
trait TrackerControlMessage
|
||||
case object RequestCompletionFuture extends TrackerControlMessage
|
||||
case object RequestBoundFuture extends TrackerControlMessage
|
||||
// Start the Rabit tracker at given socket address awaiting worker connections.
|
||||
// All workers must connect to the tracker before connectionTimeout, otherwise the tracker will
|
||||
// shut down due to timeout.
|
||||
case class StartTracker(addr: InetSocketAddress,
|
||||
maxPortTrials: Int,
|
||||
connectionTimeout: Duration) extends TrackerControlMessage
|
||||
// To interrupt the tracker handler due to uncaught exception thrown by the thread acting as
|
||||
// driver for the distributed training.
|
||||
case class InterruptTracker(e: Throwable) extends TrackerControlMessage
|
||||
|
||||
def props(numWorkers: Int): Props =
|
||||
Props(new RabitTrackerHandler(numWorkers))
|
||||
}
|
||||
@ -1,467 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.rabit.handler
|
||||
|
||||
import java.nio.{ByteBuffer, ByteOrder}
|
||||
|
||||
import akka.io.Tcp
|
||||
import akka.actor._
|
||||
import akka.util.ByteString
|
||||
import ml.dmlc.xgboost4j.scala.rabit.util.{AssignedRank, RabitTrackerHelpers}
|
||||
|
||||
import scala.concurrent.{Await, Future}
|
||||
import scala.concurrent.duration._
|
||||
import scala.util.Try
|
||||
|
||||
/**
|
||||
* Actor to handle socket communication from worker node.
|
||||
* To handle fragmentation in received data, this class acts like a FSM
|
||||
* (finite-state machine) to keep track of the internal states.
|
||||
*
|
||||
* @param host IP address of the remote worker
|
||||
* @param worldSize number of total workers
|
||||
* @param tracker the RabitTrackerHandler actor reference
|
||||
*/
|
||||
private[scala] class RabitWorkerHandler(host: String, worldSize: Int, tracker: ActorRef,
|
||||
connection: ActorRef)
|
||||
extends FSM[RabitWorkerHandler.State, RabitWorkerHandler.DataStruct]
|
||||
with ActorLogging with Stash {
|
||||
|
||||
import RabitWorkerHandler._
|
||||
import RabitTrackerHelpers._
|
||||
|
||||
private[this] var rank: Int = 0
|
||||
private[this] var port: Int = 0
|
||||
|
||||
// indicate if the connection is transient (like "print" or "shutdown")
|
||||
private[this] var transient: Boolean = false
|
||||
private[this] var peerClosed: Boolean = false
|
||||
|
||||
// number of workers pending acceptance of current worker
|
||||
private[this] var awaitingAcceptance: Int = 0
|
||||
private[this] var neighboringWorkers = Set.empty[Int]
|
||||
|
||||
// TODO: use a single memory allocation to host all buffers,
|
||||
// including the transient ones for writing.
|
||||
private[this] val readBuffer = ByteBuffer.allocate(4096)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
// in case the received message is longer than needed,
|
||||
// stash the spilled over part in this buffer, and send
|
||||
// to self when transition occurs.
|
||||
private[this] val spillOverBuffer = ByteBuffer.allocate(4096)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
// when setup is complete, need to notify peer handlers
|
||||
// to reduce the awaiting-connection counter.
|
||||
private[this] var pendingAcknowledgement: Option[AcknowledgeAcceptance] = None
|
||||
|
||||
private def resetBuffers(): Unit = {
|
||||
readBuffer.clear()
|
||||
if (spillOverBuffer.position() > 0) {
|
||||
spillOverBuffer.flip()
|
||||
self ! Tcp.Received(ByteString.fromByteBuffer(spillOverBuffer))
|
||||
spillOverBuffer.clear()
|
||||
}
|
||||
}
|
||||
|
||||
private def stashSpillOver(buf: ByteBuffer): Unit = {
|
||||
if (buf.remaining() > 0) spillOverBuffer.put(buf)
|
||||
}
|
||||
|
||||
def getNeighboringWorkers: Set[Int] = neighboringWorkers
|
||||
|
||||
def decodeCommand(buffer: ByteBuffer): TrackerCommand = {
|
||||
val readBuffer = buffer.duplicate().order(ByteOrder.nativeOrder())
|
||||
readBuffer.flip()
|
||||
|
||||
val rank = readBuffer.getInt()
|
||||
val worldSize = readBuffer.getInt()
|
||||
val jobId = readBuffer.getString
|
||||
|
||||
val command = readBuffer.getString
|
||||
val trackerCommand = command match {
|
||||
case "start" => WorkerStart(rank, worldSize, jobId)
|
||||
case "shutdown" =>
|
||||
transient = true
|
||||
WorkerShutdown(rank, worldSize, jobId)
|
||||
case "recover" =>
|
||||
require(rank >= 0, "Invalid rank for recovering worker.")
|
||||
WorkerRecover(rank, worldSize, jobId)
|
||||
case "print" =>
|
||||
transient = true
|
||||
WorkerTrackerPrint(rank, worldSize, jobId, readBuffer.getString)
|
||||
}
|
||||
|
||||
stashSpillOver(readBuffer)
|
||||
trackerCommand
|
||||
}
|
||||
|
||||
startWith(AwaitingHandshake, DataStruct())
|
||||
|
||||
when(AwaitingHandshake) {
|
||||
case Event(Tcp.Received(magic), _) =>
|
||||
assert(magic.length == 4)
|
||||
val purportedMagic = magic.asNativeOrderByteBuffer.getInt
|
||||
assert(purportedMagic == MAGIC_NUMBER, s"invalid magic number $purportedMagic from $host")
|
||||
|
||||
// echo back the magic number
|
||||
connection ! Tcp.Write(magic)
|
||||
goto(AwaitingCommand) using StructTrackerCommand
|
||||
}
|
||||
|
||||
when(AwaitingCommand) {
|
||||
case Event(Tcp.Received(bytes), validator) =>
|
||||
bytes.asByteBuffers.foreach { buf => readBuffer.put(buf) }
|
||||
if (validator.verify(readBuffer)) {
|
||||
Try(decodeCommand(readBuffer)) match {
|
||||
case scala.util.Success(decodedCommand) =>
|
||||
tracker ! decodedCommand
|
||||
case scala.util.Failure(th: java.nio.BufferUnderflowException) =>
|
||||
// BufferUnderflowException would occur if the message to print has not arrived yet.
|
||||
// Do nothing, wait for next Tcp.Received event
|
||||
case scala.util.Failure(th: Throwable) => throw th
|
||||
}
|
||||
}
|
||||
|
||||
stay
|
||||
// when rank for a worker is assigned, send encoded rank information
|
||||
// back to worker over Tcp socket.
|
||||
case Event(aRank @ AssignedRank(assignedRank, neighbors, ring, parent), _) =>
|
||||
log.debug(s"Assigned rank [$assignedRank] for $host, T: $neighbors, R: $ring, P: $parent")
|
||||
|
||||
rank = assignedRank
|
||||
// ranks from the ring
|
||||
val ringRanks = List(
|
||||
// ringPrev
|
||||
if (ring._1 != -1 && ring._1 != rank) ring._1 else -1,
|
||||
// ringNext
|
||||
if (ring._2 != -1 && ring._2 != rank) ring._2 else -1
|
||||
)
|
||||
|
||||
// update the set of all linked workers to current worker.
|
||||
neighboringWorkers = neighbors.toSet ++ ringRanks.filterNot(_ == -1).toSet
|
||||
|
||||
connection ! Tcp.Write(ByteString.fromByteBuffer(aRank.toByteBuffer(worldSize)))
|
||||
// to prevent reading before state transition
|
||||
connection ! Tcp.SuspendReading
|
||||
goto(BuildingLinkMap) using StructNodes
|
||||
}
|
||||
|
||||
when(BuildingLinkMap) {
|
||||
case Event(Tcp.Received(bytes), validator) =>
|
||||
bytes.asByteBuffers.foreach { buf =>
|
||||
readBuffer.put(buf)
|
||||
}
|
||||
|
||||
if (validator.verify(readBuffer)) {
|
||||
readBuffer.flip()
|
||||
// for a freshly started worker, numConnected should be 0.
|
||||
val numConnected = readBuffer.getInt()
|
||||
val toConnectSet = neighboringWorkers.diff(
|
||||
(0 until numConnected).map { index => readBuffer.getInt() }.toSet)
|
||||
|
||||
// check which workers are currently awaiting connections
|
||||
tracker ! RequestAwaitConnWorkers(rank, toConnectSet)
|
||||
}
|
||||
stay
|
||||
|
||||
// got a Future from the tracker (resolver) about workers that are
|
||||
// currently awaiting connections (particularly from this node.)
|
||||
case Event(future: Future[_], _) =>
|
||||
// blocks execution until all dependencies for current worker is resolved.
|
||||
Await.result(future, 1 minute).asInstanceOf[AwaitingConnections] match {
|
||||
// numNotReachable is the number of workers that currently
|
||||
// cannot be connected to (pending connection or setup). Instead, this worker will AWAIT
|
||||
// connections from those currently non-reachable nodes in the future.
|
||||
case AwaitingConnections(waitConnNodes, numNotReachable) =>
|
||||
log.debug(s"Rank $rank needs to connect to: $waitConnNodes, # bad: $numNotReachable")
|
||||
val buf = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder())
|
||||
buf.putInt(waitConnNodes.size).putInt(numNotReachable)
|
||||
buf.flip()
|
||||
|
||||
// cache this message until the final state (SetupComplete)
|
||||
pendingAcknowledgement = Some(AcknowledgeAcceptance(
|
||||
waitConnNodes, numNotReachable))
|
||||
|
||||
connection ! Tcp.Write(ByteString.fromByteBuffer(buf))
|
||||
if (waitConnNodes.isEmpty) {
|
||||
connection ! Tcp.SuspendReading
|
||||
goto(AwaitingErrorCount)
|
||||
}
|
||||
else {
|
||||
waitConnNodes.foreach { case (peerRank, peerRef) =>
|
||||
peerRef ! RequestWorkerHostPort
|
||||
}
|
||||
|
||||
// a countdown for DivulgedHostPort messages.
|
||||
stay using DataStruct(Seq.empty[DataField], waitConnNodes.size - 1)
|
||||
}
|
||||
}
|
||||
|
||||
case Event(DivulgedWorkerHostPort(peerRank, peerHost, peerPort), data) =>
|
||||
val hostBytes = peerHost.getBytes()
|
||||
val buffer = ByteBuffer.allocate(4 * 3 + hostBytes.length)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
buffer.putInt(peerHost.length).put(hostBytes)
|
||||
.putInt(peerPort).putInt(peerRank)
|
||||
|
||||
buffer.flip()
|
||||
connection ! Tcp.Write(ByteString.fromByteBuffer(buffer))
|
||||
|
||||
if (data.counter == 0) {
|
||||
// to prevent reading before state transition
|
||||
connection ! Tcp.SuspendReading
|
||||
goto(AwaitingErrorCount)
|
||||
}
|
||||
else {
|
||||
stay using data.decrement()
|
||||
}
|
||||
}
|
||||
|
||||
when(AwaitingErrorCount) {
|
||||
case Event(Tcp.Received(numErrors), _) =>
|
||||
val buf = numErrors.asNativeOrderByteBuffer
|
||||
|
||||
buf.getInt match {
|
||||
case 0 =>
|
||||
stashSpillOver(buf)
|
||||
goto(AwaitingPortNumber)
|
||||
case _ =>
|
||||
stashSpillOver(buf)
|
||||
goto(BuildingLinkMap) using StructNodes
|
||||
}
|
||||
}
|
||||
|
||||
when(AwaitingPortNumber) {
|
||||
case Event(Tcp.Received(assignedPort), _) =>
|
||||
assert(assignedPort.length == 4)
|
||||
port = assignedPort.asNativeOrderByteBuffer.getInt
|
||||
log.debug(s"Rank $rank listening @ $host:$port")
|
||||
// wait until the worker closes connection.
|
||||
if (peerClosed) goto(SetupComplete) else stay
|
||||
|
||||
case Event(Tcp.PeerClosed, _) =>
|
||||
peerClosed = true
|
||||
if (port == 0) stay else goto(SetupComplete)
|
||||
}
|
||||
|
||||
when(SetupComplete) {
|
||||
case Event(ReduceWaitCount(count: Int), _) =>
|
||||
awaitingAcceptance -= count
|
||||
// check peerClosed to avoid prematurely stopping this actor (which sends RST to worker)
|
||||
if (awaitingAcceptance == 0 && peerClosed) {
|
||||
tracker ! DropFromWaitingList(rank)
|
||||
// no longer needed.
|
||||
context.stop(self)
|
||||
}
|
||||
stay
|
||||
|
||||
case Event(AcknowledgeAcceptance(peers, numBad), _) =>
|
||||
awaitingAcceptance = numBad
|
||||
tracker ! WorkerStarted(host, rank, awaitingAcceptance)
|
||||
peers.values.foreach { peer =>
|
||||
peer ! ReduceWaitCount(1)
|
||||
}
|
||||
|
||||
if (awaitingAcceptance == 0 && peerClosed) self ! PoisonPill
|
||||
|
||||
stay
|
||||
|
||||
// can only divulge the complete host and port information
|
||||
// when this worker is declared fully connected (otherwise
|
||||
// port information is still missing.)
|
||||
case Event(RequestWorkerHostPort, _) =>
|
||||
sender() ! DivulgedWorkerHostPort(rank, host, port)
|
||||
stay
|
||||
}
|
||||
|
||||
onTransition {
|
||||
// reset buffer when state transitions as data becomes stale
|
||||
case _ -> SetupComplete =>
|
||||
connection ! Tcp.ResumeReading
|
||||
resetBuffers()
|
||||
if (pendingAcknowledgement.isDefined) {
|
||||
self ! pendingAcknowledgement.get
|
||||
}
|
||||
case _ =>
|
||||
connection ! Tcp.ResumeReading
|
||||
resetBuffers()
|
||||
}
|
||||
|
||||
// default message handler
|
||||
whenUnhandled {
|
||||
case Event(Tcp.PeerClosed, _) =>
|
||||
peerClosed = true
|
||||
if (transient) context.stop(self)
|
||||
stay
|
||||
}
|
||||
}
|
||||
|
||||
private[scala] object RabitWorkerHandler {
|
||||
val MAGIC_NUMBER = 0xff99
|
||||
|
||||
// Finite states of this actor, which acts like a FSM.
|
||||
// The following states are defined in order as the FSM progresses.
|
||||
sealed trait State
|
||||
|
||||
// [1] Initial state, awaiting worker to send magic number per protocol.
|
||||
case object AwaitingHandshake extends State
|
||||
// [2] Awaiting worker to send command (start/print/recover/shutdown etc.)
|
||||
case object AwaitingCommand extends State
|
||||
// [3] Brokers connections between workers per ring/tree/parent link map.
|
||||
case object BuildingLinkMap extends State
|
||||
// [4] A transient state in which the worker reports the number of errors in establishing
|
||||
// connections to other peer workers. If no errors, transition to next state.
|
||||
case object AwaitingErrorCount extends State
|
||||
// [5] Awaiting the worker to report its port number for accepting connections from peer workers.
|
||||
// This port number information is later forwarded to linked workers.
|
||||
case object AwaitingPortNumber extends State
|
||||
// [6] Final state after completing the setup with the connecting worker. At this stage, the
|
||||
// worker will have closed the Tcp connection. The actor remains alive to handle messages from
|
||||
// peer actors representing workers with pending setups.
|
||||
case object SetupComplete extends State
|
||||
|
||||
sealed trait DataField
|
||||
case object IntField extends DataField
|
||||
// an integer preceding the actual string
|
||||
case object StringField extends DataField
|
||||
case object IntSeqField extends DataField
|
||||
|
||||
object DataStruct {
|
||||
def apply(): DataStruct = DataStruct(Seq.empty[DataField], 0)
|
||||
}
|
||||
|
||||
// Internal data pertaining to individual state, used to verify the validity of packets sent by
|
||||
// workers.
|
||||
case class DataStruct(fields: Seq[DataField], counter: Int) {
|
||||
/**
|
||||
* Validate whether the provided buffer is complete (i.e., contains
|
||||
* all data fields specified for this DataStruct.)
|
||||
*
|
||||
* @param buf a byte buffer containing received data.
|
||||
*/
|
||||
def verify(buf: ByteBuffer): Boolean = {
|
||||
if (fields.isEmpty) return true
|
||||
|
||||
val dupBuf = buf.duplicate().order(ByteOrder.nativeOrder())
|
||||
dupBuf.flip()
|
||||
|
||||
Try(fields.foldLeft(true) {
|
||||
case (complete, field) =>
|
||||
val remBytes = dupBuf.remaining()
|
||||
complete && (remBytes > 0) && (remBytes >= (field match {
|
||||
case IntField =>
|
||||
dupBuf.position(dupBuf.position() + 4)
|
||||
4
|
||||
case StringField =>
|
||||
val strLen = dupBuf.getInt
|
||||
dupBuf.position(dupBuf.position() + strLen)
|
||||
4 + strLen
|
||||
case IntSeqField =>
|
||||
val seqLen = dupBuf.getInt
|
||||
dupBuf.position(dupBuf.position() + seqLen * 4)
|
||||
4 + seqLen * 4
|
||||
}))
|
||||
}).getOrElse(false)
|
||||
}
|
||||
|
||||
def increment(): DataStruct = DataStruct(fields, counter + 1)
|
||||
def decrement(): DataStruct = DataStruct(fields, counter - 1)
|
||||
}
|
||||
|
||||
val StructNodes = DataStruct(List(IntSeqField), 0)
|
||||
val StructTrackerCommand = DataStruct(List(
|
||||
IntField, IntField, StringField, StringField
|
||||
), 0)
|
||||
|
||||
// ---- Messages between RabitTrackerHandler and RabitTrackerConnectionHandler ----
|
||||
|
||||
// RabitWorkerHandler --> RabitTrackerHandler
|
||||
sealed trait RabitWorkerRequest
|
||||
// RabitWorkerHandler <-- RabitTrackerHandler
|
||||
sealed trait RabitWorkerResponse
|
||||
|
||||
// Representations of decoded worker commands.
|
||||
abstract class TrackerCommand(val command: String) extends RabitWorkerRequest {
|
||||
def rank: Int
|
||||
def worldSize: Int
|
||||
def jobId: String
|
||||
|
||||
def encode: ByteString = {
|
||||
val buf = ByteBuffer.allocate(4 * 4 + jobId.length + command.length)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
|
||||
buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
|
||||
.putInt(command.length).put(command.getBytes()).flip()
|
||||
|
||||
ByteString.fromByteBuffer(buf)
|
||||
}
|
||||
}
|
||||
|
||||
case class WorkerStart(rank: Int, worldSize: Int, jobId: String)
|
||||
extends TrackerCommand("start")
|
||||
case class WorkerShutdown(rank: Int, worldSize: Int, jobId: String)
|
||||
extends TrackerCommand("shutdown")
|
||||
case class WorkerRecover(rank: Int, worldSize: Int, jobId: String)
|
||||
extends TrackerCommand("recover")
|
||||
case class WorkerTrackerPrint(rank: Int, worldSize: Int, jobId: String, msg: String)
|
||||
extends TrackerCommand("print") {
|
||||
|
||||
override def encode: ByteString = {
|
||||
val buf = ByteBuffer.allocate(4 * 5 + jobId.length + command.length + msg.length)
|
||||
.order(ByteOrder.nativeOrder())
|
||||
|
||||
buf.putInt(rank).putInt(worldSize).putInt(jobId.length).put(jobId.getBytes())
|
||||
.putInt(command.length).put(command.getBytes())
|
||||
.putInt(msg.length).put(msg.getBytes()).flip()
|
||||
|
||||
ByteString.fromByteBuffer(buf)
|
||||
}
|
||||
}
|
||||
|
||||
// Request to remove the worker of given rank from the list of workers awaiting peer connections.
|
||||
case class DropFromWaitingList(rank: Int) extends RabitWorkerRequest
|
||||
// Notify the tracker that the worker of given rank has finished setup and started.
|
||||
case class WorkerStarted(host: String, rank: Int, awaitingAcceptance: Int)
|
||||
extends RabitWorkerRequest
|
||||
// Request the set of workers to connect to, according to the LinkMap structure.
|
||||
case class RequestAwaitConnWorkers(rank: Int, toConnectSet: Set[Int])
|
||||
extends RabitWorkerRequest
|
||||
|
||||
// Request, from the tracker, the set of nodes to connect.
|
||||
case class AwaitingConnections(workers: Map[Int, ActorRef], numBad: Int)
|
||||
extends RabitWorkerResponse
|
||||
|
||||
// ---- Messages between ConnectionHandler actors ----
|
||||
sealed trait IntraWorkerMessage
|
||||
|
||||
// Notify neighboring workers to decrease the counter of awaiting workers by `count`.
|
||||
case class ReduceWaitCount(count: Int) extends IntraWorkerMessage
|
||||
// Request host and port information from peer ConnectionHandler actors (acting on behave of
|
||||
// connecting workers.) This message will be brokered by RabitTrackerHandler.
|
||||
case object RequestWorkerHostPort extends IntraWorkerMessage
|
||||
// Response to the above request
|
||||
case class DivulgedWorkerHostPort(rank: Int, host: String, port: Int) extends IntraWorkerMessage
|
||||
// A reminder to send ReduceWaitCount messages once the actor is in state "SetupComplete".
|
||||
case class AcknowledgeAcceptance(peers: Map[Int, ActorRef], numBad: Int)
|
||||
extends IntraWorkerMessage
|
||||
|
||||
// ---- End of message definitions ----
|
||||
|
||||
def props(host: String, worldSize: Int, tracker: ActorRef, connection: ActorRef): Props = {
|
||||
Props(new RabitWorkerHandler(host, worldSize, tracker, connection))
|
||||
}
|
||||
}
|
||||
@ -1,136 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.rabit.util
|
||||
|
||||
import java.nio.{ByteBuffer, ByteOrder}
|
||||
|
||||
/**
|
||||
* The assigned rank to a connecting Rabit worker, along with the information of the ranks of
|
||||
* its linked peer workers, which are critical to perform Allreduce.
|
||||
* When RabitWorkerHandler delegates "start" or "recover" commands from the connecting worker
|
||||
* client, RabitTrackerHandler utilizes LinkMap to figure out linkage relationships, and respond
|
||||
* with this class as a message, which is later encoded to byte string, and sent over socket
|
||||
* connection to the worker client.
|
||||
*
|
||||
* @param rank assigned rank (ranked by worker connection order: first worker connecting to the
|
||||
* tracker is assigned rank 0, second with rank 1, etc.)
|
||||
* @param neighbors ranks of neighboring workers in a tree map.
|
||||
* @param ring ranks of neighboring workers in a ring map.
|
||||
* @param parent rank of the parent worker.
|
||||
*/
|
||||
private[rabit] case class AssignedRank(rank: Int, neighbors: Seq[Int],
|
||||
ring: (Int, Int), parent: Int) {
|
||||
/**
|
||||
* Encode the AssignedRank message into byte sequence for socket communication with Rabit worker
|
||||
* client.
|
||||
* @param worldSize the number of total distributed workers. Must match `numWorkers` used in
|
||||
* LinkMap.
|
||||
* @return a ByteBuffer containing encoded data.
|
||||
*/
|
||||
def toByteBuffer(worldSize: Int): ByteBuffer = {
|
||||
val buffer = ByteBuffer.allocate(4 * (neighbors.length + 6)).order(ByteOrder.nativeOrder())
|
||||
buffer.putInt(rank).putInt(parent).putInt(worldSize).putInt(neighbors.length)
|
||||
// neighbors in tree structure
|
||||
neighbors.foreach { n => buffer.putInt(n) }
|
||||
buffer.putInt(if (ring._1 != -1 && ring._1 != rank) ring._1 else -1)
|
||||
buffer.putInt(if (ring._2 != -1 && ring._2 != rank) ring._2 else -1)
|
||||
|
||||
buffer.flip()
|
||||
buffer
|
||||
}
|
||||
}
|
||||
|
||||
private[rabit] class LinkMap(numWorkers: Int) {
|
||||
private def getNeighbors(rank: Int): Seq[Int] = {
|
||||
val rank1 = rank + 1
|
||||
Vector(rank1 / 2 - 1, rank1 * 2 - 1, rank1 * 2).filter { r =>
|
||||
r >= 0 && r < numWorkers
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a ring structure that tends to share nodes with the tree.
|
||||
*
|
||||
* @param treeMap
|
||||
* @param parentMap
|
||||
* @param rank
|
||||
* @return Seq[Int] instance starting from rank.
|
||||
*/
|
||||
private def constructShareRing(treeMap: Map[Int, Seq[Int]],
|
||||
parentMap: Map[Int, Int],
|
||||
rank: Int = 0): Seq[Int] = {
|
||||
treeMap(rank).toSet - parentMap(rank) match {
|
||||
case emptySet if emptySet.isEmpty =>
|
||||
List(rank)
|
||||
case connectionSet =>
|
||||
connectionSet.zipWithIndex.foldLeft(List(rank)) {
|
||||
case (ringSeq, (v, cnt)) =>
|
||||
val vConnSeq = constructShareRing(treeMap, parentMap, v)
|
||||
vConnSeq match {
|
||||
case vconn if vconn.size == cnt + 1 =>
|
||||
ringSeq ++ vconn.reverse
|
||||
case vconn =>
|
||||
ringSeq ++ vconn
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Construct a ring connection used to recover local data.
|
||||
*
|
||||
* @param treeMap
|
||||
* @param parentMap
|
||||
*/
|
||||
private def constructRingMap(treeMap: Map[Int, Seq[Int]], parentMap: Map[Int, Int]) = {
|
||||
assert(parentMap(0) == -1)
|
||||
|
||||
val sharedRing = constructShareRing(treeMap, parentMap, 0).toVector
|
||||
assert(sharedRing.length == treeMap.size)
|
||||
|
||||
(0 until numWorkers).map { r =>
|
||||
val rPrev = (r + numWorkers - 1) % numWorkers
|
||||
val rNext = (r + 1) % numWorkers
|
||||
sharedRing(r) -> (sharedRing(rPrev), sharedRing(rNext))
|
||||
}.toMap
|
||||
}
|
||||
|
||||
private[this] val treeMap_ = (0 until numWorkers).map { r => r -> getNeighbors(r) }.toMap
|
||||
private[this] val parentMap_ = (0 until numWorkers).map{ r => r -> ((r + 1) / 2 - 1) }.toMap
|
||||
private[this] val ringMap_ = constructRingMap(treeMap_, parentMap_)
|
||||
val rMap_ = (0 until (numWorkers - 1)).foldLeft((Map(0 -> 0), 0)) {
|
||||
case ((rmap, k), i) =>
|
||||
val kNext = ringMap_(k)._2
|
||||
(rmap ++ Map(kNext -> (i + 1)), kNext)
|
||||
}._1
|
||||
|
||||
val ringMap = ringMap_.map {
|
||||
case (k, (v0, v1)) => rMap_(k) -> (rMap_(v0), rMap_(v1))
|
||||
}
|
||||
val treeMap = treeMap_.map {
|
||||
case (k, vSeq) => rMap_(k) -> vSeq.map{ v => rMap_(v) }
|
||||
}
|
||||
val parentMap = parentMap_.map {
|
||||
case (k, v) if k == 0 =>
|
||||
rMap_(k) -> -1
|
||||
case (k, v) =>
|
||||
rMap_(k) -> rMap_(v)
|
||||
}
|
||||
|
||||
def assignRank(rank: Int): AssignedRank = {
|
||||
AssignedRank(rank, treeMap(rank), ringMap(rank), parentMap(rank))
|
||||
}
|
||||
}
|
||||
@ -1,39 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.rabit.util
|
||||
|
||||
import java.nio.{ByteOrder, ByteBuffer}
|
||||
import akka.util.ByteString
|
||||
|
||||
private[rabit] object RabitTrackerHelpers {
|
||||
implicit class ByteStringHelplers(bs: ByteString) {
|
||||
// Java by default uses big endian. Enforce native endian so that
|
||||
// the byte order is consistent with the workers.
|
||||
def asNativeOrderByteBuffer: ByteBuffer = {
|
||||
bs.asByteBuffer.order(ByteOrder.nativeOrder())
|
||||
}
|
||||
}
|
||||
|
||||
implicit class ByteBufferHelpers(buf: ByteBuffer) {
|
||||
def getString: String = {
|
||||
val len = buf.getInt()
|
||||
val stringBuffer = ByteBuffer.allocate(len).order(ByteOrder.nativeOrder())
|
||||
buf.get(stringBuffer.array(), 0, len)
|
||||
new String(stringBuffer.array(), "utf-8")
|
||||
}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user