Fix wrapping GPU ID and prevent data copying. (#5160)

* Removed some data copying.

* Make sure gpu_id is valid before any configuration is carried out.
This commit is contained in:
Jiaming Yuan 2019-12-27 16:51:08 +08:00 committed by GitHub
parent ee81ba8e1f
commit 61286c6e8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 55 additions and 17 deletions

View File

@ -41,7 +41,7 @@ BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() {
// column page doesn't exist, generate it
if (!column_page_) {
auto page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
auto const& page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
column_page_.reset(new CSCPage(page.GetTranspose(source_->info.num_col_)));
}
auto begin_iter =
@ -52,7 +52,7 @@ BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() {
BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
// Sorted column page doesn't exist, generate it
if (!sorted_column_page_) {
auto page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
auto const& page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
sorted_column_page_.reset(
new SortedCSCPage(page.GetTranspose(source_->info.num_col_)));
sorted_column_page_->SortRows();

View File

@ -354,7 +354,6 @@ class SparsePageSource : public DataSource<T> {
writer.Alloc(&page);
page->Clear();
MetaInfo info = src->Info();
size_t bytes_write = 0;
double tstart = dmlc::GetTime();
for (auto& batch : src->GetBatches<SparsePage>()) {

View File

@ -275,7 +275,8 @@ class LearnerImpl : public Learner {
// `verbosity` in logger is not saved, we should move it into generic_param_.
// FIXME(trivialfis): Make eval_metric a training parameter.
if (kv.first != "num_feature" && kv.first != "verbosity" &&
kv.first != "num_class" && kv.first != kEvalMetric) {
kv.first != "num_class" && kv.first != "num_output_group" &&
kv.first != kEvalMetric) {
provided.push_back(kv.first);
}
}
@ -399,6 +400,8 @@ class LearnerImpl : public Learner {
}
fromJson(learner_parameters.at("generic_param"), &generic_parameters_);
// make sure the GPU ID is valid in new environment before start running configure.
generic_parameters_.ConfigureGpuId(false);
this->need_configuration_ = true;
}

View File

@ -51,14 +51,14 @@ TEST(SparsePageDMatrix, ColAccess) {
EXPECT_EQ(dmat->GetColDensity(1), 0.5);
// Loop over the batches and assert the data is as expected
for (auto col_batch : dmat->GetBatches<xgboost::SortedCSCPage>()) {
for (auto const& col_batch : dmat->GetBatches<xgboost::SortedCSCPage>()) {
EXPECT_EQ(col_batch.Size(), dmat->Info().num_col_);
EXPECT_EQ(col_batch[1][0].fvalue, 10.0f);
EXPECT_EQ(col_batch[1].size(), 1);
}
// Loop over the batches and assert the data is as expected
for (auto col_batch : dmat->GetBatches<xgboost::CSCPage>()) {
for (auto const& col_batch : dmat->GetBatches<xgboost::CSCPage>()) {
EXPECT_EQ(col_batch.Size(), dmat->Info().num_col_);
EXPECT_EQ(col_batch[1][0].fvalue, 10.0f);
EXPECT_EQ(col_batch[1].size(), 1);
@ -223,7 +223,7 @@ TEST(SparsePageDMatrix, FromFile) {
const std::string tmp_file = tempdir.path + "/simple.libsvm";
data::SparsePageDMatrix dmat(
&adapter, std::numeric_limits<float>::quiet_NaN(), -1, tmp_file, 1);
for (auto &batch : dmat.GetBatches<SparsePage>()) {
std::vector<bst_row_t> expected_offset(batch.Size() + 1);
int n = -3;

View File

@ -46,7 +46,7 @@ void CheckObjFunctionImpl(std::unique_ptr<xgboost::ObjFunction> const& obj,
std::vector<xgboost::bst_float> preds,
std::vector<xgboost::bst_float> labels,
std::vector<xgboost::bst_float> weights,
xgboost::MetaInfo info,
xgboost::MetaInfo const& info,
std::vector<xgboost::bst_float> out_grad,
std::vector<xgboost::bst_float> out_hess) {
xgboost::HostDeviceVector<xgboost::bst_float> in_preds(preds);

View File

@ -37,3 +37,15 @@ class TestLoadPickle(unittest.TestCase):
config = json.loads(config)
assert config['learner']['gradient_booster']['gbtree_train_param'][
'predictor'] == 'gpu_predictor'
def test_wrap_gpu_id(self):
assert os.environ['CUDA_VISIBLE_DEVICES'] == '0'
bst = load_pickle(model_path)
config = bst.save_config()
config = json.loads(config)
assert config['learner']['generic_param']['gpu_id'] == '0'
x, y = build_dataset()
test_x = xgb.DMatrix(x)
res = bst.predict(test_x)
assert len(res) == 10

View File

@ -5,6 +5,9 @@ import numpy as np
import subprocess
import os
import json
import pytest
import copy
import xgboost as xgb
from xgboost import XGBClassifier
@ -31,6 +34,12 @@ def load_pickle(path):
class TestPickling(unittest.TestCase):
args_template = [
"pytest",
"--verbose",
"-s",
"--fulltrace"]
def test_pickling(self):
x, y = build_dataset()
train_x = xgb.DMatrix(x, label=y)
@ -61,13 +70,29 @@ class TestPickling(unittest.TestCase):
assert status == 0
os.remove(model_path)
def test_pickled_predictor(self):
args_templae = [
"pytest",
"--verbose",
"-s",
"--fulltrace"]
@pytest.mark.mgpu
def test_wrap_gpu_id(self):
X, y = build_dataset()
dtrain = xgb.DMatrix(X, y)
bst = xgb.train({'tree_method': 'gpu_hist',
'gpu_id': 1},
dtrain, num_boost_round=6)
model_path = 'model.pkl'
save_pickle(bst, model_path)
cuda_environment = {'CUDA_VISIBLE_DEVICES': '0'}
env = os.environ.copy()
env.update(cuda_environment)
args = self.args_template.copy()
args.append(
"./tests/python-gpu/"
"load_pickle.py::TestLoadPickle::test_wrap_gpu_id"
)
status = subprocess.call(args, env=env)
assert status == 0
def test_pickled_predictor(self):
x, y = build_dataset()
train_x = xgb.DMatrix(x, label=y)
@ -80,7 +105,7 @@ class TestPickling(unittest.TestCase):
save_pickle(bst, model_path)
args = args_templae.copy()
args = self.args_template.copy()
args.append(
"./tests/python-gpu/"
"load_pickle.py::TestLoadPickle::test_predictor_type_is_auto")
@ -93,7 +118,7 @@ class TestPickling(unittest.TestCase):
status = subprocess.call(args, env=env)
assert status == 0
args = args_templae.copy()
args = self.args_template.copy()
args.append(
"./tests/python-gpu/"
"load_pickle.py::TestLoadPickle::test_predictor_type_is_gpu")
@ -109,7 +134,6 @@ class TestPickling(unittest.TestCase):
kwargs = {'tree_method': 'gpu_hist',
'predictor': 'gpu_predictor',
'verbosity': 1,
'objective': 'binary:logistic',
'n_estimators': 10}