diff --git a/src/common/random.h b/src/common/random.h index 66b73fdbc..00b7046de 100644 --- a/src/common/random.h +++ b/src/common/random.h @@ -89,9 +89,10 @@ class ColumnSampler { float colsample_bylevel_{1.0f}; float colsample_bytree_{1.0f}; float colsample_bynode_{1.0f}; + GlobalRandomEngine rng_; std::shared_ptr> ColSample - (std::shared_ptr> p_features, float colsample) const { + (std::shared_ptr> p_features, float colsample) { if (colsample == 1.0f) return p_features; const auto& features = *p_features; CHECK_GT(features.size(), 0); @@ -100,17 +101,24 @@ class ColumnSampler { auto& new_features = *p_new_features; new_features.resize(features.size()); std::copy(features.begin(), features.end(), new_features.begin()); - std::shuffle(new_features.begin(), new_features.end(), common::GlobalRandom()); + std::shuffle(new_features.begin(), new_features.end(), rng_); new_features.resize(n); std::sort(new_features.begin(), new_features.end()); - // ensure that new_features are the same across ranks - rabit::Broadcast(&new_features, 0); - return p_new_features; } public: + /** + * \brief Column sampler constructor. + * \note This constructor synchronizes the RNG seed across processes. + */ + ColumnSampler() { + uint32_t seed = common::GlobalRandom()(); + rabit::Broadcast(&seed, sizeof(seed), 0); + rng_.seed(seed); + } + /** * \brief Initialise this object before use. * @@ -153,6 +161,9 @@ class ColumnSampler { * \return The sampled feature set. * \note If colsample_bynode_ < 1.0, this method creates a new feature set each time it * is called. Therefore, it should be called only once per node. + * \note With distributed xgboost, this function must be called exactly once for the + * construction of each tree node, and must be called the same number of times in each + * process and with the same parameters to return the same feature set across processes. */ std::shared_ptr> GetFeatureSet(int depth) { if (colsample_bylevel_ == 1.0f && colsample_bynode_ == 1.0f) { diff --git a/tests/ci_build/test_mgpu.sh b/tests/ci_build/test_mgpu.sh index 2dfafcc2e..a1b56549b 100755 --- a/tests/ci_build/test_mgpu.sh +++ b/tests/ci_build/test_mgpu.sh @@ -7,5 +7,5 @@ cd .. pytest -v -s --fulltrace -m "(not slow) and mgpu" tests/python-gpu ./testxgboost --gtest_filter=*.MGPU_* -cd tests/distributed-gpu -./runtests-gpu.sh \ No newline at end of file +cd tests/distributed +./runtests-gpu.sh diff --git a/tests/distributed-gpu/runtests-gpu.sh b/tests/distributed-gpu/runtests-gpu.sh deleted file mode 100755 index e3fa8a0d3..000000000 --- a/tests/distributed-gpu/runtests-gpu.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -rm -f *.model* - -echo -e "\n ====== 1. Basic distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n" -PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=4 \ - python test_gpu_basic_1x4.py - -echo -e "\n ====== 2. Basic distributed-gpu test with Python: 2 workers; 2 GPUs per worker ====== \n" -PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=2 \ - python test_gpu_basic_2x2.py - -echo -e "\n ====== 3. Basic distributed-gpu test with Python: 2 workers; Rank 0: 1 GPU, Rank 1: 3 GPUs ====== \n" -PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=2 \ - python test_gpu_basic_asym.py - -echo -e "\n ====== 4. Basic distributed-gpu test with Python: 1 worker; 4 GPUs per worker ====== \n" -PYTHONPATH=../../python-package/ python ../../dmlc-core/tracker/dmlc-submit --cluster=local --num-workers=1 \ - python test_gpu_basic_4x1.py \ No newline at end of file diff --git a/tests/distributed-gpu/test_gpu_basic_1x4.py b/tests/distributed-gpu/test_gpu_basic_1x4.py deleted file mode 100644 index d325a167a..000000000 --- a/tests/distributed-gpu/test_gpu_basic_1x4.py +++ /dev/null @@ -1,51 +0,0 @@ -#!/usr/bin/python -import xgboost as xgb -import time -from collections import OrderedDict - -# Always call this before using distributed module -xgb.rabit.init() -rank = xgb.rabit.get_rank() -world = xgb.rabit.get_world_size() - -# Load file, file will be automatically sharded in distributed mode. -dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train') -dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test') - -# Specify parameters via map, definition are same as c++ version -param = {'n_gpus': 1, 'gpu_id': rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' } - -# Specify validations set to watch performance -watchlist = [(dtest,'eval'), (dtrain,'train')] -num_round = 20 - -# Run training, all the features in training API is available. -# Currently, this script only support calling train once for fault recovery purpose. -bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2) - -# Have each worker save its model -model_name = "test.model.1x4." + str(rank) -bst.dump_model(model_name, with_stats=True); time.sleep(2) -xgb.rabit.tracker_print("Finished training\n") - -fail = False -if (rank == 0): - for i in range(0, world): - model_name_root = "test.model.1x4." + str(i) - for j in range(0, world): - if i != j: - with open(model_name_root, 'r') as model_root: - model_name_rank = "test.model.1x4." + str(j) - with open(model_name_rank, 'r') as model_rank: - diff = set(model_root).difference(model_rank) - if len(diff) != 0: - fail = True - xgb.rabit.finalize() - raise Exception('Worker models diverged: test.model.1x4.{} differs from test.model.1x4.{}'.format(i, j)) - -if (rank != 0) and (fail): - xgb.rabit.finalize() - -# Notify the tracker all training has been successful -# This is only needed in distributed training. -xgb.rabit.finalize() diff --git a/tests/distributed-gpu/test_gpu_basic_2x2.py b/tests/distributed-gpu/test_gpu_basic_2x2.py deleted file mode 100644 index b1669560c..000000000 --- a/tests/distributed-gpu/test_gpu_basic_2x2.py +++ /dev/null @@ -1,51 +0,0 @@ -#!/usr/bin/python -import xgboost as xgb -import time -from collections import OrderedDict - -# Always call this before using distributed module -xgb.rabit.init() -rank = xgb.rabit.get_rank() -world = xgb.rabit.get_world_size() - -# Load file, file will be automatically sharded in distributed mode. -dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train') -dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test') - -# Specify parameters via map, definition are same as c++ version -param = {'n_gpus': 2, 'gpu_id': 2*rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' } - -# Specify validations set to watch performance -watchlist = [(dtest,'eval'), (dtrain,'train')] -num_round = 20 - -# Run training, all the features in training API is available. -# Currently, this script only support calling train once for fault recovery purpose. -bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2) - -# Have each worker save its model -model_name = "test.model.2x2." + str(rank) -bst.dump_model(model_name, with_stats=True); time.sleep(2) -xgb.rabit.tracker_print("Finished training\n") - -fail = False -if (rank == 0): - for i in range(0, world): - model_name_root = "test.model.2x2." + str(i) - for j in range(0, world): - if i != j: - with open(model_name_root, 'r') as model_root: - model_name_rank = "test.model.2x2." + str(j) - with open(model_name_rank, 'r') as model_rank: - diff = set(model_root).difference(model_rank) - if len(diff) != 0: - fail = True - xgb.rabit.finalize() - raise Exception('Worker models diverged: test.model.2x2.{} differs from test.model.2x2.{}'.format(i, j)) - -if (rank != 0) and (fail): - xgb.rabit.finalize() - -# Notify the tracker all training has been successful -# This is only needed in distributed training. -xgb.rabit.finalize() diff --git a/tests/distributed-gpu/test_gpu_basic_4x1.py b/tests/distributed-gpu/test_gpu_basic_4x1.py deleted file mode 100644 index 6662a3ac6..000000000 --- a/tests/distributed-gpu/test_gpu_basic_4x1.py +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/python -import xgboost as xgb -import time -from collections import OrderedDict - -# Always call this before using distributed module -xgb.rabit.init() -rank = xgb.rabit.get_rank() -world = xgb.rabit.get_world_size() - -# Load file, file will be automatically sharded in distributed mode. -dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train') -dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test') - -# Specify parameters via map, definition are same as c++ version -param = {'n_gpus': 4, 'gpu_id': rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' } - -# Specify validations set to watch performance -watchlist = [(dtest,'eval'), (dtrain,'train')] -num_round = 20 - -# Run training, all the features in training API is available. -# Currently, this script only support calling train once for fault recovery purpose. -bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2) - -# Have root save its model -if(rank == 0): - model_name = "test.model.4x1." + str(rank) - bst.dump_model(model_name, with_stats=True) -xgb.rabit.tracker_print("Finished training\n") - -# Notify the tracker all training has been successful -# This is only needed in distributed training. -xgb.rabit.finalize() diff --git a/tests/distributed-gpu/test_gpu_basic_asym.py b/tests/distributed-gpu/test_gpu_basic_asym.py deleted file mode 100644 index e20304ef2..000000000 --- a/tests/distributed-gpu/test_gpu_basic_asym.py +++ /dev/null @@ -1,54 +0,0 @@ -#!/usr/bin/python -import xgboost as xgb -import time -from collections import OrderedDict - -# Always call this before using distributed module -xgb.rabit.init() -rank = xgb.rabit.get_rank() -world = xgb.rabit.get_world_size() - -# Load file, file will be automatically sharded in distributed mode. -dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train') -dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test') - -# Specify parameters via map, definition are same as c++ version -if rank == 0: - param = {'n_gpus': 1, 'gpu_id': rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' } -else: - param = {'n_gpus': 3, 'gpu_id': rank, 'tree_method': 'gpu_hist', 'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' } - -# Specify validations set to watch performance -watchlist = [(dtest,'eval'), (dtrain,'train')] -num_round = 20 - -# Run training, all the features in training API is available. -# Currently, this script only support calling train once for fault recovery purpose. -bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2) - -# Have each worker save its model -model_name = "test.model.asym." + str(rank) -bst.dump_model(model_name, with_stats=True); time.sleep(2) -xgb.rabit.tracker_print("Finished training\n") - -fail = False -if (rank == 0): - for i in range(0, world): - model_name_root = "test.model.asym." + str(i) - for j in range(0, world): - if i != j: - with open(model_name_root, 'r') as model_root: - model_name_rank = "test.model.asym." + str(j) - with open(model_name_rank, 'r') as model_rank: - diff = set(model_root).difference(model_rank) - if len(diff) != 0: - fail = True - xgb.rabit.finalize() - raise Exception('Worker models diverged: test.model.asym.{} differs from test.model.asym.{}'.format(i, j)) - -if (rank != 0) and (fail): - xgb.rabit.finalize() - -# Notify the tracker all training has been successful -# This is only needed in distributed training. -xgb.rabit.finalize() diff --git a/tests/distributed/distributed_gpu.py b/tests/distributed/distributed_gpu.py new file mode 100644 index 000000000..172b0443d --- /dev/null +++ b/tests/distributed/distributed_gpu.py @@ -0,0 +1,113 @@ +"""Distributed GPU tests.""" +import sys +import time +import xgboost as xgb + + +def run_test(name, params_fun): + """Runs a distributed GPU test.""" + # Always call this before using distributed module + xgb.rabit.init() + rank = xgb.rabit.get_rank() + world = xgb.rabit.get_world_size() + + # Load file, file will be automatically sharded in distributed mode. + dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train') + dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test') + + params, n_rounds = params_fun(rank) + + # Specify validations set to watch performance + watchlist = [(dtest, 'eval'), (dtrain, 'train')] + + # Run training, all the features in training API is available. + # Currently, this script only support calling train once for fault recovery purpose. + bst = xgb.train(params, dtrain, n_rounds, watchlist, early_stopping_rounds=2) + + # Have each worker save its model + model_name = "test.model.%s.%d" % (name, rank) + bst.dump_model(model_name, with_stats=True) + time.sleep(2) + xgb.rabit.tracker_print("Finished training\n") + + if (rank == 0): + for i in range(0, world): + model_name_root = "test.model.%s.%d" % (name, i) + for j in range(0, world): + if i == j: + continue + with open(model_name_root, 'r') as model_root: + contents_root = model_root.read() + model_name_rank = "test.model.%s.%d" % (name, j) + with open(model_name_rank, 'r') as model_rank: + contents_rank = model_rank.read() + if contents_root != contents_rank: + raise Exception( + ('Worker models diverged: test.model.%s.%d ' + 'differs from test.model.%s.%d') % (name, i, name, j)) + + xgb.rabit.finalize() + + +base_params = { + 'tree_method': 'gpu_hist', + 'max_depth': 2, + 'eta': 1, + 'verbosity': 0, + 'objective': 'binary:logistic' +} + + +def params_basic_1x4(rank): + return dict(base_params, **{ + 'n_gpus': 1, + 'gpu_id': rank, + }), 20 + + +def params_basic_2x2(rank): + return dict(base_params, **{ + 'n_gpus': 2, + 'gpu_id': 2*rank, + }), 20 + + +def params_basic_4x1(rank): + return dict(base_params, **{ + 'n_gpus': 4, + 'gpu_id': rank, + }), 20 + + +def params_basic_asym(rank): + return dict(base_params, **{ + 'n_gpus': 1 if rank == 0 else 3, + 'gpu_id': rank, + }), 20 + + +rf_update_params = { + 'subsample': 0.5, + 'colsample_bynode': 0.5 +} + + +def wrap_rf(params_fun): + def wrapped_params_fun(rank): + params, n_estimators = params_fun(rank) + rf_params = dict(rf_update_params, num_parallel_tree=n_estimators) + return dict(params, **rf_params), 1 + return wrapped_params_fun + + +params_rf_1x4 = wrap_rf(params_basic_1x4) + +params_rf_2x2 = wrap_rf(params_basic_2x2) + +params_rf_4x1 = wrap_rf(params_basic_4x1) + +params_rf_asym = wrap_rf(params_basic_asym) + + +test_name = sys.argv[1] +run_test(test_name, globals()['params_%s' % test_name]) diff --git a/tests/distributed/runtests-gpu.sh b/tests/distributed/runtests-gpu.sh new file mode 100755 index 000000000..de8e71fac --- /dev/null +++ b/tests/distributed/runtests-gpu.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +rm -f *.model* + +export DMLC_SUBMIT_CLUSTER=local +export PYTHONPATH=../../python-package +submit="timeout 30 python ../../dmlc-core/tracker/dmlc-submit" + +echo -e "\n ====== 1. Basic distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n" +$submit --num-workers=4 python distributed_gpu.py basic_1x4 || exit 1 + +echo -e "\n ====== 2. Basic distributed-gpu test with Python: 2 workers; 2 GPUs per worker ====== \n" +$submit --num-workers=2 python distributed_gpu.py basic_2x2 || exit 1 + +echo -e "\n ====== 3. Basic distributed-gpu test with Python: 2 workers; Rank 0: 1 GPU, Rank 1: 3 GPUs ====== \n" +$submit --num-workers=2 python distributed_gpu.py basic_asym || exit 1 + +echo -e "\n ====== 4. Basic distributed-gpu test with Python: 1 worker; 4 GPUs per worker ====== \n" +$submit --num-workers=1 python distributed_gpu.py basic_4x1 || exit 1 + +echo -e "\n ====== 5. RF distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n" +$submit --num-workers=4 python distributed_gpu.py rf_1x4 || exit 1 + +echo -e "\n ====== 6. RF distributed-gpu test with Python: 2 workers; 2 GPUs per worker ====== \n" +$submit --num-workers=2 python distributed_gpu.py rf_2x2 || exit 1 + +echo -e "\n ====== 7. RF distributed-gpu test with Python: 2 workers; Rank 0: 1 GPU, Rank 1: 3 GPUs ====== \n" +$submit --num-workers=2 python distributed_gpu.py rf_asym || exit 1 + +echo -e "\n ====== 8. RF distributed-gpu test with Python: 1 worker; 4 GPUs per worker ====== \n" +$submit --num-workers=1 python distributed_gpu.py rf_4x1 || exit 1