diff --git a/.clang-tidy b/.clang-tidy index ecc265f8d..dbc0cf292 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,4 +1,4 @@ -Checks: 'modernize-*,-modernize-make-*,-modernize-use-auto,-modernize-raw-string-literal,google-*,-google-default-arguments,-clang-diagnostic-#pragma-messages,readability-identifier-naming' +Checks: 'modernize-*,-modernize-make-*,-modernize-use-auto,-modernize-raw-string-literal,-modernize-avoid-c-arrays,google-*,-google-default-arguments,-clang-diagnostic-#pragma-messages,readability-identifier-naming' CheckOptions: - { key: readability-identifier-naming.ClassCase, value: CamelCase } - { key: readability-identifier-naming.StructCase, value: CamelCase } diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 311347eaa..45e41a114 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -437,6 +437,7 @@ class DMatrix { bool load_row_split, const std::string& file_format = "auto", const size_t page_size = kPageSize); + /*! * \brief create a new DMatrix, by wrapping a row_iterator, and meta info. * \param source The source iterator of the data, the create function takes ownership of the source. diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a215fcef5..50291355d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -59,7 +59,7 @@ if (USE_CUDA) # OpenMP is mandatory for cuda version find_package(OpenMP REQUIRED) - target_compile_options(objxgboost PRIVATE + target_compile_options(objxgboost PRIVATE $<$:-Xcompiler=${OpenMP_CXX_FLAGS}> ) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 8e44b6e61..eeeb0d0be 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -119,10 +119,9 @@ class NativeDataIter : public dmlc::Parser { } bool Next() override { - if ((*next_callback_)( - data_handle_, - XGBoostNativeDataIterSetData, - this) != 0) { + if ((*next_callback_)(data_handle_, + XGBoostNativeDataIterSetData, + this) != 0) { at_first_ = false; return true; } else { diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index e55e1ef57..19a484109 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -75,7 +75,7 @@ class ColumnMatrix { // construct column matrix from GHistIndexMatrix inline void Init(const GHistIndexMatrix& gmat, double sparse_threshold) { - const int32_t nfeature = static_cast(gmat.cut.row_ptr.size() - 1); + const int32_t nfeature = static_cast(gmat.cut.Ptrs().size() - 1); const size_t nrow = gmat.row_ptr.size() - 1; // identify type of each column @@ -85,7 +85,7 @@ class ColumnMatrix { uint32_t max_val = std::numeric_limits::max(); for (int32_t fid = 0; fid < nfeature; ++fid) { - CHECK_LE(gmat.cut.row_ptr[fid + 1] - gmat.cut.row_ptr[fid], max_val); + CHECK_LE(gmat.cut.Ptrs()[fid + 1] - gmat.cut.Ptrs()[fid], max_val); } gmat.GetFeatureCounts(&feature_counts_[0]); @@ -123,7 +123,7 @@ class ColumnMatrix { // store least bin id for each feature index_base_.resize(nfeature); for (int32_t fid = 0; fid < nfeature; ++fid) { - index_base_[fid] = gmat.cut.row_ptr[fid]; + index_base_[fid] = gmat.cut.Ptrs()[fid]; } // pre-fill index_ for dense columns @@ -150,9 +150,9 @@ class ColumnMatrix { size_t fid = 0; for (size_t i = ibegin; i < iend; ++i) { const uint32_t bin_id = gmat.index[i]; - while (bin_id >= gmat.cut.row_ptr[fid + 1]) { - ++fid; - } + auto iter = std::upper_bound(gmat.cut.Ptrs().cbegin() + fid, + gmat.cut.Ptrs().cend(), bin_id); + fid = std::distance(gmat.cut.Ptrs().cbegin(), iter) - 1; if (type_[fid] == kDenseColumn) { uint32_t* begin = &index_[boundary_[fid].index_begin]; begin[rid] = bin_id - index_base_[fid]; diff --git a/src/common/common.h b/src/common/common.h index b030a72e0..9eab46add 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -72,6 +72,11 @@ inline std::string ToString(const T& data) { return os.str(); } +template +XGBOOST_DEVICE T1 DivRoundUp(const T1 a, const T2 b) { + return static_cast(std::ceil(static_cast(a) / b)); +} + /* * Range iterator */ diff --git a/src/common/config.h b/src/common/config.h index 04fcbbdb5..a85fee609 100644 --- a/src/common/config.h +++ b/src/common/config.h @@ -30,12 +30,13 @@ class ConfigParser { * \param path path to configuration file */ explicit ConfigParser(const std::string& path) - : line_comment_regex_("^#"), + : path_(path), + line_comment_regex_("^#"), key_regex_(R"rx(^([^#"'=\r\n\t ]+)[\t ]*=)rx"), key_regex_escaped_(R"rx(^(["'])([^"'=\r\n]+)\1[\t ]*=)rx"), value_regex_(R"rx(^([^#"'=\r\n\t ]+)[\t ]*(?:#.*){0,1}$)rx"), - value_regex_escaped_(R"rx(^(["'])([^"'=\r\n]+)\1[\t ]*(?:#.*){0,1}$)rx"), - path_(path) {} + value_regex_escaped_(R"rx(^(["'])([^"'=\r\n]+)\1[\t ]*(?:#.*){0,1}$)rx") + {} std::string LoadConfigFile(const std::string& path) { std::ifstream fin(path, std::ios_base::in | std::ios_base::binary); @@ -77,8 +78,6 @@ class ConfigParser { content = NormalizeConfigEOL(content); std::stringstream ss { content }; std::vector> results; - char delimiter = '='; - char comment = '#'; std::string line; std::string key, value; // Loop over every line of the configuration file diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index c1ff749d0..9a678b029 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -1,5 +1,5 @@ /*! - * Copyright 2017 XGBoost contributors + * Copyright 2017-2019 XGBoost contributors */ #pragma once #include @@ -183,11 +183,6 @@ __device__ void BlockFill(IterT begin, size_t n, ValueT value) { * Kernel launcher */ -template -T1 DivRoundUp(const T1 a, const T2 b) { - return static_cast(ceil(static_cast(a) / b)); -} - template __global__ void LaunchNKernel(size_t begin, size_t end, L lambda) { for (auto i : GridStrideRange(begin, end)) { @@ -211,7 +206,7 @@ inline void LaunchN(int device_idx, size_t n, cudaStream_t stream, L lambda) { safe_cuda(cudaSetDevice(device_idx)); const int GRID_SIZE = - static_cast(DivRoundUp(n, ITEMS_PER_THREAD * BLOCK_THREADS)); + static_cast(xgboost::common::DivRoundUp(n, ITEMS_PER_THREAD * BLOCK_THREADS)); LaunchNKernel<<>>(static_cast(0), n, lambda); } @@ -619,7 +614,7 @@ struct CubMemory { if (this->IsAllocated()) { XGBDeviceAllocator allocator; allocator.deallocate(thrust::device_ptr(static_cast(d_temp_storage)), - temp_storage_bytes); + temp_storage_bytes); d_temp_storage = nullptr; temp_storage_bytes = 0; } @@ -738,7 +733,7 @@ void SparseTransformLbs(int device_idx, dh::CubMemory *temp_memory, const int BLOCK_THREADS = 256; const int ITEMS_PER_THREAD = 1; const int TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD; - auto num_tiles = dh::DivRoundUp(count + num_segments, BLOCK_THREADS); + auto num_tiles = xgboost::common::DivRoundUp(count + num_segments, BLOCK_THREADS); CHECK(num_tiles < std::numeric_limits::max()); temp_memory->LazyAllocate(sizeof(CoordinateT) * (num_tiles + 1)); @@ -1158,7 +1153,7 @@ class AllReducer { }; /** - * \brief Synchronizes the device + * \brief Synchronizes the device * * \param device_id Identifier for the device. */ diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index af420db4d..e2514d856 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -25,25 +25,206 @@ namespace xgboost { namespace common { -HistCutMatrix::HistCutMatrix() { - monitor_.Init("HistCutMatrix"); +HistogramCuts::HistogramCuts() { + monitor_.Init(__FUNCTION__); + cut_ptrs_.emplace_back(0); } -size_t HistCutMatrix::SearchGroupIndFromBaseRow( - std::vector const& group_ptr, size_t const base_rowid) const { - using KIt = std::vector::const_iterator; - KIt res = std::lower_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid); - // Cannot use CHECK_NE because it will try to print the iterator. - bool const found = res != group_ptr.cend() - 1; - if (!found) { - LOG(FATAL) << "Row " << base_rowid << " does not lie in any group!\n"; +// Dispatch to specific builder. +void HistogramCuts::Build(DMatrix* dmat, uint32_t const max_num_bins) { + auto const& info = dmat->Info(); + size_t const total = info.num_row_ * info.num_col_; + size_t const nnz = info.num_nonzero_; + float const sparsity = static_cast(nnz) / static_cast(total); + // Use a small number to avoid calling `dmat->GetColumnBatches'. + float constexpr kSparsityThreshold = 0.0005; + // FIXME(trivialfis): Distributed environment is not supported. + if (sparsity < kSparsityThreshold && (!rabit::IsDistributed())) { + LOG(INFO) << "Building quantile cut on a sparse dataset."; + SparseCuts cuts(this); + cuts.Build(dmat, max_num_bins); + } else { + LOG(INFO) << "Building quantile cut on a dense dataset or distributed environment."; + DenseCuts cuts(this); + cuts.Build(dmat, max_num_bins); } - size_t group_ind = std::distance(group_ptr.cbegin(), res); - return group_ind; } -void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) { - monitor_.Start("Init"); +bool CutsBuilder::UseGroup(DMatrix* dmat) { + auto& info = dmat->Info(); + size_t const num_groups = info.group_ptr_.size() == 0 ? + 0 : info.group_ptr_.size() - 1; + // Use group index for weights? + bool const use_group_ind = num_groups != 0 && + (info.weights_.Size() != info.num_row_); + return use_group_ind; +} + +void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info, + uint32_t max_num_bins, + bool const use_group_ind, + uint32_t beg_col, uint32_t end_col, + uint32_t thread_id) { + using WXQSketch = common::WXQuantileSketch; + CHECK_GE(end_col, beg_col); + constexpr float kFactor = 8; + + // Data groups, used in ranking. + std::vector const& group_ptr = info.group_ptr_; + p_cuts_->min_vals_.resize(end_col - beg_col, 0); + + for (uint32_t col_id = beg_col; col_id < page.Size() && col_id < end_col; ++col_id) { + // Using a local variable makes things easier, but at the cost of memory trashing. + WXQSketch sketch; + common::Span const column = page[col_id]; + uint32_t const n_bins = std::min(static_cast(column.size()), + max_num_bins); + if (n_bins == 0) { + // cut_ptrs_ is initialized with a zero, so there's always an element at the back + p_cuts_->cut_ptrs_.emplace_back(p_cuts_->cut_ptrs_.back()); + continue; + } + + sketch.Init(info.num_row_, 1.0 / (n_bins * kFactor)); + for (auto const& entry : column) { + uint32_t weight_ind = 0; + if (use_group_ind) { + auto row_idx = entry.index; + uint32_t group_ind = + this->SearchGroupIndFromRow(group_ptr, page.base_rowid + row_idx); + weight_ind = group_ind; + } else { + weight_ind = entry.index; + } + sketch.Push(entry.fvalue, info.GetWeight(weight_ind)); + } + + WXQSketch::SummaryContainer out_summary; + sketch.GetSummary(&out_summary); + WXQSketch::SummaryContainer summary; + summary.Reserve(n_bins); + summary.SetPrune(out_summary, n_bins); + + // Can be use data[1] as the min values so that we don't need to + // store another array? + float mval = summary.data[0].value; + p_cuts_->min_vals_[col_id - beg_col] = mval - (fabs(mval) + 1e-5); + + this->AddCutPoint(summary); + + bst_float cpt = (summary.size > 0) ? + summary.data[summary.size - 1].value : + p_cuts_->min_vals_[col_id - beg_col]; + cpt += fabs(cpt) + 1e-5; + p_cuts_->cut_values_.emplace_back(cpt); + + p_cuts_->cut_ptrs_.emplace_back(p_cuts_->cut_values_.size()); + } +} + +std::vector SparseCuts::LoadBalance(SparsePage const& page, + size_t const nthreads) { + /* Some sparse datasets have their mass concentrating on small + * number of features. To avoid wating for a few threads running + * forever, we here distirbute different number of columns to + * different threads according to number of entries. */ + size_t const total_entries = page.data.Size(); + size_t const entries_per_thread = common::DivRoundUp(total_entries, nthreads); + + std::vector cols_ptr(nthreads+1, 0); + size_t count {0}; + size_t current_thread {1}; + + for (size_t col_id = 0; col_id < page.Size(); ++col_id) { + auto const column = page[col_id]; + cols_ptr[current_thread]++; // add one column to thread + count += column.size(); + if (count > entries_per_thread + 1) { + current_thread++; + count = 0; + cols_ptr[current_thread] = cols_ptr[current_thread-1]; + } + } + // Idle threads. + for (; current_thread < cols_ptr.size() - 1; ++current_thread) { + cols_ptr[current_thread+1] = cols_ptr[current_thread]; + } + + return cols_ptr; +} + +void SparseCuts::Build(DMatrix* dmat, uint32_t const max_num_bins) { + monitor_.Start(__FUNCTION__); + // Use group index for weights? + auto use_group = UseGroup(dmat); + uint32_t nthreads = omp_get_max_threads(); + CHECK_GT(nthreads, 0); + std::vector cuts_containers(nthreads); + std::vector> sparse_cuts(nthreads); + for (size_t i = 0; i < nthreads; ++i) { + sparse_cuts[i].reset(new SparseCuts(&cuts_containers[i])); + } + + for (auto const& page : dmat->GetColumnBatches()) { + CHECK_LE(page.Size(), dmat->Info().num_col_); + monitor_.Start("Load balance"); + std::vector col_ptr = LoadBalance(page, nthreads); + monitor_.Stop("Load balance"); + // We here decouples the logic between build and parallelization + // to simplify things a bit. +#pragma omp parallel for num_threads(nthreads) schedule(static) + for (omp_ulong i = 0; i < nthreads; ++i) { + common::Monitor t_monitor; + t_monitor.Init("SingleThreadBuild: " + std::to_string(i)); + t_monitor.Start(std::to_string(i)); + sparse_cuts[i]->SingleThreadBuild(page, dmat->Info(), max_num_bins, use_group, + col_ptr[i], col_ptr[i+1], i); + t_monitor.Stop(std::to_string(i)); + } + + this->Concat(sparse_cuts, dmat->Info().num_col_); + } + + monitor_.Stop(__FUNCTION__); +} + +void SparseCuts::Concat( + std::vector> const& cuts, uint32_t n_cols) { + monitor_.Start(__FUNCTION__); + uint32_t nthreads = omp_get_max_threads(); + p_cuts_->min_vals_.resize(n_cols, std::numeric_limits::max()); + size_t min_vals_tail = 0; + + for (uint32_t t = 0; t < nthreads; ++t) { + // concat csc pointers. + size_t const old_ptr_size = p_cuts_->cut_ptrs_.size(); + p_cuts_->cut_ptrs_.resize( + cuts[t]->p_cuts_->cut_ptrs_.size() + p_cuts_->cut_ptrs_.size() - 1); + size_t const new_icp_size = p_cuts_->cut_ptrs_.size(); + auto tail = p_cuts_->cut_ptrs_[old_ptr_size-1]; + for (size_t j = old_ptr_size; j < new_icp_size; ++j) { + p_cuts_->cut_ptrs_[j] = tail + cuts[t]->p_cuts_->cut_ptrs_[j-old_ptr_size+1]; + } + // concat csc values + size_t const old_iv_size = p_cuts_->cut_values_.size(); + p_cuts_->cut_values_.resize( + cuts[t]->p_cuts_->cut_values_.size() + p_cuts_->cut_values_.size()); + size_t const new_iv_size = p_cuts_->cut_values_.size(); + for (size_t j = old_iv_size; j < new_iv_size; ++j) { + p_cuts_->cut_values_[j] = cuts[t]->p_cuts_->cut_values_[j-old_iv_size]; + } + // merge min values + for (size_t j = 0; j < cuts[t]->p_cuts_->min_vals_.size(); ++j) { + p_cuts_->min_vals_.at(min_vals_tail + j) = + std::min(p_cuts_->min_vals_.at(min_vals_tail + j), cuts.at(t)->p_cuts_->min_vals_.at(j)); + } + min_vals_tail += cuts[t]->p_cuts_->min_vals_.size(); + } + monitor_.Stop(__FUNCTION__); +} + +void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) { + monitor_.Start(__FUNCTION__); const MetaInfo& info = p_fmat->Info(); // safe factor for better accuracy @@ -60,20 +241,18 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) { s.Init(info.num_row_, 1.0 / (max_num_bins * kFactor)); } - const auto& weights = info.weights_.HostVector(); - // Data groups, used in ranking. std::vector const& group_ptr = info.group_ptr_; size_t const num_groups = group_ptr.size() == 0 ? 0 : group_ptr.size() - 1; // Use group index for weights? - bool const use_group_ind = num_groups != 0 && weights.size() != info.num_row_; + bool const use_group = UseGroup(p_fmat); for (const auto &batch : p_fmat->GetRowBatches()) { size_t group_ind = 0; - if (use_group_ind) { - group_ind = this->SearchGroupIndFromBaseRow(group_ptr, batch.base_rowid); + if (use_group) { + group_ind = this->SearchGroupIndFromRow(group_ptr, batch.base_rowid); } -#pragma omp parallel num_threads(nthread) firstprivate(group_ind, use_group_ind) +#pragma omp parallel num_threads(nthread) firstprivate(group_ind, use_group) { CHECK_EQ(nthread, omp_get_num_threads()); auto tid = static_cast(omp_get_thread_num()); @@ -85,7 +264,7 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) { for (size_t i = 0; i < batch.Size(); ++i) { // NOLINT(*) size_t const ridx = batch.base_rowid + i; SparsePage::Inst const inst = batch[i]; - if (use_group_ind && + if (use_group && group_ptr[group_ind] == ridx && // maximum equals to weights.size() - 1 group_ind < num_groups - 1) { @@ -94,7 +273,7 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) { } for (auto const& entry : inst) { if (entry.index >= begin && entry.index < end) { - size_t w_idx = use_group_ind ? group_ind : ridx; + size_t w_idx = use_group ? group_ind : ridx; sketchs[entry.index].Push(entry.fvalue, info.GetWeight(w_idx)); } } @@ -104,10 +283,10 @@ void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) { } Init(&sketchs, max_num_bins); - monitor_.Stop("Init"); + monitor_.Stop(__FUNCTION__); } -void HistCutMatrix::Init +void DenseCuts::Init (std::vector* in_sketchs, uint32_t max_num_bins) { std::vector& sketchs = *in_sketchs; constexpr int kFactor = 8; @@ -124,62 +303,34 @@ void HistCutMatrix::Init CHECK_EQ(summary_array.size(), in_sketchs->size()); size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor); sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size()); - this->min_val.resize(sketchs.size()); - row_ptr.push_back(0); + p_cuts_->min_vals_.resize(sketchs.size()); + for (size_t fid = 0; fid < summary_array.size(); ++fid) { WXQSketch::SummaryContainer a; a.Reserve(max_num_bins); a.SetPrune(summary_array[fid], max_num_bins); const bst_float mval = a.data[0].value; - this->min_val[fid] = mval - (fabs(mval) + 1e-5); - if (a.size > 1 && a.size <= 16) { - /* specialized code categorial / ordinal data -- use midpoints */ - for (size_t i = 1; i < a.size; ++i) { - bst_float cpt = (a.data[i].value + a.data[i - 1].value) / 2.0f; - if (i == 1 || cpt > cut.back()) { - cut.push_back(cpt); - } - } - } else { - for (size_t i = 2; i < a.size; ++i) { - bst_float cpt = a.data[i - 1].value; - if (i == 2 || cpt > cut.back()) { - cut.push_back(cpt); - } - } - } + p_cuts_->min_vals_[fid] = mval - (fabs(mval) + 1e-5); + AddCutPoint(a); // push a value that is greater than anything const bst_float cpt - = (a.size > 0) ? a.data[a.size - 1].value : this->min_val[fid]; + = (a.size > 0) ? a.data[a.size - 1].value : p_cuts_->min_vals_[fid]; // this must be bigger than last value in a scale const bst_float last = cpt + (fabs(cpt) + 1e-5); - cut.push_back(last); + p_cuts_->cut_values_.push_back(last); // Ensure that every feature gets at least one quantile point - CHECK_LE(cut.size(), std::numeric_limits::max()); - auto cut_size = static_cast(cut.size()); - CHECK_GT(cut_size, row_ptr.back()); - row_ptr.push_back(cut_size); + CHECK_LE(p_cuts_->cut_values_.size(), std::numeric_limits::max()); + auto cut_size = static_cast(p_cuts_->cut_values_.size()); + CHECK_GT(cut_size, p_cuts_->cut_ptrs_.back()); + p_cuts_->cut_ptrs_.push_back(cut_size); } } -uint32_t HistCutMatrix::GetBinIdx(const Entry& e) { - unsigned fid = e.index; - auto cbegin = cut.begin() + row_ptr[fid]; - auto cend = cut.begin() + row_ptr[fid + 1]; - CHECK(cbegin != cend); - auto it = std::upper_bound(cbegin, cend, e.fvalue); - if (it == cend) { - it = cend - 1; - } - uint32_t idx = static_cast(it - cut.begin()); - return idx; -} - void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) { - cut.Init(p_fmat, max_num_bins); + cut.Build(p_fmat, max_num_bins); const int32_t nthread = omp_get_max_threads(); - const uint32_t nbins = cut.row_ptr.back(); + const uint32_t nbins = cut.Ptrs().back(); hit_count.resize(nbins, 0); hit_count_tloc_.resize(nthread * nbins, 0); @@ -208,7 +359,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) { #pragma omp parallel num_threads(batch_threads) { #pragma omp for - for (int32_t tid = 0; tid < batch_threads; ++tid) { + for (omp_ulong tid = 0; tid < batch_threads; ++tid) { size_t ibegin = block_size * tid; size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1))); @@ -222,13 +373,13 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) { #pragma omp single { p_part[0] = prev_sum; - for (int32_t i = 1; i < batch_threads; ++i) { + for (size_t i = 1; i < batch_threads; ++i) { p_part[i] = p_part[i - 1] + row_ptr[rbegin + i*block_size]; } } #pragma omp for - for (int32_t tid = 0; tid < batch_threads; ++tid) { + for (omp_ulong tid = 0; tid < batch_threads; ++tid) { size_t ibegin = block_size * tid; size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1))); @@ -240,7 +391,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) { index.resize(row_ptr[rbegin + batch.Size()]); - CHECK_GT(cut.cut.size(), 0U); + CHECK_GT(cut.Values().size(), 0U); #pragma omp parallel for num_threads(batch_threads) schedule(static) for (omp_ulong i = 0; i < batch.Size(); ++i) { // NOLINT(*) @@ -251,7 +402,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) { CHECK_EQ(ibegin + inst.size(), iend); for (bst_uint j = 0; j < inst.size(); ++j) { - uint32_t idx = cut.GetBinIdx(inst[j]); + uint32_t idx = cut.SearchBin(inst[j]); index[ibegin + j] = idx; ++hit_count_tloc_[tid * nbins + idx]; @@ -382,7 +533,7 @@ FastFeatureGrouping(const GHistIndexMatrix& gmat, const ColumnMatrix& colmat, const tree::TrainParam& param) { const size_t nrow = gmat.row_ptr.size() - 1; - const size_t nfeature = gmat.cut.row_ptr.size() - 1; + const size_t nfeature = gmat.cut.Ptrs().size() - 1; std::vector feature_list(nfeature); std::iota(feature_list.begin(), feature_list.end(), 0); @@ -438,7 +589,7 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat, cut_ = &gmat.cut; const size_t nrow = gmat.row_ptr.size() - 1; - const uint32_t nbins = gmat.cut.row_ptr.back(); + const uint32_t nbins = gmat.cut.Ptrs().back(); /* step 1: form feature groups */ auto groups = FastFeatureGrouping(gmat, colmat, param); @@ -448,8 +599,8 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat, std::vector bin2block(nbins); // lookup table [bin id] => [block id] for (uint32_t group_id = 0; group_id < nblock; ++group_id) { for (auto& fid : groups[group_id]) { - const uint32_t bin_begin = gmat.cut.row_ptr[fid]; - const uint32_t bin_end = gmat.cut.row_ptr[fid + 1]; + const uint32_t bin_begin = gmat.cut.Ptrs()[fid]; + const uint32_t bin_end = gmat.cut.Ptrs()[fid + 1]; for (uint32_t bin_id = bin_begin; bin_id < bin_end; ++bin_id) { bin2block[bin_id] = group_id; } @@ -627,8 +778,8 @@ void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) { const size_t block_size = 1024; // aproximatly 1024 values per block size_t n_blocks = size/block_size + !!(size%block_size); - #pragma omp parallel for - for (int iblock = 0; iblock < n_blocks; ++iblock) { +#pragma omp parallel for + for (omp_ulong iblock = 0; iblock < n_blocks; ++iblock) { const size_t ibegin = iblock*block_size; const size_t iend = (((iblock+1)*block_size > size) ? size : ibegin + block_size); for (bst_omp_uint bin_id = ibegin; bin_id < iend; bin_id++) { diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 6c703ff50..586ba0e7e 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -3,6 +3,7 @@ */ #include "./hist_util.h" +#include #include #include @@ -24,7 +25,7 @@ namespace xgboost { namespace common { -using WXQSketch = HistCutMatrix::WXQSketch; +using WXQSketch = DenseCuts::WXQSketch; __global__ void FindCutsK (WXQSketch::Entry* __restrict__ cuts, const bst_float* __restrict__ data, @@ -92,7 +93,7 @@ __global__ void UnpackFeaturesK * across distinct rows. */ struct SketchContainer { - std::vector sketches_; // NOLINT + std::vector sketches_; // NOLINT std::vector col_locks_; // NOLINT static constexpr int kOmpNumColsParallelizeLimit = 1000; @@ -300,7 +301,7 @@ struct GPUSketcher { } else if (n_cuts_cur_[icol] > 0) { // if more elements than cuts: use binary search on cumulative weights int block = 256; - FindCutsK<<>> + FindCutsK<<>> (cuts_d_.data().get() + icol * n_cuts_, fvalues_cur_.data().get(), weights2_.data().get(), n_unique, n_cuts_cur_[icol]); dh::safe_cuda(cudaGetLastError()); // NOLINT @@ -342,8 +343,8 @@ struct GPUSketcher { dim3 block3(16, 64, 1); // NOTE: This will typically support ~ 4M features - 64K*64 - dim3 grid3(dh::DivRoundUp(batch_nrows, block3.x), - dh::DivRoundUp(num_cols_, block3.y), 1); + dim3 grid3(common::DivRoundUp(batch_nrows, block3.x), + common::DivRoundUp(num_cols_, block3.y), 1); UnpackFeaturesK<<>> (fvalues_.data().get(), has_weights_ ? feature_weights_.data().get() : nullptr, row_ptrs_.data().get() + batch_row_begin, @@ -392,7 +393,7 @@ struct GPUSketcher { row_ptrs_.resize(n_rows_ + 1); thrust::copy(offset_vec.data() + row_begin_, offset_vec.data() + row_end_ + 1, row_ptrs_.begin()); - size_t gpu_nbatches = dh::DivRoundUp(n_rows_, gpu_batch_nrows_); + size_t gpu_nbatches = common::DivRoundUp(n_rows_, gpu_batch_nrows_); for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) { SketchBatch(row_batch, info, gpu_batch); } @@ -434,7 +435,7 @@ struct GPUSketcher { /* Builds the sketches on the GPU for the dmatrix and returns the row stride * for the entire dataset */ - size_t Sketch(DMatrix *dmat, HistCutMatrix *hmat) { + size_t Sketch(DMatrix *dmat, DenseCuts *hmat) { const MetaInfo &info = dmat->Info(); row_stride_ = 0; @@ -459,9 +460,13 @@ struct GPUSketcher { size_t DeviceSketch (const tree::TrainParam ¶m, const LearnerTrainParam &learner_param, int gpu_batch_nrows, - DMatrix *dmat, HistCutMatrix *hmat) { + DMatrix *dmat, HistogramCuts *hmat) { GPUSketcher sketcher(param, learner_param, gpu_batch_nrows); - return sketcher.Sketch(dmat, hmat); + // We only need to return the result in HistogramCuts container, so it is safe to + // use a pointer of local HistogramCutsDense + DenseCuts dense_cuts(hmat); + auto res = sketcher.Sketch(dmat, &dense_cuts); + return res; } } // namespace common diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 0cef1878e..d3b0c0a03 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -12,18 +12,21 @@ #include #include #include +#include #include + #include "row_set.h" #include "../tree/param.h" #include "./quantile.h" #include "./timer.h" -#include "../include/rabit/rabit.h" #include "random.h" namespace xgboost { /*! - * \brief A C-style array with in-stack allocation. As long as the array is smaller than MaxStackSize, it will be allocated inside the stack. Otherwise, it will be heap-allocated. + * \brief A C-style array with in-stack allocation. As long as the array is smaller than + * MaxStackSize, it will be allocated inside the stack. Otherwise, it will be + * heap-allocated. */ template class MemStackAllocator { @@ -122,47 +125,175 @@ struct SimpleArray { size_t n_ = 0; }; -/*! \brief Cut configuration for all the features. */ -struct HistCutMatrix { - /*! \brief Unit pointer to rows by element position */ - std::vector row_ptr; - /*! \brief minimum value of each feature */ - std::vector min_val; - /*! \brief the cut field */ - std::vector cut; - uint32_t GetBinIdx(const Entry &e); +/*! + * \brief A single row in global histogram index. + * Directly represent the global index in the histogram entry. + */ +using GHistIndexRow = Span; - using WXQSketch = common::WXQuantileSketch; - - // create histogram cut matrix given statistics from data - // using approximate quantile sketch approach - void Init(DMatrix* p_fmat, uint32_t max_num_bins); - - void Init(std::vector* sketchs, uint32_t max_num_bins); - - HistCutMatrix(); - size_t NumBins() const { return row_ptr.back(); } +// A CSC matrix representing histogram cuts, used in CPU quantile hist. +class HistogramCuts { + // Using friends to avoid creating a virtual class, since HistogramCuts is used as value + // object in many places. + friend class SparseCuts; + friend class DenseCuts; + friend class CutsBuilder; protected: - virtual size_t SearchGroupIndFromBaseRow( - std::vector const& group_ptr, size_t const base_rowid) const; + using BinIdx = uint32_t; + common::Monitor monitor_; - Monitor monitor_; + std::vector cut_values_; + std::vector cut_ptrs_; + std::vector min_vals_; // storing minimum value in a sketch set. + + public: + HistogramCuts(); + HistogramCuts(HistogramCuts const& that) = delete; + HistogramCuts(HistogramCuts&& that) noexcept(true) { + *this = std::forward(that); + } + HistogramCuts& operator=(HistogramCuts const& that) = delete; + HistogramCuts& operator=(HistogramCuts&& that) noexcept(true) { + monitor_ = std::move(that.monitor_); + cut_ptrs_ = std::move(that.cut_ptrs_); + cut_values_ = std::move(that.cut_values_); + min_vals_ = std::move(that.min_vals_); + return *this; + } + + /* \brief Build histogram cuts. */ + void Build(DMatrix* dmat, uint32_t const max_num_bins); + /* \brief How many bins a feature has. */ + uint32_t FeatureBins(uint32_t feature) const { + return cut_ptrs_.at(feature+1) - cut_ptrs_[feature]; + } + + // Getters. Cuts should be of no use after building histogram indices, but currently + // it's deeply linked with quantile_hist, gpu sketcher and gpu_hist. So we preserve + // these for now. + std::vector const& Ptrs() const { return cut_ptrs_; } + std::vector const& Values() const { return cut_values_; } + std::vector const& MinValues() const { return min_vals_; } + + size_t TotalBins() const { return cut_ptrs_.back(); } + + BinIdx SearchBin(float value, uint32_t column_id) { + auto beg = cut_ptrs_.at(column_id); + auto end = cut_ptrs_.at(column_id + 1); + auto it = std::upper_bound(cut_values_.cbegin() + beg, cut_values_.cbegin() + end, value); + if (it == cut_values_.cend()) { + it = cut_values_.cend() - 1; + } + BinIdx idx = it - cut_values_.cbegin(); + return idx; + } + + BinIdx SearchBin(Entry const& e) { + return SearchBin(e.fvalue, e.index); + } }; +/* \brief An interface for building quantile cuts. + * + * `DenseCuts' always assumes there are `max_bins` for each feature, which makes it not + * suitable for sparse dataset. On the other hand `SparseCuts' uses `GetColumnBatches', + * which doubles the memory usage, hence can not be applied to dense dataset. + */ +class CutsBuilder { + public: + using WXQSketch = common::WXQuantileSketch; + + protected: + HistogramCuts* p_cuts_; + /* \brief return whether group for ranking is used. */ + static bool UseGroup(DMatrix* dmat); + + public: + explicit CutsBuilder(HistogramCuts* p_cuts) : p_cuts_{p_cuts} {} + virtual ~CutsBuilder() = default; + + static uint32_t SearchGroupIndFromRow( + std::vector const& group_ptr, size_t const base_rowid) { + using KIt = std::vector::const_iterator; + KIt res = std::lower_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid); + // Cannot use CHECK_NE because it will try to print the iterator. + bool const found = res != group_ptr.cend() - 1; + if (!found) { + LOG(FATAL) << "Row " << base_rowid << " does not lie in any group!"; + } + uint32_t group_ind = std::distance(group_ptr.cbegin(), res); + return group_ind; + } + + void AddCutPoint(WXQSketch::SummaryContainer const& summary) { + if (summary.size > 1 && summary.size <= 16) { + /* specialized code categorial / ordinal data -- use midpoints */ + for (size_t i = 1; i < summary.size; ++i) { + bst_float cpt = (summary.data[i].value + summary.data[i - 1].value) / 2.0f; + if (i == 1 || cpt > p_cuts_->cut_values_.back()) { + p_cuts_->cut_values_.push_back(cpt); + } + } + } else { + for (size_t i = 2; i < summary.size; ++i) { + bst_float cpt = summary.data[i - 1].value; + if (i == 2 || cpt > p_cuts_->cut_values_.back()) { + p_cuts_->cut_values_.push_back(cpt); + } + } + } + } + + /* \brief Build histogram indices. */ + virtual void Build(DMatrix* dmat, uint32_t const max_num_bins) = 0; +}; + +/*! \brief Cut configuration for sparse dataset. */ +class SparseCuts : public CutsBuilder { + /* \brief Distrbute columns to each thread according to number of entries. */ + static std::vector LoadBalance(SparsePage const& page, size_t const nthreads); + Monitor monitor_; + + public: + explicit SparseCuts(HistogramCuts* container) : + CutsBuilder(container) { + monitor_.Init(__FUNCTION__); + } + + /* \brief Concatonate the built cuts in each thread. */ + void Concat(std::vector> const& cuts, uint32_t n_cols); + /* \brief Build histogram indices in single thread. */ + void SingleThreadBuild(SparsePage const& page, MetaInfo const& info, + uint32_t max_num_bins, + bool const use_group_ind, + uint32_t beg, uint32_t end, uint32_t thread_id); + void Build(DMatrix* dmat, uint32_t const max_num_bins) override; +}; + +/*! \brief Cut configuration for dense dataset. */ +class DenseCuts : public CutsBuilder { + protected: + Monitor monitor_; + + public: + explicit DenseCuts(HistogramCuts* container) : + CutsBuilder(container) { + monitor_.Init(__FUNCTION__); + } + void Init(std::vector* sketchs, uint32_t max_num_bins); + void Build(DMatrix* p_fmat, uint32_t max_num_bins) override; +}; + +// FIXME(trivialfis): Merge this into generic cut builder. /*! \brief Builds the cut matrix on the GPU. * * \return The row stride across the entire dataset. */ size_t DeviceSketch (const tree::TrainParam& param, const LearnerTrainParam &learner_param, int gpu_batch_nrows, - DMatrix* dmat, HistCutMatrix* hmat); + DMatrix* dmat, HistogramCuts* hmat); -/*! - * \brief A single row in global histogram index. - * Directly represent the global index in the histogram entry. - */ -using GHistIndexRow = Span; /*! * \brief preprocessed global index matrix, in CSR format @@ -178,7 +309,7 @@ struct GHistIndexMatrix { /*! \brief hit count of each index */ std::vector hit_count; /*! \brief The corresponding cuts */ - HistCutMatrix cut; + HistogramCuts cut; // Create a global histogram matrix, given cut void Init(DMatrix* p_fmat, int max_num_bins); // get i-th row @@ -188,10 +319,10 @@ struct GHistIndexMatrix { row_ptr[i + 1] - row_ptr[i])}; } inline void GetFeatureCounts(size_t* counts) const { - auto nfeature = cut.row_ptr.size() - 1; + auto nfeature = cut.Ptrs().size() - 1; for (unsigned fid = 0; fid < nfeature; ++fid) { - auto ibegin = cut.row_ptr[fid]; - auto iend = cut.row_ptr[fid + 1]; + auto ibegin = cut.Ptrs()[fid]; + auto iend = cut.Ptrs()[fid + 1]; for (auto i = ibegin; i < iend; ++i) { counts[fid] += hit_count[i]; } @@ -234,7 +365,7 @@ class GHistIndexBlockMatrix { private: std::vector row_ptr_; std::vector index_; - const HistCutMatrix* cut_; + const HistogramCuts* cut_; struct Block { const size_t* row_ptr_begin; const size_t* row_ptr_end; diff --git a/src/common/span.h b/src/common/span.h index f33c2eb89..96c7200ac 100644 --- a/src/common/span.h +++ b/src/common/span.h @@ -549,7 +549,7 @@ class Span { detail::ExtentValue::value> { SPAN_CHECK(Offset >= 0 && (Offset < size() || size() == 0)); SPAN_CHECK(Count == dynamic_extent || - Count >= 0 && Offset + Count <= size()); + (Count >= 0 && Offset + Count <= size())); return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count}; } diff --git a/src/common/transform.h b/src/common/transform.h index 62ef433ef..b1bc55322 100644 --- a/src/common/transform.h +++ b/src/common/transform.h @@ -60,7 +60,7 @@ class Transform { Evaluator(Functor func, Range range, GPUSet devices, bool shard) : func_(func), range_{std::move(range)}, shard_{shard}, - distribution_{std::move(GPUDistribution::Block(devices))} {} + distribution_{GPUDistribution::Block(devices)} {} Evaluator(Functor func, Range range, GPUDistribution dist, bool shard) : func_(func), range_{std::move(range)}, shard_{shard}, @@ -142,7 +142,7 @@ class Transform { Range shard_range {0, static_cast(shard_size)}; dh::safe_cuda(cudaSetDevice(device)); const int GRID_SIZE = - static_cast(dh::DivRoundUp(*(range_.end()), kBlockThreads)); + static_cast(DivRoundUp(*(range_.end()), kBlockThreads)); detail::LaunchCUDAKernel<<>>( _func, shard_range, UnpackHDV(_vectors, device)...); } diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 8a742d32d..45894aeec 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -52,14 +52,14 @@ class SparsePageSource : public DataSource { * \param page_size Page size for external memory. */ static void CreateRowPage(dmlc::Parser* src, - const std::string& cache_info, - const size_t page_size = DMatrix::kPageSize); + const std::string& cache_info, + const size_t page_size = DMatrix::kPageSize); /*! * \brief Create source cache by copy content from DMatrix. * \param cache_info The cache_info of cache file location. */ static void CreateRowPage(DMatrix* src, - const std::string& cache_info); + const std::string& cache_info); /*! * \brief Create source cache by copy content from DMatrix. Creates transposed column page, may be sorted or not. @@ -67,7 +67,7 @@ class SparsePageSource : public DataSource { * \param sorted Whether columns should be pre-sorted */ static void CreateColumnPage(DMatrix* src, - const std::string& cache_info, bool sorted); + const std::string& cache_info, bool sorted); /*! * \brief Check if the cache file already exists. * \param cache_info The cache prefix of files. diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index a817fdccb..d67c2963c 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -238,7 +238,7 @@ class GPUPredictor : public xgboost::Predictor { auto& offsets = *out_offsets; size_t n_shards = devices_.Size(); offsets.resize(n_shards + 2); - size_t rows_per_shard = dh::DivRoundUp(batch_size, n_shards); + size_t rows_per_shard = common::DivRoundUp(batch_size, n_shards); for (size_t shard = 0; shard < devices_.Size(); ++shard) { size_t n_rows = std::min(batch_size, shard * rows_per_shard); offsets[shard] = batch_offset + n_rows * n_classes; @@ -284,7 +284,7 @@ class GPUPredictor : public xgboost::Predictor { dh::safe_cuda(cudaSetDevice(device_)); const int BLOCK_THREADS = 128; size_t num_rows = batch.offset.DeviceSize(device_) - 1; - const int GRID_SIZE = static_cast(dh::DivRoundUp(num_rows, BLOCK_THREADS)); + const int GRID_SIZE = static_cast(common::DivRoundUp(num_rows, BLOCK_THREADS)); int shared_memory_bytes = static_cast (sizeof(float) * num_features * BLOCK_THREADS); diff --git a/src/tree/constraints.cu b/src/tree/constraints.cu index 6adeff5ba..2000298bd 100644 --- a/src/tree/constraints.cu +++ b/src/tree/constraints.cu @@ -170,7 +170,7 @@ void FeatureInteractionConstraint::ClearBuffers() { CHECK_LE(feature_buffer_.Size(), output_buffer_bits_.Size()); int constexpr kBlockThreads = 256; const int n_grids = static_cast( - dh::DivRoundUp(input_buffer_bits_.Size(), kBlockThreads)); + common::DivRoundUp(input_buffer_bits_.Size(), kBlockThreads)); ClearBuffersKernel<<>>( output_buffer_bits_, input_buffer_bits_); } @@ -227,7 +227,7 @@ common::Span FeatureInteractionConstraint::Query( int constexpr kBlockThreads = 256; const int n_grids = static_cast( - dh::DivRoundUp(output_buffer_bits_.Size(), kBlockThreads)); + common::DivRoundUp(output_buffer_bits_.Size(), kBlockThreads)); SetInputBufferKernel<<>>(feature_list, input_buffer_bits_); QueryFeatureListKernel<<>>( @@ -328,8 +328,8 @@ void FeatureInteractionConstraint::Split( BitField right = s_node_constraints_[right_id]; dim3 const block3(16, 64, 1); - dim3 const grid3(dh::DivRoundUp(n_sets_, 16), - dh::DivRoundUp(s_fconstraints_.size(), 64)); + dim3 const grid3(common::DivRoundUp(n_sets_, 16), + common::DivRoundUp(s_fconstraints_.size(), 64)); RestoreFeatureListFromSetsKernel<<>> (feature_buffer_, feature_id, @@ -339,7 +339,7 @@ void FeatureInteractionConstraint::Split( s_sets_ptr_); int constexpr kBlockThreads = 256; - const int n_grids = static_cast(dh::DivRoundUp(node.Size(), kBlockThreads)); + const int n_grids = static_cast(common::DivRoundUp(node.Size(), kBlockThreads)); InteractionConstraintSplitKernel<<>> (feature_buffer_, feature_id, diff --git a/src/tree/updater_gpu.cu b/src/tree/updater_gpu.cu index bbb72f8aa..2088bfee7 100644 --- a/src/tree/updater_gpu.cu +++ b/src/tree/updater_gpu.cu @@ -76,7 +76,7 @@ static const int kNoneKey = -100; */ template int ScanTempBufferSize(int size) { - int num_blocks = dh::DivRoundUp(size, BLKDIM_L1L3); + int num_blocks = common::DivRoundUp(size, BLKDIM_L1L3); return num_blocks; } @@ -250,7 +250,7 @@ void ReduceScanByKey(common::Span sums, common::Span tmpScans, common::Span tmpKeys, common::Span colIds, NodeIdT nodeStart) { - int nBlks = dh::DivRoundUp(size, BLKDIM_L1L3); + int nBlks = common::DivRoundUp(size, BLKDIM_L1L3); cudaMemset(sums.data(), 0, nUniqKeys * nCols * sizeof(GradientPair)); CubScanByKeyL1 <<>>(scans, vals, instIds, tmpScans, tmpKeys, keys, @@ -448,7 +448,7 @@ void ArgMaxByKey(common::Span nodeSplits, dh::FillConst( *(devices.begin()), nodeSplits.data(), nUniqKeys, ExactSplitCandidate()); - int nBlks = dh::DivRoundUp(len, ITEMS_PER_THREAD * BLKDIM); + int nBlks = common::DivRoundUp(len, ITEMS_PER_THREAD * BLKDIM); switch (algo) { case kAbkGmem: AtomicArgMaxByKeyGmem<<>>( @@ -793,11 +793,11 @@ class GPUMaker : public TreeUpdater { const int BlkDim = 256; const int ItemsPerThread = 4; // assign default node ids first - int nBlks = dh::DivRoundUp(n_rows_, BlkDim); + int nBlks = common::DivRoundUp(n_rows_, BlkDim); FillDefaultNodeIds<<>>(node_assigns_per_inst_.data(), nodes_.data(), n_rows_); // evaluate the correct child indices of non-missing values next - nBlks = dh::DivRoundUp(n_vals_, BlkDim * ItemsPerThread); + nBlks = common::DivRoundUp(n_vals_, BlkDim * ItemsPerThread); AssignNodeIds<<>>( node_assigns_per_inst_.data(), nodeLocations_.Current(), nodeAssigns_.Current(), instIds_.Current(), nodes_.data(), @@ -823,7 +823,7 @@ class GPUMaker : public TreeUpdater { void MarkLeaves() { const int BlkDim = 128; - int nBlks = dh::DivRoundUp(maxNodes_, BlkDim); + int nBlks = common::DivRoundUp(maxNodes_, BlkDim); MarkLeavesKernel<<>>(nodes_.data(), maxNodes_); } }; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 714e6258c..bd43e40e5 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -480,8 +480,8 @@ __global__ void CompressBinEllpackKernel( common::CompressedByteT* __restrict__ buffer, // gidx_buffer const size_t* __restrict__ row_ptrs, // row offset of input data const Entry* __restrict__ entries, // One batch of input data - const float* __restrict__ cuts, // HistCutMatrix::cut - const uint32_t* __restrict__ cut_rows, // HistCutMatrix::row_ptrs + const float* __restrict__ cuts, // HistogramCuts::cut + const uint32_t* __restrict__ cut_rows, // HistogramCuts::row_ptrs size_t base_row, // batch_row_begin size_t n_rows, size_t row_stride, @@ -593,7 +593,7 @@ struct DeviceShard { std::unique_ptr row_partitioner; DeviceHistogram hist; - /*! \brief row_ptr form HistCutMatrix. */ + /*! \brief row_ptr form HistogramCuts. */ common::Span feature_segments; /*! \brief minimum value for each feature. */ common::Span min_fvalue; @@ -654,10 +654,10 @@ struct DeviceShard { } void InitCompressedData( - const common::HistCutMatrix& hmat, size_t row_stride, bool is_dense); + const common::HistogramCuts& hmat, size_t row_stride, bool is_dense); void CreateHistIndices( - const SparsePage &row_batch, const common::HistCutMatrix &hmat, + const SparsePage &row_batch, const common::HistogramCuts &hmat, const RowStateOnDevice &device_row_state, int rows_per_batch); ~DeviceShard() { @@ -718,7 +718,7 @@ struct DeviceShard { // Work out cub temporary memory requirement GPUTrainingParam gpu_param(param); DeviceSplitCandidateReduceOp op(gpu_param); - size_t temp_storage_bytes; + size_t temp_storage_bytes = 0; DeviceSplitCandidate*dummy = nullptr; cub::DeviceReduce::Reduce( nullptr, temp_storage_bytes, dummy, @@ -806,7 +806,7 @@ struct DeviceShard { const int items_per_thread = 8; const int block_threads = 256; const int grid_size = static_cast( - dh::DivRoundUp(n_elements, items_per_thread * block_threads)); + common::DivRoundUp(n_elements, items_per_thread * block_threads)); if (grid_size <= 0) { return; } @@ -1106,9 +1106,9 @@ struct DeviceShard { template inline void DeviceShard::InitCompressedData( - const common::HistCutMatrix &hmat, size_t row_stride, bool is_dense) { - n_bins = hmat.row_ptr.back(); - int null_gidx_value = hmat.row_ptr.back(); + const common::HistogramCuts &hmat, size_t row_stride, bool is_dense) { + n_bins = hmat.Ptrs().back(); + int null_gidx_value = hmat.Ptrs().back(); CHECK(!(param.max_leaves == 0 && param.max_depth == 0)) << "Max leaves and max depth cannot both be unconstrained for " @@ -1121,14 +1121,14 @@ inline void DeviceShard::InitCompressedData( &gpair, n_rows, &prediction_cache, n_rows, &node_sum_gradients_d, max_nodes, - &feature_segments, hmat.row_ptr.size(), - &gidx_fvalue_map, hmat.cut.size(), - &min_fvalue, hmat.min_val.size(), + &feature_segments, hmat.Ptrs().size(), + &gidx_fvalue_map, hmat.Values().size(), + &min_fvalue, hmat.MinValues().size(), &monotone_constraints, param.monotone_constraints.size()); - dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.cut); - dh::CopyVectorToDeviceSpan(min_fvalue, hmat.min_val); - dh::CopyVectorToDeviceSpan(feature_segments, hmat.row_ptr); + dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.Values()); + dh::CopyVectorToDeviceSpan(min_fvalue, hmat.MinValues()); + dh::CopyVectorToDeviceSpan(feature_segments, hmat.Ptrs()); dh::CopyVectorToDeviceSpan(monotone_constraints, param.monotone_constraints); node_sum_gradients.resize(max_nodes); @@ -1153,26 +1153,26 @@ inline void DeviceShard::InitCompressedData( // check if we can use shared memory for building histograms // (assuming atleast we need 2 CTAs per SM to maintain decent latency // hiding) - auto histogram_size = sizeof(GradientSumT) * hmat.row_ptr.back(); + auto histogram_size = sizeof(GradientSumT) * hmat.Ptrs().back(); auto max_smem = dh::MaxSharedMemory(device_id); if (histogram_size <= max_smem) { use_shared_memory_histograms = true; } // Init histogram - hist.Init(device_id, hmat.NumBins()); + hist.Init(device_id, hmat.Ptrs().back()); } template inline void DeviceShard::CreateHistIndices( const SparsePage &row_batch, - const common::HistCutMatrix &hmat, + const common::HistogramCuts &hmat, const RowStateOnDevice &device_row_state, int rows_per_batch) { // Has any been allocated for me in this batch? if (!device_row_state.rows_to_process_from_batch) return; - unsigned int null_gidx_value = hmat.row_ptr.back(); + unsigned int null_gidx_value = hmat.Ptrs().back(); size_t row_stride = this->ellpack_matrix.row_stride; const auto &offset_vec = row_batch.offset.ConstHostVector(); @@ -1184,8 +1184,8 @@ inline void DeviceShard::CreateHistIndices( static_cast(device_row_state.rows_to_process_from_batch)); const std::vector& data_vec = row_batch.data.ConstHostVector(); - size_t gpu_nbatches = dh::DivRoundUp(device_row_state.rows_to_process_from_batch, - gpu_batch_nrows); + size_t gpu_nbatches = common::DivRoundUp(device_row_state.rows_to_process_from_batch, + gpu_batch_nrows); for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) { size_t batch_row_begin = gpu_batch * gpu_batch_nrows; @@ -1216,8 +1216,8 @@ inline void DeviceShard::CreateHistIndices( (entries_d.data().get(), data_vec.data() + ent_cnt_begin, n_entries * sizeof(Entry), cudaMemcpyDefault)); const dim3 block3(32, 8, 1); // 256 threads - const dim3 grid3(dh::DivRoundUp(batch_nrows, block3.x), - dh::DivRoundUp(row_stride, block3.y), 1); + const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x), + common::DivRoundUp(row_stride, block3.y), 1); CompressBinEllpackKernel<<>> (common::CompressedBufferWriter(num_symbols), gidx_buffer.data(), @@ -1361,13 +1361,13 @@ class GPUHistMakerSpecialised { }); monitor_.StartCuda("Quantiles"); - // Create the quantile sketches for the dmatrix and initialize HistCutMatrix + // Create the quantile sketches for the dmatrix and initialize HistogramCuts size_t row_stride = common::DeviceSketch(param_, *learner_param_, hist_maker_param_.gpu_batch_nrows, dmat, &hmat_); monitor_.StopCuda("Quantiles"); - n_bins_ = hmat_.row_ptr.back(); + n_bins_ = hmat_.Ptrs().back(); auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_; @@ -1475,9 +1475,9 @@ class GPUHistMakerSpecialised { return true; } - TrainParam param_; // NOLINT - common::HistCutMatrix hmat_; // NOLINT - MetaInfo* info_; // NOLINT + TrainParam param_; // NOLINT + common::HistogramCuts hmat_; // NOLINT + MetaInfo* info_; // NOLINT std::vector>> shards_; // NOLINT diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 9899ea61d..f67cf2d39 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -247,15 +247,15 @@ int32_t QuantileHistMaker::Builder::FindSplitCond(int32_t nid, // Categorize member rows const bst_uint fid = node.SplitIndex(); const bst_float split_pt = node.SplitCond(); - const uint32_t lower_bound = gmat.cut.row_ptr[fid]; - const uint32_t upper_bound = gmat.cut.row_ptr[fid + 1]; + const uint32_t lower_bound = gmat.cut.Ptrs()[fid]; + const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1]; int32_t split_cond = -1; // convert floating-point split_pt into corresponding bin_id // split_cond = -1 indicates that split_pt is less than all known cut points CHECK_LT(upper_bound, static_cast(std::numeric_limits::max())); for (uint32_t i = lower_bound; i < upper_bound; ++i) { - if (split_pt == gmat.cut.cut[i]) { + if (split_pt == gmat.cut.Values()[i]) { split_cond = static_cast(i); } } @@ -533,7 +533,7 @@ void QuantileHistMaker::Builder::BuildHistsBatch(const std::vector& perf_monitor.TickStart(); const size_t block_size_rows = 256; const size_t nthread = static_cast(this->nthread_); - const size_t nbins = gmat.cut.row_ptr.back(); + const size_t nbins = gmat.cut.Ptrs().back(); const size_t hist_size = 2 * nbins; hist_buffers->resize(nodes.size()); @@ -856,8 +856,8 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( } } - #pragma omp parallel for schedule(guided) - for (int32_t k = 0; k < tasks_elem.size(); ++k) { +#pragma omp parallel for schedule(guided) + for (omp_ulong k = 0; k < tasks_elem.size(); ++k) { const RowSetCollection::Elem rowset = tasks_elem[k]; if (rowset.begin != nullptr && rowset.end != nullptr && rowset.node_id != -1) { const size_t nrows = rowset.Size(); @@ -909,7 +909,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, // clear local prediction cache leaf_value_cache_.clear(); // initialize histogram collection - uint32_t nbins = gmat.cut.row_ptr.back(); + uint32_t nbins = gmat.cut.Ptrs().back(); hist_.Init(nbins); hist_buff_.Init(nbins); @@ -999,7 +999,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, const size_t ncol = info.num_col_; const size_t nnz = info.num_nonzero_; // number of discrete bins for feature 0 - const uint32_t nbins_f0 = gmat.cut.row_ptr[1] - gmat.cut.row_ptr[0]; + const uint32_t nbins_f0 = gmat.cut.Ptrs()[1] - gmat.cut.Ptrs()[0]; if (nrow * ncol == nnz) { // dense data with zero-based indexing data_layout_ = kDenseDataZeroBased; @@ -1029,7 +1029,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat, choose the column that has a least positive number of discrete bins. For dense data (with no missing value), the sum of gradient histogram is equal to snode[nid] */ - const std::vector& row_ptr = gmat.cut.row_ptr; + const std::vector& row_ptr = gmat.cut.Ptrs(); const auto nfeature = static_cast(row_ptr.size() - 1); uint32_t min_nbins_per_feature = 0; for (bst_uint i = 0; i < nfeature; ++i) { @@ -1079,8 +1079,8 @@ void QuantileHistMaker::Builder::EvaluateSplitsBatch( // partial results std::vector> splits(tasks.size()); // parallel enumeration - #pragma omp parallel for schedule(guided) - for (int32_t i = 0; i < tasks.size(); ++i) { +#pragma omp parallel for schedule(guided) + for (omp_ulong i = 0; i < tasks.size(); ++i) { // node_idx : offset within `nodes` list const int32_t node_idx = tasks[i].first; const size_t fid = tasks[i].second; @@ -1098,7 +1098,7 @@ void QuantileHistMaker::Builder::EvaluateSplitsBatch( // reduce needed part of a hist here to have it in cache before enumeration if (!rabit::IsDistributed()) { - const std::vector& cut_ptr = gmat.cut.row_ptr; + const std::vector& cut_ptr = gmat.cut.Ptrs(); const size_t ibegin = 2 * cut_ptr[fid]; const size_t iend = 2 * cut_ptr[fid + 1]; ReduceHistograms(hist_data, sibling_hist_data, parent_hist_data, ibegin, iend, node_idx, @@ -1179,8 +1179,8 @@ bool QuantileHistMaker::Builder::EnumerateSplit(int d_step, CHECK(d_step == +1 || d_step == -1); // aliases - const std::vector& cut_ptr = gmat.cut.row_ptr; - const std::vector& cut_val = gmat.cut.cut; + const std::vector& cut_ptr = gmat.cut.Ptrs(); + const std::vector& cut_val = gmat.cut.Values(); // statistics on both sides of split GradStats c; @@ -1239,7 +1239,7 @@ bool QuantileHistMaker::Builder::EnumerateSplit(int d_step, if (i == imin) { // for leftmost bin, left bound is the smallest feature value - split_pt = gmat.cut.min_val[fid]; + split_pt = gmat.cut.MinValues()[fid]; } else { split_pt = cut_val[i - 1]; } diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 592257321..ff814657d 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -33,7 +33,6 @@ namespace common { } namespace tree { -using xgboost::common::HistCutMatrix; using xgboost::common::GHistIndexMatrix; using xgboost::common::GHistIndexBlockMatrix; using xgboost::common::GHistIndexRow; diff --git a/tests/cpp/c_api/test_c_api.cc b/tests/cpp/c_api/test_c_api.cc index 5942dc030..629518b01 100644 --- a/tests/cpp/c_api/test_c_api.cc +++ b/tests/cpp/c_api/test_c_api.cc @@ -53,10 +53,10 @@ TEST(c_api, XGDMatrixCreateFromMat_omp) { ASSERT_EQ(info.num_nonzero_, num_cols * row - num_missing); for (const auto &batch : (*dmat)->GetRowBatches()) { - for (int i = 0; i < batch.Size(); i++) { + for (size_t i = 0; i < batch.Size(); i++) { auto inst = batch[i]; - for (int j = 0; i < inst.size(); i++) { - ASSERT_EQ(inst[j].fvalue, 1.5); + for (auto e : inst) { + ASSERT_EQ(e.fvalue, 1.5); } } } diff --git a/tests/cpp/common/test_column_matrix.cc b/tests/cpp/common/test_column_matrix.cc index 25f2688b2..8b6f69a6a 100644 --- a/tests/cpp/common/test_column_matrix.cc +++ b/tests/cpp/common/test_column_matrix.cc @@ -7,6 +7,7 @@ namespace xgboost { namespace common { + TEST(DenseColumn, Test) { auto dmat = CreateDMatrix(100, 10, 0.0); GHistIndexMatrix gmat; @@ -17,7 +18,7 @@ TEST(DenseColumn, Test) { for (auto i = 0ull; i < (*dmat)->Info().num_row_; i++) { for (auto j = 0ull; j < (*dmat)->Info().num_col_; j++) { auto col = column_matrix.GetColumn(j); - EXPECT_EQ(gmat.index[i * (*dmat)->Info().num_col_ + j], + ASSERT_EQ(gmat.index[i * (*dmat)->Info().num_col_ + j], col.GetGlobalBinIdx(i)); } } @@ -33,7 +34,7 @@ TEST(SparseColumn, Test) { auto col = column_matrix.GetColumn(0); ASSERT_EQ(col.Size(), gmat.index.size()); for (auto i = 0ull; i < col.Size(); i++) { - EXPECT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]], + ASSERT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]], col.GetGlobalBinIdx(i)); } delete dmat; diff --git a/tests/cpp/common/test_compressed_iterator.cc b/tests/cpp/common/test_compressed_iterator.cc index fe52a2be2..93243c0b3 100644 --- a/tests/cpp/common/test_compressed_iterator.cc +++ b/tests/cpp/common/test_compressed_iterator.cc @@ -28,7 +28,7 @@ TEST(CompressedIterator, Test) { CompressedIterator ci(buffer.data(), alphabet_size); std::vector output(input.size()); - for (int i = 0; i < input.size(); i++) { + for (size_t i = 0; i < input.size(); i++) { output[i] = ci[i]; } @@ -38,12 +38,12 @@ TEST(CompressedIterator, Test) { std::vector buffer2( CompressedBufferWriter::CalculateBufferSize(input.size(), alphabet_size)); - for (int i = 0; i < input.size(); i++) { + for (size_t i = 0; i < input.size(); i++) { cbw.WriteSymbol(buffer2.data(), input[i], i); } CompressedIterator ci2(buffer.data(), alphabet_size); std::vector output2(input.size()); - for (int i = 0; i < input.size(); i++) { + for (size_t i = 0; i < input.size(); i++) { output2[i] = ci2[i]; } ASSERT_TRUE(input == output2); diff --git a/tests/cpp/common/test_gpu_hist_util.cu b/tests/cpp/common/test_gpu_hist_util.cu index 4ade9ab80..a63cc08e0 100644 --- a/tests/cpp/common/test_gpu_hist_util.cu +++ b/tests/cpp/common/test_gpu_hist_util.cu @@ -48,11 +48,11 @@ void TestDeviceSketch(const GPUSet& devices, bool use_external_memory) { int gpu_batch_nrows = 0; // find quantiles on the CPU - HistCutMatrix hmat_cpu; - hmat_cpu.Init((*dmat).get(), p.max_bin); + HistogramCuts hmat_cpu; + hmat_cpu.Build((*dmat).get(), p.max_bin); // find the cuts on the GPU - HistCutMatrix hmat_gpu; + HistogramCuts hmat_gpu; size_t row_stride = DeviceSketch(p, CreateEmptyGenericParam(0, devices.Size()), gpu_batch_nrows, dmat->get(), &hmat_gpu); @@ -69,12 +69,12 @@ void TestDeviceSketch(const GPUSet& devices, bool use_external_memory) { // compare the cuts double eps = 1e-2; - ASSERT_EQ(hmat_gpu.min_val.size(), num_cols); - ASSERT_EQ(hmat_gpu.row_ptr.size(), num_cols + 1); - ASSERT_EQ(hmat_gpu.cut.size(), hmat_cpu.cut.size()); - ASSERT_LT(fabs(hmat_cpu.min_val[0] - hmat_gpu.min_val[0]), eps * nrows); - for (int i = 0; i < hmat_gpu.cut.size(); ++i) { - ASSERT_LT(fabs(hmat_cpu.cut[i] - hmat_gpu.cut[i]), eps * nrows); + ASSERT_EQ(hmat_gpu.MinValues().size(), num_cols); + ASSERT_EQ(hmat_gpu.Ptrs().size(), num_cols + 1); + ASSERT_EQ(hmat_gpu.Values().size(), hmat_cpu.Values().size()); + ASSERT_LT(fabs(hmat_cpu.MinValues()[0] - hmat_gpu.MinValues()[0]), eps * nrows); + for (int i = 0; i < hmat_gpu.Values().size(); ++i) { + ASSERT_LT(fabs(hmat_cpu.Values()[i] - hmat_gpu.Values()[i]), eps * nrows); } delete dmat; diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index d959f486d..842c333f3 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -9,15 +9,7 @@ namespace xgboost { namespace common { -class HistCutMatrixMock : public HistCutMatrix { - public: - size_t SearchGroupIndFromBaseRow( - std::vector const& group_ptr, size_t const base_rowid) { - return HistCutMatrix::SearchGroupIndFromBaseRow(group_ptr, base_rowid); - } -}; - -TEST(HistCutMatrix, SearchGroupInd) { +TEST(CutsBuilder, SearchGroupInd) { size_t constexpr kNumGroups = 4; size_t constexpr kNumRows = 17; size_t constexpr kNumCols = 15; @@ -34,18 +26,102 @@ TEST(HistCutMatrix, SearchGroupInd) { p_mat->Info().SetInfo( "group", group.data(), DataType::kUInt32, kNumGroups); - HistCutMatrixMock hmat; + HistogramCuts hmat; - size_t group_ind = hmat.SearchGroupIndFromBaseRow(p_mat->Info().group_ptr_, 0); + size_t group_ind = CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 0); ASSERT_EQ(group_ind, 0); - group_ind = hmat.SearchGroupIndFromBaseRow(p_mat->Info().group_ptr_, 5); + group_ind = CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 5); ASSERT_EQ(group_ind, 2); - EXPECT_ANY_THROW(hmat.SearchGroupIndFromBaseRow(p_mat->Info().group_ptr_, 17)); + EXPECT_ANY_THROW(CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 17)); delete pp_mat; } +namespace { +class SparseCutsWrapper : public SparseCuts { + public: + std::vector const& ColPtrs() const { return p_cuts_->Ptrs(); } + std::vector const& ColValues() const { return p_cuts_->Values(); } +}; +} // anonymous namespace + +TEST(SparseCuts, SingleThreadedBuild) { + size_t constexpr kRows = 267; + size_t constexpr kCols = 31; + size_t constexpr kBins = 256; + + // Dense matrix. + auto pp_mat = CreateDMatrix(kRows, kCols, 0); + DMatrix* p_fmat = (*pp_mat).get(); + + common::GHistIndexMatrix hmat; + hmat.Init(p_fmat, kBins); + + HistogramCuts cuts; + SparseCuts indices(&cuts); + auto const& page = *(p_fmat->GetColumnBatches().begin()); + indices.SingleThreadBuild(page, p_fmat->Info(), kBins, false, 0, page.Size(), 0); + + ASSERT_EQ(hmat.cut.Ptrs().size(), cuts.Ptrs().size()); + ASSERT_EQ(hmat.cut.Ptrs(), cuts.Ptrs()); + ASSERT_EQ(hmat.cut.Values(), cuts.Values()); + ASSERT_EQ(hmat.cut.MinValues(), cuts.MinValues()); + + delete pp_mat; +} + +TEST(SparseCuts, MultiThreadedBuild) { + size_t constexpr kRows = 17; + size_t constexpr kCols = 15; + size_t constexpr kBins = 255; + + omp_ulong ori_nthreads = omp_get_max_threads(); + omp_set_num_threads(16); + + auto Compare = +#if defined(_MSC_VER) // msvc fails to capture + [kBins](DMatrix* p_fmat) { +#else + [](DMatrix* p_fmat) { +#endif + HistogramCuts threaded_container; + SparseCuts threaded_indices(&threaded_container); + threaded_indices.Build(p_fmat, kBins); + + HistogramCuts container; + SparseCuts indices(&container); + auto const& page = *(p_fmat->GetColumnBatches().begin()); + indices.SingleThreadBuild(page, p_fmat->Info(), kBins, false, 0, page.Size(), 0); + + ASSERT_EQ(container.Ptrs().size(), threaded_container.Ptrs().size()); + ASSERT_EQ(container.Values().size(), threaded_container.Values().size()); + + for (uint32_t i = 0; i < container.Ptrs().size(); ++i) { + ASSERT_EQ(container.Ptrs()[i], threaded_container.Ptrs()[i]); + } + for (uint32_t i = 0; i < container.Values().size(); ++i) { + ASSERT_EQ(container.Values()[i], threaded_container.Values()[i]); + } + }; + + { + auto pp_mat = CreateDMatrix(kRows, kCols, 0); + DMatrix* p_fmat = (*pp_mat).get(); + Compare(p_fmat); + delete pp_mat; + } + + { + auto pp_mat = CreateDMatrix(kRows, kCols, 0.0001); + DMatrix* p_fmat = (*pp_mat).get(); + Compare(p_fmat); + delete pp_mat; + } + + omp_set_num_threads(ori_nthreads); +} + } // namespace common } // namespace xgboost diff --git a/tests/cpp/common/test_random.cc b/tests/cpp/common/test_random.cc index e408f11bc..128b0fd8c 100644 --- a/tests/cpp/common/test_random.cc +++ b/tests/cpp/common/test_random.cc @@ -53,8 +53,8 @@ TEST(ColumnSampler, Test) { TEST(ColumnSampler, ThreadSynchronisation) { const int64_t num_threads = 100; int n = 128; - int iterations = 10; - int levels = 5; + size_t iterations = 10; + size_t levels = 5; std::vector reference_result; bool success = true; // Cannot use google test asserts in multithreaded region diff --git a/tests/cpp/common/test_span.cc b/tests/cpp/common/test_span.cc index f29ce2af6..d91bdb9b5 100644 --- a/tests/cpp/common/test_span.cc +++ b/tests/cpp/common/test_span.cc @@ -310,7 +310,7 @@ TEST(Span, FirstLast) { ASSERT_EQ(first.size(), 4); ASSERT_EQ(first.data(), arr); - for (size_t i = 0; i < first.size(); ++i) { + for (int64_t i = 0; i < first.size(); ++i) { ASSERT_EQ(first[i], arr[i]); } @@ -329,7 +329,7 @@ TEST(Span, FirstLast) { ASSERT_EQ(last.size(), 4); ASSERT_EQ(last.data(), arr + 12); - for (size_t i = 0; i < last.size(); ++i) { + for (int64_t i = 0; i < last.size(); ++i) { ASSERT_EQ(last[i], arr[i+12]); } @@ -348,7 +348,7 @@ TEST(Span, FirstLast) { ASSERT_EQ(first.size(), 4); ASSERT_EQ(first.data(), s.data()); - for (size_t i = 0; i < first.size(); ++i) { + for (int64_t i = 0; i < first.size(); ++i) { ASSERT_EQ(first[i], s[i]); } @@ -368,7 +368,7 @@ TEST(Span, FirstLast) { ASSERT_EQ(last.size(), 4); ASSERT_EQ(last.data(), s.data() + 12); - for (size_t i = 0; i < last.size(); ++i) { + for (int64_t i = 0; i < last.size(); ++i) { ASSERT_EQ(s[12 + i], last[i]); } diff --git a/tests/cpp/data/test_data.cc b/tests/cpp/data/test_data.cc index 622351bbc..22b558efd 100644 --- a/tests/cpp/data/test_data.cc +++ b/tests/cpp/data/test_data.cc @@ -50,7 +50,7 @@ TEST(SparsePage, PushCSC) { inst = page[1]; ASSERT_EQ(inst.size(), 6); std::vector indices_sol {1, 2, 3}; - for (size_t i = 0; i < inst.size(); ++i) { + for (int64_t i = 0; i < inst.size(); ++i) { ASSERT_EQ(inst[i].index, indices_sol[i % 3]); } } diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 0db5f66d0..3c9142526 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -21,13 +21,13 @@ TEST(cpu_predictor, Test) { HostDeviceVector out_predictions; cpu_predictor->PredictBatch((*dmat).get(), &out_predictions, model, 0); std::vector& out_predictions_h = out_predictions.HostVector(); - for (int i = 0; i < out_predictions.Size(); i++) { + for (size_t i = 0; i < out_predictions.Size(); i++) { ASSERT_EQ(out_predictions_h[i], 1.5); } // Test predict instance auto &batch = *(*dmat)->GetRowBatches().begin(); - for (int i = 0; i < batch.Size(); i++) { + for (size_t i = 0; i < batch.Size(); i++) { std::vector instance_out_predictions; cpu_predictor->PredictInstance(batch[i], &instance_out_predictions, model); ASSERT_EQ(instance_out_predictions[0], 1.5); diff --git a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu index 1106d8486..172fd899b 100644 --- a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu +++ b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu @@ -94,7 +94,7 @@ void TestUpdatePosition() { } TEST(RowPartitioner, Basic) { TestUpdatePosition(); } - + void TestFinalise() { const int kNumRows = 10; RowPartitioner rp(0, kNumRows); diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 39dd91165..2c108e38c 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -53,27 +53,43 @@ TEST(GpuHist, DeviceHistogram) { } } }; - } +namespace { +class HistogramCutsWrapper : public common::HistogramCuts { + public: + using SuperT = common::HistogramCuts; + void SetValues(std::vector cuts) { + SuperT::cut_values_ = cuts; + } + void SetPtrs(std::vector ptrs) { + SuperT::cut_ptrs_ = ptrs; + } + void SetMins(std::vector mins) { + SuperT::min_vals_ = mins; + } +}; +} // anonymous namespace + + template void BuildGidx(DeviceShard* shard, int n_rows, int n_cols, bst_float sparsity=0) { auto dmat = CreateDMatrix(n_rows, n_cols, sparsity, 3); const SparsePage& batch = *(*dmat)->GetRowBatches().begin(); - common::HistCutMatrix cmat; - cmat.row_ptr = {0, 3, 6, 9, 12, 15, 18, 21, 24}; - cmat.min_val = {0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f}; + HistogramCutsWrapper cmat; + cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24}); // 24 cut fields, 3 cut fields for each feature (column). - cmat.cut = {0.30f, 0.67f, 1.64f, - 0.32f, 0.77f, 1.95f, - 0.29f, 0.70f, 1.80f, - 0.32f, 0.75f, 1.85f, - 0.18f, 0.59f, 1.69f, - 0.25f, 0.74f, 2.00f, - 0.26f, 0.74f, 1.98f, - 0.26f, 0.71f, 1.83f}; + cmat.SetValues({0.30f, 0.67f, 1.64f, + 0.32f, 0.77f, 1.95f, + 0.29f, 0.70f, 1.80f, + 0.32f, 0.75f, 1.85f, + 0.18f, 0.59f, 1.69f, + 0.25f, 0.74f, 2.00f, + 0.26f, 0.74f, 1.98f, + 0.26f, 0.71f, 1.83f}); + cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f}); auto is_dense = (*dmat)->Info().num_nonzero_ == (*dmat)->Info().num_row_ * (*dmat)->Info().num_col_; @@ -241,20 +257,20 @@ TEST(GpuHist, BuildHistSharedMem) { TestBuildHist(true); } -common::HistCutMatrix GetHostCutMatrix () { - common::HistCutMatrix cmat; - cmat.row_ptr = {0, 3, 6, 9, 12, 15, 18, 21, 24}; - cmat.min_val = {0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f}; +HistogramCutsWrapper GetHostCutMatrix () { + HistogramCutsWrapper cmat; + cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24}); + cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f}); // 24 cut fields, 3 cut fields for each feature (column). // Each row of the cut represents the cuts for a data column. - cmat.cut = {0.30f, 0.67f, 1.64f, + cmat.SetValues({0.30f, 0.67f, 1.64f, 0.32f, 0.77f, 1.95f, 0.29f, 0.70f, 1.80f, 0.32f, 0.75f, 1.85f, 0.18f, 0.59f, 1.69f, 0.25f, 0.74f, 2.00f, 0.26f, 0.74f, 1.98f, - 0.26f, 0.71f, 1.83f}; + 0.26f, 0.71f, 1.83f}); return cmat; } @@ -293,21 +309,21 @@ TEST(GpuHist, EvaluateSplits) { shard->node_sum_gradients = {{6.4f, 12.8f}}; // Initialize DeviceShard::cut - common::HistCutMatrix cmat = GetHostCutMatrix(); + auto cmat = GetHostCutMatrix(); // Copy cut matrix to device. shard->ba.Allocate(0, - &(shard->feature_segments), cmat.row_ptr.size(), - &(shard->min_fvalue), cmat.min_val.size(), + &(shard->feature_segments), cmat.Ptrs().size(), + &(shard->min_fvalue), cmat.MinValues().size(), &(shard->gidx_fvalue_map), 24, &(shard->monotone_constraints), kNCols); - dh::CopyVectorToDeviceSpan(shard->feature_segments, cmat.row_ptr); - dh::CopyVectorToDeviceSpan(shard->gidx_fvalue_map, cmat.cut); + dh::CopyVectorToDeviceSpan(shard->feature_segments, cmat.Ptrs()); + dh::CopyVectorToDeviceSpan(shard->gidx_fvalue_map, cmat.Values()); dh::CopyVectorToDeviceSpan(shard->monotone_constraints, param.monotone_constraints); shard->ellpack_matrix.feature_segments = shard->feature_segments; shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map; - dh::CopyVectorToDeviceSpan(shard->min_fvalue, cmat.min_val); + dh::CopyVectorToDeviceSpan(shard->min_fvalue, cmat.MinValues()); shard->ellpack_matrix.min_fvalue = shard->min_fvalue; // Initialize DeviceShard::hist diff --git a/tests/cpp/tree/test_prune.cc b/tests/cpp/tree/test_prune.cc index 3436f1e7a..8a40fda38 100644 --- a/tests/cpp/tree/test_prune.cc +++ b/tests/cpp/tree/test_prune.cc @@ -13,7 +13,7 @@ namespace xgboost { namespace tree { TEST(Updater, Prune) { - int constexpr kNRows = 32, kNCols = 16; + int constexpr kNCols = 16; std::vector> cfg; cfg.emplace_back(std::pair( diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 3d0c09e6a..b1c05c41d 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2018 by Contributors + * Copyright 2018-2019 by Contributors */ #include "../helpers.h" #include "../../../src/tree/param.h" @@ -46,23 +46,25 @@ class QuantileHistMock : public QuantileHistMaker { const size_t num_row = p_fmat->Info().num_row_; const size_t num_col = p_fmat->Info().num_col_; /* Validate HistCutMatrix */ - ASSERT_EQ(gmat.cut.row_ptr.size(), num_col + 1); + ASSERT_EQ(gmat.cut.Ptrs().size(), num_col + 1); for (size_t fid = 0; fid < num_col; ++fid) { - // Each feature must have at least one quantile point (cut) - const size_t ibegin = gmat.cut.row_ptr[fid]; - const size_t iend = gmat.cut.row_ptr[fid + 1]; - ASSERT_LT(ibegin, iend); + const size_t ibegin = gmat.cut.Ptrs()[fid]; + const size_t iend = gmat.cut.Ptrs()[fid + 1]; + // Ordered, but empty feature is allowed. + ASSERT_LE(ibegin, iend); for (size_t i = ibegin; i < iend - 1; ++i) { // Quantile points must be sorted in ascending order // No duplicates allowed - ASSERT_LT(gmat.cut.cut[i], gmat.cut.cut[i + 1]); + ASSERT_LT(gmat.cut.Values()[i], gmat.cut.Values()[i + 1]) + << "ibegin: " << ibegin << ", " + << "iend: " << iend; } } /* Validate GHistIndexMatrix */ ASSERT_EQ(gmat.row_ptr.size(), num_row + 1); ASSERT_LT(*std::max_element(gmat.index.begin(), gmat.index.end()), - gmat.cut.row_ptr.back()); + gmat.cut.Ptrs().back()); for (const auto& batch : p_fmat->GetRowBatches()) { for (size_t i = 0; i < batch.Size(); ++i) { const size_t rid = batch.base_rowid + i; @@ -71,20 +73,20 @@ class QuantileHistMock : public QuantileHistMaker { ASSERT_LT(gmat_row_offset, gmat.index.size()); SparsePage::Inst inst = batch[i]; ASSERT_EQ(gmat.row_ptr[rid] + inst.size(), gmat.row_ptr[rid + 1]); - for (size_t j = 0; j < inst.size(); ++j) { + for (int64_t j = 0; j < inst.size(); ++j) { // Each entry of GHistIndexMatrix represents a bin ID const size_t bin_id = gmat.index[gmat_row_offset + j]; const size_t fid = inst[j].index; // The bin ID must correspond to correct feature - ASSERT_GE(bin_id, gmat.cut.row_ptr[fid]); - ASSERT_LT(bin_id, gmat.cut.row_ptr[fid + 1]); + ASSERT_GE(bin_id, gmat.cut.Ptrs()[fid]); + ASSERT_LT(bin_id, gmat.cut.Ptrs()[fid + 1]); // The bin ID must correspond to a region between two // suitable quantile points - ASSERT_LT(inst[j].fvalue, gmat.cut.cut[bin_id]); - if (bin_id > gmat.cut.row_ptr[fid]) { - ASSERT_GE(inst[j].fvalue, gmat.cut.cut[bin_id - 1]); + ASSERT_LT(inst[j].fvalue, gmat.cut.Values()[bin_id]); + if (bin_id > gmat.cut.Ptrs()[fid]) { + ASSERT_GE(inst[j].fvalue, gmat.cut.Values()[bin_id - 1]); } else { - ASSERT_GE(inst[j].fvalue, gmat.cut.min_val[fid]); + ASSERT_GE(inst[j].fvalue, gmat.cut.MinValues()[fid]); } } } @@ -106,11 +108,12 @@ class QuantileHistMock : public QuantileHistMaker { std::vector> hist_is_init; std::vector nodes = {ExpandEntry(nid, -1, -1, tree.GetDepth(0), 0.0, 0)}; BuildHistsBatch(nodes, const_cast(&tree), gmat, gpair, &hist_buffers, &hist_is_init); - RealImpl::InitNewNode(nid, gmat, gpair, fmat, const_cast(&tree), &snode_[0], tree[0].Parent()); + RealImpl::InitNewNode(nid, gmat, gpair, fmat, + const_cast(&tree), &snode_[0], tree[0].Parent()); EvaluateSplitsBatch(nodes, gmat, fmat, hist_is_init, hist_buffers); // Check if number of histogram bins is correct - ASSERT_EQ(hist_[nid].size(), gmat.cut.row_ptr.back()); + ASSERT_EQ(hist_[nid].size(), gmat.cut.Ptrs().back()); std::vector histogram_expected(hist_[nid].size()); // Compute the correct histogram (histogram_expected) @@ -126,7 +129,7 @@ class QuantileHistMock : public QuantileHistMaker { } // Now validate the computed histogram returned by BuildHist - for (size_t i = 0; i < hist_[nid].size(); ++i) { + for (int64_t i = 0; i < hist_[nid].size(); ++i) { GradientPairPrecise sol = histogram_expected[i]; ASSERT_NEAR(sol.GetGrad(), hist_[nid][i].GetGrad(), kEps); ASSERT_NEAR(sol.GetHess(), hist_[nid][i].GetHess(), kEps); @@ -140,7 +143,7 @@ class QuantileHistMock : public QuantileHistMaker { {0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f} }; size_t constexpr kMaxBins = 4; auto dmat = CreateDMatrix(kNRows, kNCols, 0, 3); - // dense, no missing values + // dense, no missing values common::GHistIndexMatrix gmat; gmat.Init((*dmat).get(), kMaxBins); @@ -152,7 +155,8 @@ class QuantileHistMock : public QuantileHistMaker { std::vector> hist_buffers; std::vector> hist_is_init; BuildHistsBatch(nodes, const_cast(&tree), gmat, row_gpairs, &hist_buffers, &hist_is_init); - RealImpl::InitNewNode(0, gmat, row_gpairs, *(*dmat), const_cast(&tree), &snode_[0], tree[0].Parent()); + RealImpl::InitNewNode(0, gmat, row_gpairs, *(*dmat), + const_cast(&tree), &snode_[0], tree[0].Parent()); EvaluateSplitsBatch(nodes, gmat, **dmat, hist_is_init, hist_buffers); /* Compute correct split (best_split) using the computed histogram */ @@ -178,8 +182,8 @@ class QuantileHistMock : public QuantileHistMaker { size_t best_split_feature = std::numeric_limits::max(); // Enumerate all features for (size_t fid = 0; fid < num_feature; ++fid) { - const size_t bin_id_min = gmat.cut.row_ptr[fid]; - const size_t bin_id_max = gmat.cut.row_ptr[fid + 1]; + const size_t bin_id_min = gmat.cut.Ptrs()[fid]; + const size_t bin_id_max = gmat.cut.Ptrs()[fid + 1]; // Enumerate all bin ID in [bin_id_min, bin_id_max), i.e. every possible // choice of thresholds for feature fid for (size_t split_thresh = bin_id_min; @@ -217,7 +221,7 @@ class QuantileHistMock : public QuantileHistMaker { EvaluateSplitsBatch(nodes, gmat, **dmat, hist_is_init, hist_buffers); ASSERT_EQ(snode_[0].best.SplitIndex(), best_split_feature); - ASSERT_EQ(snode_[0].best.split_value, gmat.cut.cut[best_split_threshold]); + ASSERT_EQ(snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]); delete dmat; }