Fix wrapping GPU ID and prevent data copying. (#5160)
* Removed some data copying. * Make sure gpu_id is valid before any configuration is carried out.
This commit is contained in:
parent
ee81ba8e1f
commit
61286c6e8f
@ -41,7 +41,7 @@ BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
|
|||||||
BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() {
|
BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() {
|
||||||
// column page doesn't exist, generate it
|
// column page doesn't exist, generate it
|
||||||
if (!column_page_) {
|
if (!column_page_) {
|
||||||
auto page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
|
auto const& page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
|
||||||
column_page_.reset(new CSCPage(page.GetTranspose(source_->info.num_col_)));
|
column_page_.reset(new CSCPage(page.GetTranspose(source_->info.num_col_)));
|
||||||
}
|
}
|
||||||
auto begin_iter =
|
auto begin_iter =
|
||||||
@ -52,7 +52,7 @@ BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() {
|
|||||||
BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
|
BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
|
||||||
// Sorted column page doesn't exist, generate it
|
// Sorted column page doesn't exist, generate it
|
||||||
if (!sorted_column_page_) {
|
if (!sorted_column_page_) {
|
||||||
auto page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
|
auto const& page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
|
||||||
sorted_column_page_.reset(
|
sorted_column_page_.reset(
|
||||||
new SortedCSCPage(page.GetTranspose(source_->info.num_col_)));
|
new SortedCSCPage(page.GetTranspose(source_->info.num_col_)));
|
||||||
sorted_column_page_->SortRows();
|
sorted_column_page_->SortRows();
|
||||||
|
|||||||
@ -354,7 +354,6 @@ class SparsePageSource : public DataSource<T> {
|
|||||||
writer.Alloc(&page);
|
writer.Alloc(&page);
|
||||||
page->Clear();
|
page->Clear();
|
||||||
|
|
||||||
MetaInfo info = src->Info();
|
|
||||||
size_t bytes_write = 0;
|
size_t bytes_write = 0;
|
||||||
double tstart = dmlc::GetTime();
|
double tstart = dmlc::GetTime();
|
||||||
for (auto& batch : src->GetBatches<SparsePage>()) {
|
for (auto& batch : src->GetBatches<SparsePage>()) {
|
||||||
|
|||||||
@ -275,7 +275,8 @@ class LearnerImpl : public Learner {
|
|||||||
// `verbosity` in logger is not saved, we should move it into generic_param_.
|
// `verbosity` in logger is not saved, we should move it into generic_param_.
|
||||||
// FIXME(trivialfis): Make eval_metric a training parameter.
|
// FIXME(trivialfis): Make eval_metric a training parameter.
|
||||||
if (kv.first != "num_feature" && kv.first != "verbosity" &&
|
if (kv.first != "num_feature" && kv.first != "verbosity" &&
|
||||||
kv.first != "num_class" && kv.first != kEvalMetric) {
|
kv.first != "num_class" && kv.first != "num_output_group" &&
|
||||||
|
kv.first != kEvalMetric) {
|
||||||
provided.push_back(kv.first);
|
provided.push_back(kv.first);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -399,6 +400,8 @@ class LearnerImpl : public Learner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fromJson(learner_parameters.at("generic_param"), &generic_parameters_);
|
fromJson(learner_parameters.at("generic_param"), &generic_parameters_);
|
||||||
|
// make sure the GPU ID is valid in new environment before start running configure.
|
||||||
|
generic_parameters_.ConfigureGpuId(false);
|
||||||
|
|
||||||
this->need_configuration_ = true;
|
this->need_configuration_ = true;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -51,14 +51,14 @@ TEST(SparsePageDMatrix, ColAccess) {
|
|||||||
EXPECT_EQ(dmat->GetColDensity(1), 0.5);
|
EXPECT_EQ(dmat->GetColDensity(1), 0.5);
|
||||||
|
|
||||||
// Loop over the batches and assert the data is as expected
|
// Loop over the batches and assert the data is as expected
|
||||||
for (auto col_batch : dmat->GetBatches<xgboost::SortedCSCPage>()) {
|
for (auto const& col_batch : dmat->GetBatches<xgboost::SortedCSCPage>()) {
|
||||||
EXPECT_EQ(col_batch.Size(), dmat->Info().num_col_);
|
EXPECT_EQ(col_batch.Size(), dmat->Info().num_col_);
|
||||||
EXPECT_EQ(col_batch[1][0].fvalue, 10.0f);
|
EXPECT_EQ(col_batch[1][0].fvalue, 10.0f);
|
||||||
EXPECT_EQ(col_batch[1].size(), 1);
|
EXPECT_EQ(col_batch[1].size(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Loop over the batches and assert the data is as expected
|
// Loop over the batches and assert the data is as expected
|
||||||
for (auto col_batch : dmat->GetBatches<xgboost::CSCPage>()) {
|
for (auto const& col_batch : dmat->GetBatches<xgboost::CSCPage>()) {
|
||||||
EXPECT_EQ(col_batch.Size(), dmat->Info().num_col_);
|
EXPECT_EQ(col_batch.Size(), dmat->Info().num_col_);
|
||||||
EXPECT_EQ(col_batch[1][0].fvalue, 10.0f);
|
EXPECT_EQ(col_batch[1][0].fvalue, 10.0f);
|
||||||
EXPECT_EQ(col_batch[1].size(), 1);
|
EXPECT_EQ(col_batch[1].size(), 1);
|
||||||
@ -223,7 +223,7 @@ TEST(SparsePageDMatrix, FromFile) {
|
|||||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||||
data::SparsePageDMatrix dmat(
|
data::SparsePageDMatrix dmat(
|
||||||
&adapter, std::numeric_limits<float>::quiet_NaN(), -1, tmp_file, 1);
|
&adapter, std::numeric_limits<float>::quiet_NaN(), -1, tmp_file, 1);
|
||||||
|
|
||||||
for (auto &batch : dmat.GetBatches<SparsePage>()) {
|
for (auto &batch : dmat.GetBatches<SparsePage>()) {
|
||||||
std::vector<bst_row_t> expected_offset(batch.Size() + 1);
|
std::vector<bst_row_t> expected_offset(batch.Size() + 1);
|
||||||
int n = -3;
|
int n = -3;
|
||||||
|
|||||||
@ -46,7 +46,7 @@ void CheckObjFunctionImpl(std::unique_ptr<xgboost::ObjFunction> const& obj,
|
|||||||
std::vector<xgboost::bst_float> preds,
|
std::vector<xgboost::bst_float> preds,
|
||||||
std::vector<xgboost::bst_float> labels,
|
std::vector<xgboost::bst_float> labels,
|
||||||
std::vector<xgboost::bst_float> weights,
|
std::vector<xgboost::bst_float> weights,
|
||||||
xgboost::MetaInfo info,
|
xgboost::MetaInfo const& info,
|
||||||
std::vector<xgboost::bst_float> out_grad,
|
std::vector<xgboost::bst_float> out_grad,
|
||||||
std::vector<xgboost::bst_float> out_hess) {
|
std::vector<xgboost::bst_float> out_hess) {
|
||||||
xgboost::HostDeviceVector<xgboost::bst_float> in_preds(preds);
|
xgboost::HostDeviceVector<xgboost::bst_float> in_preds(preds);
|
||||||
|
|||||||
@ -37,3 +37,15 @@ class TestLoadPickle(unittest.TestCase):
|
|||||||
config = json.loads(config)
|
config = json.loads(config)
|
||||||
assert config['learner']['gradient_booster']['gbtree_train_param'][
|
assert config['learner']['gradient_booster']['gbtree_train_param'][
|
||||||
'predictor'] == 'gpu_predictor'
|
'predictor'] == 'gpu_predictor'
|
||||||
|
|
||||||
|
def test_wrap_gpu_id(self):
|
||||||
|
assert os.environ['CUDA_VISIBLE_DEVICES'] == '0'
|
||||||
|
bst = load_pickle(model_path)
|
||||||
|
config = bst.save_config()
|
||||||
|
config = json.loads(config)
|
||||||
|
assert config['learner']['generic_param']['gpu_id'] == '0'
|
||||||
|
|
||||||
|
x, y = build_dataset()
|
||||||
|
test_x = xgb.DMatrix(x)
|
||||||
|
res = bst.predict(test_x)
|
||||||
|
assert len(res) == 10
|
||||||
|
|||||||
@ -5,6 +5,9 @@ import numpy as np
|
|||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import pytest
|
||||||
|
import copy
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost import XGBClassifier
|
from xgboost import XGBClassifier
|
||||||
|
|
||||||
@ -31,6 +34,12 @@ def load_pickle(path):
|
|||||||
|
|
||||||
|
|
||||||
class TestPickling(unittest.TestCase):
|
class TestPickling(unittest.TestCase):
|
||||||
|
args_template = [
|
||||||
|
"pytest",
|
||||||
|
"--verbose",
|
||||||
|
"-s",
|
||||||
|
"--fulltrace"]
|
||||||
|
|
||||||
def test_pickling(self):
|
def test_pickling(self):
|
||||||
x, y = build_dataset()
|
x, y = build_dataset()
|
||||||
train_x = xgb.DMatrix(x, label=y)
|
train_x = xgb.DMatrix(x, label=y)
|
||||||
@ -61,13 +70,29 @@ class TestPickling(unittest.TestCase):
|
|||||||
assert status == 0
|
assert status == 0
|
||||||
os.remove(model_path)
|
os.remove(model_path)
|
||||||
|
|
||||||
def test_pickled_predictor(self):
|
@pytest.mark.mgpu
|
||||||
args_templae = [
|
def test_wrap_gpu_id(self):
|
||||||
"pytest",
|
X, y = build_dataset()
|
||||||
"--verbose",
|
dtrain = xgb.DMatrix(X, y)
|
||||||
"-s",
|
|
||||||
"--fulltrace"]
|
|
||||||
|
|
||||||
|
bst = xgb.train({'tree_method': 'gpu_hist',
|
||||||
|
'gpu_id': 1},
|
||||||
|
dtrain, num_boost_round=6)
|
||||||
|
|
||||||
|
model_path = 'model.pkl'
|
||||||
|
save_pickle(bst, model_path)
|
||||||
|
cuda_environment = {'CUDA_VISIBLE_DEVICES': '0'}
|
||||||
|
env = os.environ.copy()
|
||||||
|
env.update(cuda_environment)
|
||||||
|
args = self.args_template.copy()
|
||||||
|
args.append(
|
||||||
|
"./tests/python-gpu/"
|
||||||
|
"load_pickle.py::TestLoadPickle::test_wrap_gpu_id"
|
||||||
|
)
|
||||||
|
status = subprocess.call(args, env=env)
|
||||||
|
assert status == 0
|
||||||
|
|
||||||
|
def test_pickled_predictor(self):
|
||||||
x, y = build_dataset()
|
x, y = build_dataset()
|
||||||
train_x = xgb.DMatrix(x, label=y)
|
train_x = xgb.DMatrix(x, label=y)
|
||||||
|
|
||||||
@ -80,7 +105,7 @@ class TestPickling(unittest.TestCase):
|
|||||||
|
|
||||||
save_pickle(bst, model_path)
|
save_pickle(bst, model_path)
|
||||||
|
|
||||||
args = args_templae.copy()
|
args = self.args_template.copy()
|
||||||
args.append(
|
args.append(
|
||||||
"./tests/python-gpu/"
|
"./tests/python-gpu/"
|
||||||
"load_pickle.py::TestLoadPickle::test_predictor_type_is_auto")
|
"load_pickle.py::TestLoadPickle::test_predictor_type_is_auto")
|
||||||
@ -93,7 +118,7 @@ class TestPickling(unittest.TestCase):
|
|||||||
status = subprocess.call(args, env=env)
|
status = subprocess.call(args, env=env)
|
||||||
assert status == 0
|
assert status == 0
|
||||||
|
|
||||||
args = args_templae.copy()
|
args = self.args_template.copy()
|
||||||
args.append(
|
args.append(
|
||||||
"./tests/python-gpu/"
|
"./tests/python-gpu/"
|
||||||
"load_pickle.py::TestLoadPickle::test_predictor_type_is_gpu")
|
"load_pickle.py::TestLoadPickle::test_predictor_type_is_gpu")
|
||||||
@ -109,7 +134,6 @@ class TestPickling(unittest.TestCase):
|
|||||||
|
|
||||||
kwargs = {'tree_method': 'gpu_hist',
|
kwargs = {'tree_method': 'gpu_hist',
|
||||||
'predictor': 'gpu_predictor',
|
'predictor': 'gpu_predictor',
|
||||||
'verbosity': 1,
|
|
||||||
'objective': 'binary:logistic',
|
'objective': 'binary:logistic',
|
||||||
'n_estimators': 10}
|
'n_estimators': 10}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user