Fix wrapping GPU ID and prevent data copying. (#5160)

* Removed some data copying.

* Make sure gpu_id is valid before any configuration is carried out.
This commit is contained in:
Jiaming Yuan
2019-12-27 16:51:08 +08:00
committed by GitHub
parent ee81ba8e1f
commit 61286c6e8f
7 changed files with 55 additions and 17 deletions

View File

@@ -41,7 +41,7 @@ BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() {
// column page doesn't exist, generate it
if (!column_page_) {
auto page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
auto const& page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
column_page_.reset(new CSCPage(page.GetTranspose(source_->info.num_col_)));
}
auto begin_iter =
@@ -52,7 +52,7 @@ BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() {
BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
// Sorted column page doesn't exist, generate it
if (!sorted_column_page_) {
auto page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
auto const& page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
sorted_column_page_.reset(
new SortedCSCPage(page.GetTranspose(source_->info.num_col_)));
sorted_column_page_->SortRows();

View File

@@ -354,7 +354,6 @@ class SparsePageSource : public DataSource<T> {
writer.Alloc(&page);
page->Clear();
MetaInfo info = src->Info();
size_t bytes_write = 0;
double tstart = dmlc::GetTime();
for (auto& batch : src->GetBatches<SparsePage>()) {

View File

@@ -275,7 +275,8 @@ class LearnerImpl : public Learner {
// `verbosity` in logger is not saved, we should move it into generic_param_.
// FIXME(trivialfis): Make eval_metric a training parameter.
if (kv.first != "num_feature" && kv.first != "verbosity" &&
kv.first != "num_class" && kv.first != kEvalMetric) {
kv.first != "num_class" && kv.first != "num_output_group" &&
kv.first != kEvalMetric) {
provided.push_back(kv.first);
}
}
@@ -399,6 +400,8 @@ class LearnerImpl : public Learner {
}
fromJson(learner_parameters.at("generic_param"), &generic_parameters_);
// make sure the GPU ID is valid in new environment before start running configure.
generic_parameters_.ConfigureGpuId(false);
this->need_configuration_ = true;
}