Fix wrapping GPU ID and prevent data copying. (#5160)
* Removed some data copying. * Make sure gpu_id is valid before any configuration is carried out.
This commit is contained in:
@@ -41,7 +41,7 @@ BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
|
||||
BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() {
|
||||
// column page doesn't exist, generate it
|
||||
if (!column_page_) {
|
||||
auto page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
|
||||
auto const& page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
|
||||
column_page_.reset(new CSCPage(page.GetTranspose(source_->info.num_col_)));
|
||||
}
|
||||
auto begin_iter =
|
||||
@@ -52,7 +52,7 @@ BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() {
|
||||
BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
|
||||
// Sorted column page doesn't exist, generate it
|
||||
if (!sorted_column_page_) {
|
||||
auto page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
|
||||
auto const& page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
|
||||
sorted_column_page_.reset(
|
||||
new SortedCSCPage(page.GetTranspose(source_->info.num_col_)));
|
||||
sorted_column_page_->SortRows();
|
||||
|
||||
@@ -354,7 +354,6 @@ class SparsePageSource : public DataSource<T> {
|
||||
writer.Alloc(&page);
|
||||
page->Clear();
|
||||
|
||||
MetaInfo info = src->Info();
|
||||
size_t bytes_write = 0;
|
||||
double tstart = dmlc::GetTime();
|
||||
for (auto& batch : src->GetBatches<SparsePage>()) {
|
||||
|
||||
@@ -275,7 +275,8 @@ class LearnerImpl : public Learner {
|
||||
// `verbosity` in logger is not saved, we should move it into generic_param_.
|
||||
// FIXME(trivialfis): Make eval_metric a training parameter.
|
||||
if (kv.first != "num_feature" && kv.first != "verbosity" &&
|
||||
kv.first != "num_class" && kv.first != kEvalMetric) {
|
||||
kv.first != "num_class" && kv.first != "num_output_group" &&
|
||||
kv.first != kEvalMetric) {
|
||||
provided.push_back(kv.first);
|
||||
}
|
||||
}
|
||||
@@ -399,6 +400,8 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
|
||||
fromJson(learner_parameters.at("generic_param"), &generic_parameters_);
|
||||
// make sure the GPU ID is valid in new environment before start running configure.
|
||||
generic_parameters_.ConfigureGpuId(false);
|
||||
|
||||
this->need_configuration_ = true;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user