[BREAKING] prevent multi-gpu usage (#4749)

* prevent multi-gpu usage

* fix distributed test

* combine gpu predictor tests

* set upper bound on n_gpus
This commit is contained in:
Rong Ou 2019-08-12 14:11:35 -07:00 committed by Rory Mitchell
parent 198f3a6c4a
commit c5b229632d
14 changed files with 59 additions and 298 deletions

View File

@ -40,10 +40,10 @@ struct GenericParameter : public dmlc::Parameter<GenericParameter> {
.describe("The primary GPU device ordinal."); .describe("The primary GPU device ordinal.");
DMLC_DECLARE_FIELD(n_gpus) DMLC_DECLARE_FIELD(n_gpus)
.set_default(0) .set_default(0)
.set_lower_bound(-1) .set_range(0, 1)
.describe("Deprecated, please use distributed training with one " .describe("Deprecated. Single process multi-GPU training is no longer supported. "
"process per GPU. " "Please switch to distributed training with one process per GPU. "
"Number of GPUs to use for multi-gpu algorithms."); "This can be done using Dask or Spark.");
} }
}; };
} // namespace xgboost } // namespace xgboost

View File

@ -580,9 +580,16 @@ class LearnerImpl : public Learner {
} }
gbm_->Configure(args); gbm_->Configure(args);
if (this->gbm_->UseGPU() && cfg_.find("n_gpus") == cfg_.cend()) { if (this->gbm_->UseGPU()) {
if (cfg_.find("n_gpus") == cfg_.cend()) {
generic_param_.n_gpus = 1; generic_param_.n_gpus = 1;
} }
if (generic_param_.n_gpus != 1) {
LOG(FATAL) << "Single process multi-GPU training is no longer supported. "
"Please switch to distributed GPU training with one process per GPU. "
"This can be done using Dask or Spark.";
}
}
} }
// set number of features correctly. // set number of features correctly.

View File

@ -88,19 +88,5 @@ TEST(gpu_hist_util, DeviceSketch_ExternalMemory) {
TestDeviceSketch(GPUSet::Range(0, 1), true); TestDeviceSketch(GPUSet::Range(0, 1), true);
} }
#if defined(XGBOOST_USE_NCCL)
TEST(gpu_hist_util, MGPU_DeviceSketch) {
auto devices = GPUSet::AllVisible();
CHECK_GT(devices.Size(), 1);
TestDeviceSketch(devices, false);
}
TEST(gpu_hist_util, MGPU_DeviceSketch_ExternalMemory) {
auto devices = GPUSet::AllVisible();
CHECK_GT(devices.Size(), 1);
TestDeviceSketch(devices, true);
}
#endif
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -24,47 +24,4 @@ TEST(Linear, GPUCoordinate) {
delete mat; delete mat;
} }
#if defined(XGBOOST_USE_NCCL)
TEST(Linear, MGPU_GPUCoordinate) {
{
auto mat = xgboost::CreateDMatrix(10, 10, 0);
auto lparam = CreateEmptyGenericParam(0, -1);
lparam.n_gpus = -1;
auto updater = std::unique_ptr<xgboost::LinearUpdater>(
xgboost::LinearUpdater::Create("gpu_coord_descent", &lparam));
updater->Configure({{"eta", "1."}});
xgboost::HostDeviceVector<xgboost::GradientPair> gpair(
(*mat)->Info().num_row_, xgboost::GradientPair(-5, 1.0));
xgboost::gbm::GBLinearModel model;
model.param.num_feature = (*mat)->Info().num_col_;
model.param.num_output_group = 1;
model.LazyInitModel();
updater->Update(&gpair, (*mat).get(), &model, gpair.Size());
ASSERT_EQ(model.bias()[0], 5.0f);
delete mat;
}
{
auto lparam = CreateEmptyGenericParam(1, -1);
lparam.n_gpus = -1;
auto mat = xgboost::CreateDMatrix(10, 10, 0);
auto updater = std::unique_ptr<xgboost::LinearUpdater>(
xgboost::LinearUpdater::Create("gpu_coord_descent", &lparam));
updater->Configure({{"eta", "1."}});
xgboost::HostDeviceVector<xgboost::GradientPair> gpair(
(*mat)->Info().num_row_, xgboost::GradientPair(-5, 1.0));
xgboost::gbm::GBLinearModel model;
model.param.num_feature = (*mat)->Info().num_col_;
model.param.num_output_group = 1;
model.LazyInitModel();
updater->Update(&gpair, (*mat).get(), &model, gpair.Size());
ASSERT_EQ(model.bias()[0], 5.0f);
delete mat;
}
}
#endif
} // namespace xgboost } // namespace xgboost

View File

@ -101,32 +101,3 @@ TEST(Metric, DeclareUnifiedTest(PoissionNegLogLik)) {
1.1280f, 0.001f); 1.1280f, 0.001f);
delete metric; delete metric;
} }
#if defined(XGBOOST_USE_NCCL) && defined(__CUDACC__)
TEST(Metric, MGPU_RMSE) {
{
auto lparam = xgboost::CreateEmptyGenericParam(0, -1);
xgboost::Metric * metric = xgboost::Metric::Create("rmse", &lparam);
metric->Configure({});
ASSERT_STREQ(metric->Name(), "rmse");
EXPECT_NEAR(GetMetricEval(metric, {0}, {0}), 0, 1e-10);
EXPECT_NEAR(GetMetricEval(metric,
{0.1f, 0.9f, 0.1f, 0.9f},
{ 0, 0, 1, 1}),
0.6403f, 0.001f);
delete metric;
}
{
auto lparam = xgboost::CreateEmptyGenericParam(1, -1);
xgboost::Metric * metric = xgboost::Metric::Create("rmse", &lparam);
ASSERT_STREQ(metric->Name(), "rmse");
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0, 1e-10);
EXPECT_NEAR(GetMetricEval(metric,
{0.1f, 0.9f, 0.1f, 0.9f},
{ 0, 0, 1, 1}),
0.6403f, 0.001f);
delete metric;
}
}
#endif

View File

@ -12,7 +12,6 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "../helpers.h" #include "../helpers.h"
#if defined(XGBOOST_USE_NCCL)
namespace { namespace {
inline void CheckCAPICall(int ret) { inline void CheckCAPICall(int ret) {
@ -20,7 +19,6 @@ inline void CheckCAPICall(int ret) {
} }
} // namespace anonymous } // namespace anonymous
#endif
const std::map<std::string, std::string>& const std::map<std::string, std::string>&
QueryBoosterConfigurationArguments(BoosterHandle handle) { QueryBoosterConfigurationArguments(BoosterHandle handle) {
@ -46,26 +44,28 @@ TEST(gpu_predictor, Test) {
gpu_predictor->Configure({}, {}); gpu_predictor->Configure({}, {});
cpu_predictor->Configure({}, {}); cpu_predictor->Configure({}, {});
int n_row = 5; for (size_t i = 1; i < 33; i *= 2) {
int n_col = 5; int n_row = i, n_col = i;
auto dmat = CreateDMatrix(n_row, n_col, 0);
gbm::GBTreeModel model = CreateTestModel(); gbm::GBTreeModel model = CreateTestModel();
model.param.num_feature = n_col; model.param.num_feature = n_col;
auto dmat = CreateDMatrix(n_row, n_col, 0);
// Test predict batch // Test predict batch
HostDeviceVector<float> gpu_out_predictions; HostDeviceVector<float> gpu_out_predictions;
HostDeviceVector<float> cpu_out_predictions; HostDeviceVector<float> cpu_out_predictions;
gpu_predictor->PredictBatch((*dmat).get(), &gpu_out_predictions, model, 0); gpu_predictor->PredictBatch((*dmat).get(), &gpu_out_predictions, model, 0);
cpu_predictor->PredictBatch((*dmat).get(), &cpu_out_predictions, model, 0); cpu_predictor->PredictBatch((*dmat).get(), &cpu_out_predictions, model, 0);
std::vector<float>& gpu_out_predictions_h = gpu_out_predictions.HostVector(); std::vector<float>& gpu_out_predictions_h = gpu_out_predictions.HostVector();
std::vector<float>& cpu_out_predictions_h = cpu_out_predictions.HostVector(); std::vector<float>& cpu_out_predictions_h = cpu_out_predictions.HostVector();
float abs_tolerance = 0.001; float abs_tolerance = 0.001;
for (int i = 0; i < gpu_out_predictions.Size(); i++) { for (int j = 0; j < gpu_out_predictions.Size(); j++) {
ASSERT_NEAR(gpu_out_predictions_h[i], cpu_out_predictions_h[i], abs_tolerance); ASSERT_NEAR(gpu_out_predictions_h[j], cpu_out_predictions_h[j], abs_tolerance);
} }
delete dmat; delete dmat;
}
} }
TEST(gpu_predictor, ExternalMemoryTest) { TEST(gpu_predictor, ExternalMemoryTest) {
@ -74,25 +74,35 @@ TEST(gpu_predictor, ExternalMemoryTest) {
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &lparam)); std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &lparam));
gpu_predictor->Configure({}, {}); gpu_predictor->Configure({}, {});
gbm::GBTreeModel model = CreateTestModel(); gbm::GBTreeModel model = CreateTestModel();
int n_col = 3; model.param.num_feature = 3;
model.param.num_feature = n_col; const int n_classes = 3;
model.param.num_output_group = n_classes;
std::vector<std::unique_ptr<DMatrix>> dmats;
dmlc::TemporaryDirectory tmpdir; dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm"; std::string file0 = tmpdir.path + "/big_0.libsvm";
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(32, 64, filename); std::string file1 = tmpdir.path + "/big_1.libsvm";
std::string file2 = tmpdir.path + "/big_2.libsvm";
dmats.push_back(CreateSparsePageDMatrix(9, 64UL, file0));
dmats.push_back(CreateSparsePageDMatrix(128, 128UL, file1));
dmats.push_back(CreateSparsePageDMatrix(1024, 1024UL, file2));
for (const auto& dmat: dmats) {
// Test predict batch // Test predict batch
HostDeviceVector<float> out_predictions; HostDeviceVector<float> out_predictions;
gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
EXPECT_EQ(out_predictions.Size(), dmat->Info().num_row_); EXPECT_EQ(out_predictions.Size(), dmat->Info().num_row_ * n_classes);
for (const auto& v : out_predictions.HostVector()) { const std::vector<float> &host_vector = out_predictions.ConstHostVector();
ASSERT_EQ(v, 1.5); for (int i = 0; i < host_vector.size() / n_classes; i++) {
ASSERT_EQ(host_vector[i * n_classes], 1.5);
ASSERT_EQ(host_vector[i * n_classes + 1], 0.);
ASSERT_EQ(host_vector[i * n_classes + 2], 0.);
}
} }
} }
#if defined(XGBOOST_USE_NCCL)
// Test whether pickling preserves predictor parameters // Test whether pickling preserves predictor parameters
TEST(gpu_predictor, MGPU_PicklingTest) { TEST(gpu_predictor, PicklingTest) {
int const ngpu = GPUSet::AllVisible().Size(); int const ngpu = 1;
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm"; const std::string tmp_file = tempdir.path + "/simple.libsvm";
@ -153,12 +163,6 @@ TEST(gpu_predictor, MGPU_PicklingTest) {
ASSERT_EQ(kwargs.at("n_gpus"), std::to_string(ngpu).c_str()); ASSERT_EQ(kwargs.at("n_gpus"), std::to_string(ngpu).c_str());
} }
{ // Change n_gpus and query again
CheckCAPICall(XGBoosterSetParam(bst2, "n_gpus", "1"));
const auto& kwargs = QueryBoosterConfigurationArguments(bst2);
ASSERT_EQ(kwargs.at("n_gpus"), "1");
}
{ // Change predictor and query again { // Change predictor and query again
CheckCAPICall(XGBoosterSetParam(bst2, "predictor", "cpu_predictor")); CheckCAPICall(XGBoosterSetParam(bst2, "predictor", "cpu_predictor"));
const auto& kwargs = QueryBoosterConfigurationArguments(bst2); const auto& kwargs = QueryBoosterConfigurationArguments(bst2);
@ -167,77 +171,5 @@ TEST(gpu_predictor, MGPU_PicklingTest) {
CheckCAPICall(XGBoosterFree(bst2)); CheckCAPICall(XGBoosterFree(bst2));
} }
// multi-GPU predictor test
TEST(gpu_predictor, MGPU_Test) {
auto cpu_lparam = CreateEmptyGenericParam(0, 0);
auto gpu_lparam = CreateEmptyGenericParam(0, -1);
std::unique_ptr<Predictor> gpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &gpu_lparam));
std::unique_ptr<Predictor> cpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &cpu_lparam));
cpu_predictor->Configure({}, {});
for (size_t i = 1; i < 33; i *= 2) {
int n_row = i, n_col = i;
auto dmat = CreateDMatrix(n_row, n_col, 0);
gbm::GBTreeModel model = CreateTestModel();
model.param.num_feature = n_col;
// Test predict batch
HostDeviceVector<float> gpu_out_predictions;
HostDeviceVector<float> cpu_out_predictions;
gpu_predictor->PredictBatch((*dmat).get(), &gpu_out_predictions, model, 0);
cpu_predictor->PredictBatch((*dmat).get(), &cpu_out_predictions, model, 0);
std::vector<float>& gpu_out_predictions_h = gpu_out_predictions.HostVector();
std::vector<float>& cpu_out_predictions_h = cpu_out_predictions.HostVector();
float abs_tolerance = 0.001;
for (int j = 0; j < gpu_out_predictions.Size(); j++) {
ASSERT_NEAR(gpu_out_predictions_h[j], cpu_out_predictions_h[j], abs_tolerance);
}
delete dmat;
}
}
// multi-GPU predictor external memory test
TEST(gpu_predictor, MGPU_ExternalMemoryTest) {
auto gpu_lparam = CreateEmptyGenericParam(0, -1);
std::unique_ptr<Predictor> gpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &gpu_lparam));
gpu_predictor->Configure({}, {});
gbm::GBTreeModel model = CreateTestModel();
model.param.num_feature = 3;
const int n_classes = 3;
model.param.num_output_group = n_classes;
std::vector<std::unique_ptr<DMatrix>> dmats;
dmlc::TemporaryDirectory tmpdir;
std::string file0 = tmpdir.path + "/big_0.libsvm";
std::string file1 = tmpdir.path + "/big_1.libsvm";
std::string file2 = tmpdir.path + "/big_2.libsvm";
dmats.push_back(CreateSparsePageDMatrix(9, 64UL, file0));
dmats.push_back(CreateSparsePageDMatrix(128, 128UL, file1));
dmats.push_back(CreateSparsePageDMatrix(1024, 1024UL, file2));
for (const auto& dmat: dmats) {
// Test predict batch
HostDeviceVector<float> out_predictions;
gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
EXPECT_EQ(out_predictions.Size(), dmat->Info().num_row_ * n_classes);
const std::vector<float> &host_vector = out_predictions.ConstHostVector();
for (int i = 0; i < host_vector.size() / n_classes; i++) {
ASSERT_EQ(host_vector[i * n_classes], 1.5);
ASSERT_EQ(host_vector[i * n_classes + 1], 0.);
ASSERT_EQ(host_vector[i * n_classes + 2], 0.);
}
}
}
#endif // defined(XGBOOST_USE_NCCL)
} // namespace predictor } // namespace predictor
} // namespace xgboost } // namespace xgboost

View File

@ -168,10 +168,10 @@ TEST(Learner, IO) {
std::unique_ptr<Learner> learner {Learner::Create(mat)}; std::unique_ptr<Learner> learner {Learner::Create(mat)};
learner->SetParams({Arg{"tree_method", "auto"}, learner->SetParams({Arg{"tree_method", "auto"},
Arg{"predictor", "gpu_predictor"}, Arg{"predictor", "gpu_predictor"},
Arg{"n_gpus", "-1"}}); Arg{"n_gpus", "1"}});
learner->UpdateOneIter(0, p_dmat.get()); learner->UpdateOneIter(0, p_dmat.get());
ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0); ASSERT_EQ(learner->GetGenericParameter().gpu_id, 0);
ASSERT_EQ(learner->GetGenericParameter().n_gpus, -1); ASSERT_EQ(learner->GetGenericParameter().n_gpus, 1);
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string fname = tempdir.path + "/model.bst"; const std::string fname = tempdir.path + "/model.bst";

View File

@ -415,13 +415,5 @@ TEST(GpuHist, TestHistogramIndex) {
TestHistogramIndexImpl(1); TestHistogramIndexImpl(1);
} }
#if defined(XGBOOST_USE_NCCL)
TEST(GpuHist, MGPU_TestHistogramIndex) {
auto devices = GPUSet::AllVisible();
CHECK_GT(devices.Size(), 1);
TestHistogramIndexImpl(-1);
}
#endif
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -66,27 +66,6 @@ def params_basic_1x4(rank):
}), 20 }), 20
def params_basic_2x2(rank):
return dict(base_params, **{
'n_gpus': 2,
'gpu_id': 2*rank,
}), 20
def params_basic_4x1(rank):
return dict(base_params, **{
'n_gpus': 4,
'gpu_id': rank,
}), 20
def params_basic_asym(rank):
return dict(base_params, **{
'n_gpus': 1 if rank == 0 else 3,
'gpu_id': rank,
}), 20
rf_update_params = { rf_update_params = {
'subsample': 0.5, 'subsample': 0.5,
'colsample_bynode': 0.5 'colsample_bynode': 0.5
@ -103,11 +82,6 @@ def wrap_rf(params_fun):
params_rf_1x4 = wrap_rf(params_basic_1x4) params_rf_1x4 = wrap_rf(params_basic_1x4)
params_rf_2x2 = wrap_rf(params_basic_2x2)
params_rf_4x1 = wrap_rf(params_basic_4x1)
params_rf_asym = wrap_rf(params_basic_asym)
test_name = sys.argv[1] test_name = sys.argv[1]

View File

@ -8,23 +8,5 @@ submit="timeout 30 python ../../dmlc-core/tracker/dmlc-submit"
echo -e "\n ====== 1. Basic distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n" echo -e "\n ====== 1. Basic distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n"
$submit --num-workers=4 python distributed_gpu.py basic_1x4 || exit 1 $submit --num-workers=4 python distributed_gpu.py basic_1x4 || exit 1
echo -e "\n ====== 2. Basic distributed-gpu test with Python: 2 workers; 2 GPUs per worker ====== \n" echo -e "\n ====== 2. RF distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n"
$submit --num-workers=2 python distributed_gpu.py basic_2x2 || exit 1
echo -e "\n ====== 3. Basic distributed-gpu test with Python: 2 workers; Rank 0: 1 GPU, Rank 1: 3 GPUs ====== \n"
$submit --num-workers=2 python distributed_gpu.py basic_asym || exit 1
echo -e "\n ====== 4. Basic distributed-gpu test with Python: 1 worker; 4 GPUs per worker ====== \n"
$submit --num-workers=1 python distributed_gpu.py basic_4x1 || exit 1
echo -e "\n ====== 5. RF distributed-gpu test with Python: 4 workers; 1 GPU per worker ====== \n"
$submit --num-workers=4 python distributed_gpu.py rf_1x4 || exit 1 $submit --num-workers=4 python distributed_gpu.py rf_1x4 || exit 1
echo -e "\n ====== 6. RF distributed-gpu test with Python: 2 workers; 2 GPUs per worker ====== \n"
$submit --num-workers=2 python distributed_gpu.py rf_2x2 || exit 1
echo -e "\n ====== 7. RF distributed-gpu test with Python: 2 workers; Rank 0: 1 GPU, Rank 1: 3 GPUs ====== \n"
$submit --num-workers=2 python distributed_gpu.py rf_asym || exit 1
echo -e "\n ====== 8. RF distributed-gpu test with Python: 1 worker; 4 GPUs per worker ====== \n"
$submit --num-workers=1 python distributed_gpu.py rf_4x1 || exit 1

View File

@ -29,15 +29,3 @@ class TestGPULinear(unittest.TestCase):
param, 150, self.datasets, scale_features=True) param, 150, self.datasets, scale_features=True)
test_linear.assert_regression_result(results, 1e-2) test_linear.assert_regression_result(results, 1e-2)
test_linear.assert_classification_result(results) test_linear.assert_classification_result(results)
@pytest.mark.mgpu
@pytest.mark.skipif(**tm.no_sklearn())
def test_gpu_coordinate_mgpu(self):
parameters = self.common_param.copy()
parameters['n_gpus'] = [-1]
parameters['gpu_id'] = [1]
for param in test_linear.parameter_combinations(parameters):
results = test_linear.run_suite(
param, 150, self.datasets, scale_features=True)
test_linear.assert_regression_result(results, 1e-2)
test_linear.assert_classification_result(results)

View File

@ -36,17 +36,6 @@ class TestGPU(unittest.TestCase):
cpu_results = run_suite(param, select_datasets=datasets) cpu_results = run_suite(param, select_datasets=datasets)
assert_gpu_results(cpu_results, gpu_results) assert_gpu_results(cpu_results, gpu_results)
@pytest.mark.mgpu
def test_gpu_hist_mgpu(self):
variable_param = {'n_gpus': [-1], 'max_depth': [2, 10],
'max_leaves': [255, 4],
'max_bin': [2, 256],
'grow_policy': ['lossguide'], 'debug_synchronize': [True]}
for param in parameter_combinations(variable_param):
param['tree_method'] = 'gpu_hist'
gpu_results = run_suite(param, select_datasets=datasets)
assert_results_non_increasing(gpu_results, 1e-2)
@pytest.mark.mgpu @pytest.mark.mgpu
def test_specified_gpu_id_gpu_update(self): def test_specified_gpu_id_gpu_update(self):
variable_param = {'n_gpus': [1], variable_param = {'n_gpus': [1],

View File

@ -25,7 +25,7 @@ cols = 31
# reduced to fit onto 1 gpu but still be large # reduced to fit onto 1 gpu but still be large
rows3 = 5000 # small rows3 = 5000 # small
rows2 = 4360032 # medium rows2 = 4360032 # medium
rows1 = 42360032 # large rows1 = 32360032 # large
# rows1 = 152360032 # can do this for multi-gpu test (very large) # rows1 = 152360032 # can do this for multi-gpu test (very large)
rowslist = [rows1, rows2, rows3] rowslist = [rows1, rows2, rows3]
@ -67,15 +67,6 @@ class TestGPU(unittest.TestCase):
'objective': 'binary:logistic', 'objective': 'binary:logistic',
'max_bin': max_bin, 'max_bin': max_bin,
'eval_metric': 'auc'} 'eval_metric': 'auc'}
ag_param3 = {'max_depth': max_depth,
'tree_method': 'gpu_hist',
'nthread': 0,
'eta': 1,
'verbosity': 3,
'n_gpus': -1,
'objective': 'binary:logistic',
'max_bin': max_bin,
'eval_metric': 'auc'}
ag_res = {} ag_res = {}
ag_resb = {} ag_resb = {}
ag_res2 = {} ag_res2 = {}
@ -93,9 +84,3 @@ class TestGPU(unittest.TestCase):
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train')], xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
evals_result=ag_res2) evals_result=ag_res2)
print("Time to Train: %s seconds" % (str(time.time() - tmp))) print("Time to Train: %s seconds" % (str(time.time() - tmp)))
tmp = time.time()
eprint("gpu_hist updater all gpus")
xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
evals_result=ag_res3)
print("Time to Train: %s seconds" % (str(time.time() - tmp)))

View File

@ -35,8 +35,6 @@ class TestPickling(unittest.TestCase):
x, y = build_dataset() x, y = build_dataset()
train_x = xgb.DMatrix(x, label=y) train_x = xgb.DMatrix(x, label=y)
param = {'tree_method': 'gpu_hist', param = {'tree_method': 'gpu_hist',
'gpu_id': 0,
'n_gpus': -1,
'verbosity': 1} 'verbosity': 1}
bst = xgb.train(param, train_x) bst = xgb.train(param, train_x)