From 24f2e6c97eb8c9562ec3ea7218fcb9b795315d80 Mon Sep 17 00:00:00 2001 From: ShvetsKS <33296480+ShvetsKS@users.noreply.github.com> Date: Wed, 19 Aug 2020 20:37:03 +0300 Subject: [PATCH] Optimize DMatrix build time. (#5877) Co-authored-by: SHVETS, KIRILL --- python-package/xgboost/core.py | 2 +- src/common/group_data.h | 11 ++-- src/data/data.cc | 97 ++++++++++++++++++++++------------ 3 files changed, 70 insertions(+), 40 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index f2cd880ba..2a69fea51 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -422,7 +422,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes raise TypeError('Input data can not be a list.') self.missing = missing if missing is not None else np.nan - self.nthread = nthread if nthread is not None else 1 + self.nthread = nthread if nthread is not None else -1 self.silent = silent # force into void_p, mac need to pass things in as void_p diff --git a/src/common/group_data.h b/src/common/group_data.h index 0144d8099..476b4925b 100644 --- a/src/common/group_data.h +++ b/src/common/group_data.h @@ -17,6 +17,7 @@ #include #include #include +#include #include "xgboost/base.h" @@ -56,10 +57,10 @@ class ParallelGroupBuilder { void InitBudget(std::size_t max_key, int nthread) { thread_rptr_.resize(nthread); for (std::size_t i = 0; i < thread_rptr_.size(); ++i) { - thread_rptr_[i].resize(max_key - std::min(base_row_offset_, max_key)); - std::fill(thread_rptr_[i].begin(), thread_rptr_[i].end(), 0); + thread_rptr_[i].resize(max_key - std::min(base_row_offset_, max_key), 0); } } + /*! * \brief step 2: add budget to each key * \param key the key @@ -74,6 +75,7 @@ class ParallelGroupBuilder { } trptr[offset_key] += nelem; } + /*! \brief step 3: initialize the necessary storage */ inline void InitStorage() { // set rptr to correct size @@ -101,6 +103,7 @@ class ParallelGroupBuilder { } data_.resize(rptr_.back()); } + /*! * \brief step 4: add data to the allocated space, * the calls to this function should be exactly match previous call to AddBudget @@ -109,10 +112,10 @@ class ParallelGroupBuilder { * \param value The value to be pushed to the group. * \param threadid the id of thread that calls this function */ - void Push(std::size_t key, ValueType value, int threadid) { + void Push(std::size_t key, ValueType&& value, int threadid) { size_t offset_key = key - base_row_offset_; SizeType &rp = thread_rptr_[threadid][offset_key]; - data_[rp++] = value; + data_[rp++] = std::move(value); } private: diff --git a/src/data/data.cc b/src/data/data.cc index 8bd7c76cf..d7d18f189 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -840,10 +840,11 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread // Set number of threads but keep old value so we can reset it after const int nthreadmax = omp_get_max_threads(); if (nthread <= 0) nthread = nthreadmax; - int nthread_original = omp_get_max_threads(); + const int nthread_original = omp_get_max_threads(); omp_set_num_threads(nthread); auto& offset_vec = offset.HostVector(); auto& data_vec = data.HostVector(); + size_t builder_base_row_offset = this->Size(); common::ParallelGroupBuilder< Entry, std::remove_reference::type::value_type> @@ -858,48 +859,74 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread last_line.GetElement(last_line.Size() - 1).row_idx - base_rowid; } } - builder.InitBudget(expected_rows, nthread); - uint64_t max_columns = 0; - - // First-pass over the batch counting valid elements size_t batch_size = batch.Size(); -#pragma omp parallel for schedule(static) - for (omp_ulong i = 0; i < static_cast(batch_size); - ++i) { // NOLINT(*) - int tid = omp_get_thread_num(); - auto line = batch.GetLine(i); - for (auto j = 0ull; j < line.Size(); j++) { - data::COOTuple element = line.GetElement(j); - max_columns = - std::max(max_columns, static_cast(element.column_idx + 1)); - if (!common::CheckNAN(element.value) && element.value != missing) { - size_t key = element.row_idx - base_rowid; - // Adapter row index is absolute, here we want it relative to - // current page - CHECK_GE(key, builder_base_row_offset); - builder.AddBudget(key, tid); - } - } + const size_t thread_size = batch_size/nthread; + builder.InitBudget(expected_rows+1, nthread); + uint64_t max_columns = 0; + if (batch_size == 0) { + omp_set_num_threads(nthread_original); + return max_columns; } + std::vector> max_columns_vector(nthread); + dmlc::OMPException exec; + // First-pass over the batch counting valid elements +#pragma omp parallel num_threads(nthread) + { + exec.Run([&]() { + int tid = omp_get_thread_num(); + size_t begin = tid*thread_size; + size_t end = tid != (nthread-1) ? (tid+1)*thread_size : batch_size; + max_columns_vector[tid].resize(1, 0); + uint64_t& max_columns_local = max_columns_vector[tid][0]; + + for (size_t i = begin; i < end; ++i) { + auto line = batch.GetLine(i); + for (auto j = 0ull; j < line.Size(); j++) { + auto element = line.GetElement(j); + const size_t key = element.row_idx - base_rowid; + CHECK_GE(key, builder_base_row_offset); + max_columns_local = + std::max(max_columns_local, static_cast(element.column_idx + 1)); + + if (!common::CheckNAN(element.value) && element.value != missing) { + // Adapter row index is absolute, here we want it relative to + // current page + builder.AddBudget(key, tid); + } + } + } + }); + } + exec.Rethrow(); + for (const auto & max : max_columns_vector) { + max_columns = std::max(max_columns, max[0]); + } + builder.InitStorage(); // Second pass over batch, placing elements in correct position -#pragma omp parallel for schedule(static) - for (omp_ulong i = 0; i < static_cast(batch_size); - ++i) { // NOLINT(*) - int tid = omp_get_thread_num(); - auto line = batch.GetLine(i); - for (auto j = 0ull; j < line.Size(); j++) { - auto element = line.GetElement(j); - if (!common::CheckNAN(element.value) && element.value != missing) { - size_t key = element.row_idx - - base_rowid; // Adapter row index is absolute, here we want - // it relative to current page - builder.Push(key, Entry(element.column_idx, element.value), tid); + +#pragma omp parallel num_threads(nthread) + { + exec.Run([&]() { + int tid = omp_get_thread_num(); + size_t begin = tid*thread_size; + size_t end = tid != (nthread-1) ? (tid+1)*thread_size : batch_size; + for (size_t i = begin; i < end; ++i) { + auto line = batch.GetLine(i); + for (auto j = 0ull; j < line.Size(); j++) { + auto element = line.GetElement(j); + const size_t key = (element.row_idx - base_rowid); + if (!common::CheckNAN(element.value) && element.value != missing) { + builder.Push(key, Entry(element.column_idx, element.value), tid); + } + } } - } + }); } + exec.Rethrow(); omp_set_num_threads(nthread_original); + return max_columns; }