From bbb771f32ec8b026ff5ede8ee34a3cb1e5c7071a Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 9 Aug 2018 17:59:57 +1200 Subject: [PATCH] Refactor parts of fast histogram utilities (#3564) * Refactor parts of fast histogram utilities * Removed byte packing from column matrix --- src/common/column_matrix.h | 146 ++++++++----------------- src/common/hist_util.cc | 90 +++++++-------- src/common/hist_util.h | 11 +- src/tree/fast_hist_param.h | 10 -- src/tree/updater_fast_hist.cc | 91 ++++++--------- tests/cpp/common/test_column_matrix.cc | 51 +++++++++ tests/cpp/data/test_metainfo.cc | 55 ---------- tests/cpp/tree/test_gpu_hist.cu | 18 +-- 8 files changed, 184 insertions(+), 288 deletions(-) create mode 100644 tests/cpp/common/test_column_matrix.cc diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index 47e6d3310..9822593e9 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -8,47 +8,14 @@ #ifndef XGBOOST_COMMON_COLUMN_MATRIX_H_ #define XGBOOST_COMMON_COLUMN_MATRIX_H_ -#define XGBOOST_TYPE_SWITCH(dtype, OP) \ - \ -switch(dtype) { \ - case xgboost::common::uint32: { \ - using DType = uint32_t; \ - OP; \ - break; \ - } \ - case xgboost::common::uint16: { \ - using DType = uint16_t; \ - OP; \ - break; \ - } \ - case xgboost::common::uint8: { \ - using DType = uint8_t; \ - OP; \ - break; \ - default: \ - LOG(FATAL) << "don't recognize type flag" << dtype; \ - } \ - \ -} - -#include #include #include #include "hist_util.h" -#include "../tree/fast_hist_param.h" namespace xgboost { namespace common { -using tree::FastHistParam; - -/*! \brief indicator of data type used for storing bin id's in a column. */ -enum DataType { - uint8 = 1, - uint16 = 2, - uint32 = 4 -}; /*! \brief column type */ enum ColumnType { @@ -58,14 +25,36 @@ enum ColumnType { /*! \brief a column storage, to be used with ApplySplit. Note that each bin id is stored as index[i] + index_base. */ -template class Column { public: - ColumnType type; - const T* index; - uint32_t index_base; - const size_t* row_ind; - size_t len; + Column(ColumnType type, const uint32_t* index, uint32_t index_base, + const size_t* row_ind, size_t len) + : type_(type), + index_(index), + index_base_(index_base), + row_ind_(row_ind), + len_(len) {} + size_t Size() const { return len_; } + uint32_t GetGlobalBinIdx(size_t idx) const { return index_base_ + index_[idx]; } + uint32_t GetFeatureBinIdx(size_t idx) const { return index_[idx]; } + // column.GetFeatureBinIdx(idx) + column.GetBaseIdx(idx) == + // column.GetGlobalBinIdx(idx) + uint32_t GetBaseIdx() const { return index_base_; } + ColumnType GetType() const { return type_; } + size_t GetRowIdx(size_t idx) const { + return type_ == ColumnType::kDenseColumn ? idx : row_ind_[idx]; + } + bool IsMissing(size_t idx) const { + return index_[idx] == std::numeric_limits::max(); + } + const size_t* GetRowData() const { return row_ind_; } + + private: + ColumnType type_; + const uint32_t* index_; + uint32_t index_base_; + const size_t* row_ind_; + const size_t len_; }; /*! \brief a collection of columns, with support for construction from @@ -79,13 +68,8 @@ class ColumnMatrix { // construct column matrix from GHistIndexMatrix inline void Init(const GHistIndexMatrix& gmat, - const FastHistParam& param) { - this->dtype = static_cast(param.colmat_dtype); - /* if dtype is smaller than uint32_t, multiple bin_id's will be stored in each - slot of internal buffer. */ - packing_factor_ = sizeof(uint32_t) / static_cast(this->dtype); - - const auto nfeature = static_cast(gmat.cut->row_ptr.size() - 1); + double sparse_threshold) { + const auto nfeature = static_cast(gmat.cut.row_ptr.size() - 1); const size_t nrow = gmat.row_ptr.size() - 1; // identify type of each column @@ -93,19 +77,16 @@ class ColumnMatrix { type_.resize(nfeature); std::fill(feature_counts_.begin(), feature_counts_.end(), 0); - uint32_t max_val = 0; - XGBOOST_TYPE_SWITCH(this->dtype, { - max_val = static_cast(std::numeric_limits::max()); - }); + uint32_t max_val = std::numeric_limits::max(); for (bst_uint fid = 0; fid < nfeature; ++fid) { - CHECK_LE(gmat.cut->row_ptr[fid + 1] - gmat.cut->row_ptr[fid], max_val); + CHECK_LE(gmat.cut.row_ptr[fid + 1] - gmat.cut.row_ptr[fid], max_val); } gmat.GetFeatureCounts(&feature_counts_[0]); // classify features for (bst_uint fid = 0; fid < nfeature; ++fid) { if (static_cast(feature_counts_[fid]) - < param.sparse_threshold * nrow) { + < sparse_threshold * nrow) { type_[fid] = kSparseColumn; } else { type_[fid] = kDenseColumn; @@ -131,28 +112,23 @@ class ColumnMatrix { boundary_[fid].row_ind_end = accum_row_ind_; } - index_.resize((boundary_[nfeature - 1].index_end - + (packing_factor_ - 1)) / packing_factor_); + index_.resize(boundary_[nfeature - 1].index_end); row_ind_.resize(boundary_[nfeature - 1].row_ind_end); // store least bin id for each feature index_base_.resize(nfeature); for (bst_uint fid = 0; fid < nfeature; ++fid) { - index_base_[fid] = gmat.cut->row_ptr[fid]; + index_base_[fid] = gmat.cut.row_ptr[fid]; } // pre-fill index_ for dense columns for (bst_uint fid = 0; fid < nfeature; ++fid) { if (type_[fid] == kDenseColumn) { const size_t ibegin = boundary_[fid].index_begin; - XGBOOST_TYPE_SWITCH(this->dtype, { - const size_t block_offset = ibegin / packing_factor_; - const size_t elem_offset = ibegin % packing_factor_; - DType* begin = reinterpret_cast(&index_[block_offset]) + elem_offset; - DType* end = begin + nrow; - std::fill(begin, end, std::numeric_limits::max()); - // max() indicates missing values - }); + uint32_t* begin = &index_[ibegin]; + uint32_t* end = begin + nrow; + std::fill(begin, end, std::numeric_limits::max()); + // max() indicates missing values } } @@ -167,23 +143,15 @@ class ColumnMatrix { size_t fid = 0; for (size_t i = ibegin; i < iend; ++i) { const uint32_t bin_id = gmat.index[i]; - while (bin_id >= gmat.cut->row_ptr[fid + 1]) { + while (bin_id >= gmat.cut.row_ptr[fid + 1]) { ++fid; } if (type_[fid] == kDenseColumn) { - XGBOOST_TYPE_SWITCH(this->dtype, { - const size_t block_offset = boundary_[fid].index_begin / packing_factor_; - const size_t elem_offset = boundary_[fid].index_begin % packing_factor_; - DType* begin = reinterpret_cast(&index_[block_offset]) + elem_offset; - begin[rid] = static_cast(bin_id - index_base_[fid]); - }); + uint32_t* begin = &index_[boundary_[fid].index_begin]; + begin[rid] = bin_id - index_base_[fid]; } else { - XGBOOST_TYPE_SWITCH(this->dtype, { - const size_t block_offset = boundary_[fid].index_begin / packing_factor_; - const size_t elem_offset = boundary_[fid].index_begin % packing_factor_; - DType* begin = reinterpret_cast(&index_[block_offset]) + elem_offset; - begin[num_nonzeros[fid]] = static_cast(bin_id - index_base_[fid]); - }); + uint32_t* begin = &index_[boundary_[fid].index_begin]; + begin[num_nonzeros[fid]] = bin_id - index_base_[fid]; row_ind_[boundary_[fid].row_ind_begin + num_nonzeros[fid]] = rid; ++num_nonzeros[fid]; } @@ -193,29 +161,13 @@ class ColumnMatrix { /* Fetch an individual column. This code should be used with XGBOOST_TYPE_SWITCH to determine type of bin id's */ - template - inline Column GetColumn(unsigned fid) const { - const bool valid_type = std::is_same::value - || std::is_same::value - || std::is_same::value; - CHECK(valid_type); - - Column c; - - c.type = type_[fid]; - const size_t block_offset = boundary_[fid].index_begin / packing_factor_; - const size_t elem_offset = boundary_[fid].index_begin % packing_factor_; - c.index = reinterpret_cast(&index_[block_offset]) + elem_offset; - c.index_base = index_base_[fid]; - c.row_ind = &row_ind_[boundary_[fid].row_ind_begin]; - c.len = boundary_[fid].index_end - boundary_[fid].index_begin; - + inline Column GetColumn(unsigned fid) const { + Column c(type_[fid], &index_[boundary_[fid].index_begin], index_base_[fid], + &row_ind_[boundary_[fid].row_ind_begin], + boundary_[fid].index_end - boundary_[fid].index_begin); return c; } - public: - DataType dtype; - private: struct ColumnBoundary { // indicate where each column's index and row_ind is stored. @@ -233,8 +185,6 @@ class ColumnMatrix { std::vector row_ind_; std::vector boundary_; - size_t packing_factor_; // how many integers are stored in each slot of index_ - // index_base_[fid]: least bin id for feature fid std::vector index_base_; }; diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index afd06eca6..4324a74d6 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -114,12 +114,23 @@ void HistCutMatrix::Init } } -void GHistIndexMatrix::Init(DMatrix* p_fmat) { - CHECK(cut != nullptr); // NOLINT +uint32_t HistCutMatrix::GetBinIdx(const Entry& e) { + unsigned fid = e.index; + auto cbegin = cut.begin() + row_ptr[fid]; + auto cend = cut.begin() + row_ptr[fid + 1]; + CHECK(cbegin != cend); + auto it = std::upper_bound(cbegin, cend, e.fvalue); + if (it == cend) it = cend - 1; + uint32_t idx = static_cast(it - cut.begin()); + return idx; +} + +void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) { + cut.Init(p_fmat, max_num_bins); auto iter = p_fmat->RowIterator(); const int nthread = omp_get_max_threads(); - const uint32_t nbins = cut->row_ptr.back(); + const uint32_t nbins = cut.row_ptr.back(); hit_count.resize(nbins, 0); hit_count_tloc_.resize(nthread * nbins, 0); @@ -133,8 +144,8 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) { } index.resize(row_ptr.back()); - CHECK_GT(cut->cut.size(), 0U); - CHECK_EQ(cut->row_ptr.back(), cut->cut.size()); + CHECK_GT(cut.cut.size(), 0U); + CHECK_EQ(cut.row_ptr.back(), cut.cut.size()); auto bsize = static_cast(batch.Size()); #pragma omp parallel for num_threads(nthread) schedule(static) @@ -145,13 +156,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) { SparsePage::Inst inst = batch[i]; CHECK_EQ(ibegin + inst.length, iend); for (bst_uint j = 0; j < inst.length; ++j) { - unsigned fid = inst[j].index; - auto cbegin = cut->cut.begin() + cut->row_ptr[fid]; - auto cend = cut->cut.begin() + cut->row_ptr[fid + 1]; - CHECK(cbegin != cend); - auto it = std::upper_bound(cbegin, cend, inst[j].fvalue); - if (it == cend) it = cend - 1; - uint32_t idx = static_cast(it - cut->cut.begin()); + uint32_t idx = cut.GetBinIdx(inst[j]); index[ibegin + j] = idx; ++hit_count_tloc_[tid * nbins + idx]; } @@ -167,14 +172,13 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) { } } -template static size_t GetConflictCount(const std::vector& mark, - const Column& column, + const Column& column, size_t max_cnt) { size_t ret = 0; - if (column.type == xgboost::common::kDenseColumn) { - for (size_t i = 0; i < column.len; ++i) { - if (column.index[i] != std::numeric_limits::max() && mark[i]) { + if (column.GetType() == xgboost::common::kDenseColumn) { + for (size_t i = 0; i < column.Size(); ++i) { + if (column.GetFeatureBinIdx(i) != std::numeric_limits::max() && mark[i]) { ++ret; if (ret > max_cnt) { return max_cnt + 1; @@ -182,8 +186,8 @@ static size_t GetConflictCount(const std::vector& mark, } } } else { - for (size_t i = 0; i < column.len; ++i) { - if (mark[column.row_ind[i]]) { + for (size_t i = 0; i < column.Size(); ++i) { + if (mark[column.GetRowIdx(i)]) { ++ret; if (ret > max_cnt) { return max_cnt + 1; @@ -194,30 +198,28 @@ static size_t GetConflictCount(const std::vector& mark, return ret; } -template inline void -MarkUsed(std::vector* p_mark, const Column& column) { +MarkUsed(std::vector* p_mark, const Column& column) { std::vector& mark = *p_mark; - if (column.type == xgboost::common::kDenseColumn) { - for (size_t i = 0; i < column.len; ++i) { - if (column.index[i] != std::numeric_limits::max()) { + if (column.GetType() == xgboost::common::kDenseColumn) { + for (size_t i = 0; i < column.Size(); ++i) { + if (column.GetFeatureBinIdx(i) != std::numeric_limits::max()) { mark[i] = true; } } } else { - for (size_t i = 0; i < column.len; ++i) { - mark[column.row_ind[i]] = true; + for (size_t i = 0; i < column.Size(); ++i) { + mark[column.GetRowIdx(i)] = true; } } } -template inline std::vector> -FindGroups_(const std::vector& feature_list, - const std::vector& feature_nnz, - const ColumnMatrix& colmat, - size_t nrow, - const FastHistParam& param) { +FindGroups(const std::vector& feature_list, + const std::vector& feature_nnz, + const ColumnMatrix& colmat, + size_t nrow, + const FastHistParam& param) { /* Goal: Bundle features together that has little or no "overlap", i.e. only a few data points should have nonzero values for member features. @@ -231,7 +233,7 @@ FindGroups_(const std::vector& feature_list, = static_cast(param.max_conflict_rate * nrow); for (auto fid : feature_list) { - const Column& column = colmat.GetColumn(fid); + const Column& column = colmat.GetColumn(fid); const size_t cur_fid_nnz = feature_nnz[fid]; bool need_new_group = true; @@ -276,24 +278,12 @@ FindGroups_(const std::vector& feature_list, return groups; } -inline std::vector> -FindGroups(const std::vector& feature_list, - const std::vector& feature_nnz, - const ColumnMatrix& colmat, - size_t nrow, - const FastHistParam& param) { - XGBOOST_TYPE_SWITCH(colmat.dtype, { - return FindGroups_(feature_list, feature_nnz, colmat, nrow, param); - }); - return std::vector>(); // to avoid warning message -} - inline std::vector> FastFeatureGrouping(const GHistIndexMatrix& gmat, const ColumnMatrix& colmat, const FastHistParam& param) { const size_t nrow = gmat.row_ptr.size() - 1; - const size_t nfeature = gmat.cut->row_ptr.size() - 1; + const size_t nfeature = gmat.cut.row_ptr.size() - 1; std::vector feature_list(nfeature); std::iota(feature_list.begin(), feature_list.end(), 0); @@ -346,10 +336,10 @@ FastFeatureGrouping(const GHistIndexMatrix& gmat, void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat, const ColumnMatrix& colmat, const FastHistParam& param) { - cut_ = gmat.cut; + cut_ = &gmat.cut; const size_t nrow = gmat.row_ptr.size() - 1; - const uint32_t nbins = gmat.cut->row_ptr.back(); + const uint32_t nbins = gmat.cut.row_ptr.back(); /* step 1: form feature groups */ auto groups = FastFeatureGrouping(gmat, colmat, param); @@ -359,8 +349,8 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat, std::vector bin2block(nbins); // lookup table [bin id] => [block id] for (uint32_t group_id = 0; group_id < nblock; ++group_id) { for (auto& fid : groups[group_id]) { - const uint32_t bin_begin = gmat.cut->row_ptr[fid]; - const uint32_t bin_end = gmat.cut->row_ptr[fid + 1]; + const uint32_t bin_begin = gmat.cut.row_ptr[fid]; + const uint32_t bin_end = gmat.cut.row_ptr[fid + 1]; for (uint32_t bin_id = bin_begin; bin_id < bin_end; ++bin_id) { bin2block[bin_id] = group_id; } diff --git a/src/common/hist_util.h b/src/common/hist_util.h index bc5eaeb58..034b8f386 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -75,6 +75,7 @@ struct HistCutMatrix { std::vector min_val; /*! \brief the cut field */ std::vector cut; + uint32_t GetBinIdx(const Entry &e); /*! \brief Get histogram bound for fid */ inline HistCutUnit operator[](bst_uint fid) const { return {dmlc::BeginPtr(cut) + row_ptr[fid], @@ -122,18 +123,18 @@ struct GHistIndexMatrix { /*! \brief hit count of each index */ std::vector hit_count; /*! \brief The corresponding cuts */ - const HistCutMatrix* cut; + HistCutMatrix cut; // Create a global histogram matrix, given cut - void Init(DMatrix* p_fmat); + void Init(DMatrix* p_fmat, int max_num_bins); // get i-th row inline GHistIndexRow operator[](size_t i) const { return {&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]}; } inline void GetFeatureCounts(size_t* counts) const { - auto nfeature = cut->row_ptr.size() - 1; + auto nfeature = cut.row_ptr.size() - 1; for (unsigned fid = 0; fid < nfeature; ++fid) { - auto ibegin = cut->row_ptr[fid]; - auto iend = cut->row_ptr[fid + 1]; + auto ibegin = cut.row_ptr[fid]; + auto iend = cut.row_ptr[fid + 1]; for (auto i = ibegin; i < iend; ++i) { counts[fid] += hit_count[i]; } diff --git a/src/tree/fast_hist_param.h b/src/tree/fast_hist_param.h index 876450991..39d009ff6 100644 --- a/src/tree/fast_hist_param.h +++ b/src/tree/fast_hist_param.h @@ -12,8 +12,6 @@ namespace tree { /*! \brief training parameters for histogram-based training */ struct FastHistParam : public dmlc::Parameter { - // integral data type to be used with columnar data storage - enum class DataType { uint8 = 1, uint16 = 2, uint32 = 4 }; // NOLINT int colmat_dtype; // percentage threshold for treating a feature as sparse // e.g. 0.2 indicates a feature with fewer than 20% nonzeros is considered sparse @@ -32,14 +30,6 @@ struct FastHistParam : public dmlc::Parameter { // declare the parameters DMLC_DECLARE_PARAMETER(FastHistParam) { - DMLC_DECLARE_FIELD(colmat_dtype) - .set_default(static_cast(DataType::uint32)) - .add_enum("uint8", static_cast(DataType::uint8)) - .add_enum("uint16", static_cast(DataType::uint16)) - .add_enum("uint32", static_cast(DataType::uint32)) - .describe("Integral data type to be used with columnar data storage." - "May carry marginal performance implications. Reserved for " - "advanced use"); DMLC_DECLARE_FIELD(sparse_threshold).set_range(0, 1.0).set_default(0.2) .describe("percentage threshold for treating a feature as sparse"); DMLC_DECLARE_FIELD(enable_feature_grouping).set_lower_bound(0).set_default(0) diff --git a/src/tree/updater_fast_hist.cc b/src/tree/updater_fast_hist.cc index 4610bdd38..9c38c6ecc 100644 --- a/src/tree/updater_fast_hist.cc +++ b/src/tree/updater_fast_hist.cc @@ -69,10 +69,8 @@ class FastHistMaker: public TreeUpdater { GradStats::CheckInfo(dmat->Info()); if (is_gmat_initialized_ == false) { double tstart = dmlc::GetTime(); - hmat_.Init(dmat, static_cast(param_.max_bin)); - gmat_.cut = &hmat_; - gmat_.Init(dmat); - column_matrix_.Init(gmat_, fhparam_); + gmat_.Init(dmat, static_cast(param_.max_bin)); + column_matrix_.Init(gmat_, fhparam_.sparse_threshold); if (fhparam_.enable_feature_grouping > 0) { gmatb_.Init(gmat_, column_matrix_, fhparam_); } @@ -112,8 +110,6 @@ class FastHistMaker: public TreeUpdater { // training parameter TrainParam param_; FastHistParam fhparam_; - // data sketch - HistCutMatrix hmat_; // quantized data matrix GHistIndexMatrix gmat_; // (optional) data matrix with feature grouping @@ -376,7 +372,7 @@ class FastHistMaker: public TreeUpdater { // clear local prediction cache leaf_value_cache_.clear(); // initialize histogram collection - uint32_t nbins = gmat.cut->row_ptr.back(); + uint32_t nbins = gmat.cut.row_ptr.back(); hist_.Init(nbins); // initialize histogram builder @@ -413,7 +409,7 @@ class FastHistMaker: public TreeUpdater { const size_t ncol = info.num_col_; const size_t nnz = info.num_nonzero_; // number of discrete bins for feature 0 - const uint32_t nbins_f0 = gmat.cut->row_ptr[1] - gmat.cut->row_ptr[0]; + const uint32_t nbins_f0 = gmat.cut.row_ptr[1] - gmat.cut.row_ptr[0]; if (nrow * ncol == nnz) { // dense data with zero-based indexing data_layout_ = kDenseDataZeroBased; @@ -454,7 +450,7 @@ class FastHistMaker: public TreeUpdater { choose the column that has a least positive number of discrete bins. For dense data (with no missing value), the sum of gradient histogram is equal to snode[nid] */ - const std::vector& row_ptr = gmat.cut->row_ptr; + const std::vector& row_ptr = gmat.cut.row_ptr; const auto nfeature = static_cast(row_ptr.size() - 1); uint32_t min_nbins_per_feature = 0; for (bst_uint i = 0; i < nfeature; ++i) { @@ -516,19 +512,6 @@ class FastHistMaker: public TreeUpdater { const HistCollection& hist, const DMatrix& fmat, RegTree* p_tree) { - XGBOOST_TYPE_SWITCH(column_matrix.dtype, { - ApplySplitSpecialize(nid, gmat, column_matrix, hist, fmat, - p_tree); - }); - } - - template - inline void ApplySplitSpecialize(int nid, - const GHistIndexMatrix& gmat, - const ColumnMatrix& column_matrix, - const HistCollection& hist, - const DMatrix& fmat, - RegTree* p_tree) { // TODO(hcho3): support feature sampling by levels /* 1. Create child nodes */ @@ -552,23 +535,23 @@ class FastHistMaker: public TreeUpdater { const bool default_left = (*p_tree)[nid].DefaultLeft(); const bst_uint fid = (*p_tree)[nid].SplitIndex(); const bst_float split_pt = (*p_tree)[nid].SplitCond(); - const uint32_t lower_bound = gmat.cut->row_ptr[fid]; - const uint32_t upper_bound = gmat.cut->row_ptr[fid + 1]; + const uint32_t lower_bound = gmat.cut.row_ptr[fid]; + const uint32_t upper_bound = gmat.cut.row_ptr[fid + 1]; int32_t split_cond = -1; // convert floating-point split_pt into corresponding bin_id // split_cond = -1 indicates that split_pt is less than all known cut points CHECK_LT(upper_bound, static_cast(std::numeric_limits::max())); for (uint32_t i = lower_bound; i < upper_bound; ++i) { - if (split_pt == gmat.cut->cut[i]) { + if (split_pt == gmat.cut.cut[i]) { split_cond = static_cast(i); } } const auto& rowset = row_set_collection_[nid]; - Column column = column_matrix.GetColumn(fid); - if (column.type == xgboost::common::kDenseColumn) { + Column column = column_matrix.GetColumn(fid); + if (column.GetType() == xgboost::common::kDenseColumn) { ApplySplitDenseData(rowset, gmat, &row_split_tloc_, column, split_cond, default_left); } else { @@ -580,11 +563,10 @@ class FastHistMaker: public TreeUpdater { nid, row_split_tloc_, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild()); } - template inline void ApplySplitDenseData(const RowSetCollection::Elem rowset, const GHistIndexMatrix& gmat, std::vector* p_row_split_tloc, - const Column& column, + const Column& column, bst_int split_cond, bool default_left) { std::vector& row_split_tloc = *p_row_split_tloc; @@ -598,24 +580,22 @@ class FastHistMaker: public TreeUpdater { auto& left = row_split_tloc[tid].left; auto& right = row_split_tloc[tid].right; size_t rid[kUnroll]; - T rbin[kUnroll]; + uint32_t rbin[kUnroll]; for (int k = 0; k < kUnroll; ++k) { rid[k] = rowset.begin[i + k]; } for (int k = 0; k < kUnroll; ++k) { - rbin[k] = column.index[rid[k]]; + rbin[k] = column.GetFeatureBinIdx(rid[k]); } for (int k = 0; k < kUnroll; ++k) { // NOLINT - if (rbin[k] == std::numeric_limits::max()) { // missing value + if (rbin[k] == std::numeric_limits::max()) { // missing value if (default_left) { left.push_back(rid[k]); } else { right.push_back(rid[k]); } } else { - CHECK_LT(rbin[k] + column.index_base, - static_cast(std::numeric_limits::max())); - if (static_cast(rbin[k] + column.index_base) <= split_cond) { + if (static_cast(rbin[k] + column.GetBaseIdx()) <= split_cond) { left.push_back(rid[k]); } else { right.push_back(rid[k]); @@ -627,17 +607,15 @@ class FastHistMaker: public TreeUpdater { auto& left = row_split_tloc[nthread_-1].left; auto& right = row_split_tloc[nthread_-1].right; const size_t rid = rowset.begin[i]; - const T rbin = column.index[rid]; - if (rbin == std::numeric_limits::max()) { // missing value + const uint32_t rbin = column.GetFeatureBinIdx(rid); + if (rbin == std::numeric_limits::max()) { // missing value if (default_left) { left.push_back(rid); } else { right.push_back(rid); } } else { - CHECK_LT(rbin + column.index_base, - static_cast(std::numeric_limits::max())); - if (static_cast(rbin + column.index_base) <= split_cond) { + if (static_cast(rbin + column.GetBaseIdx()) <= split_cond) { left.push_back(rid); } else { right.push_back(rid); @@ -646,11 +624,10 @@ class FastHistMaker: public TreeUpdater { } } - template inline void ApplySplitSparseData(const RowSetCollection::Elem rowset, const GHistIndexMatrix& gmat, std::vector* p_row_split_tloc, - const Column& column, + const Column& column, bst_uint lower_bound, bst_uint upper_bound, bst_int split_cond, @@ -665,27 +642,25 @@ class FastHistMaker: public TreeUpdater { const size_t iend = (tid + 1) * nrows / nthread_; if (ibegin < iend) { // ensure that [ibegin, iend) is nonempty range // search first nonzero row with index >= rowset[ibegin] - const size_t* p = std::lower_bound(column.row_ind, - column.row_ind + column.len, + const size_t* p = std::lower_bound(column.GetRowData(), + column.GetRowData() + column.Size(), rowset.begin[ibegin]); auto& left = row_split_tloc[tid].left; auto& right = row_split_tloc[tid].right; - if (p != column.row_ind + column.len && *p <= rowset.begin[iend - 1]) { - size_t cursor = p - column.row_ind; + if (p != column.GetRowData() + column.Size() && *p <= rowset.begin[iend - 1]) { + size_t cursor = p - column.GetRowData(); for (size_t i = ibegin; i < iend; ++i) { const size_t rid = rowset.begin[i]; - while (cursor < column.len - && column.row_ind[cursor] < rid - && column.row_ind[cursor] <= rowset.begin[iend - 1]) { + while (cursor < column.Size() + && column.GetRowIdx(cursor) < rid + && column.GetRowIdx(cursor) <= rowset.begin[iend - 1]) { ++cursor; } - if (cursor < column.len && column.row_ind[cursor] == rid) { - const T rbin = column.index[cursor]; - CHECK_LT(rbin + column.index_base, - static_cast(std::numeric_limits::max())); - if (static_cast(rbin + column.index_base) <= split_cond) { + if (cursor < column.Size() && column.GetRowIdx(cursor) == rid) { + const uint32_t rbin = column.GetFeatureBinIdx(cursor); + if (static_cast(rbin + column.GetBaseIdx()) <= split_cond) { left.push_back(rid); } else { right.push_back(rid); @@ -733,7 +708,7 @@ class FastHistMaker: public TreeUpdater { For dense data (with no missing value), the sum of gradient histogram is equal to snode[nid] */ GHistRow hist = hist_[nid]; - const std::vector& row_ptr = gmat.cut->row_ptr; + const std::vector& row_ptr = gmat.cut.row_ptr; const uint32_t ibegin = row_ptr[fid_least_bins_]; const uint32_t iend = row_ptr[fid_least_bins_ + 1]; @@ -771,8 +746,8 @@ class FastHistMaker: public TreeUpdater { CHECK(d_step == +1 || d_step == -1); // aliases - const std::vector& cut_ptr = gmat.cut->row_ptr; - const std::vector& cut_val = gmat.cut->cut; + const std::vector& cut_ptr = gmat.cut.row_ptr; + const std::vector& cut_val = gmat.cut.cut; // statistics on both sides of split GradStats c(param_); @@ -821,7 +796,7 @@ class FastHistMaker: public TreeUpdater { snode.root_gain); if (i == imin) { // for leftmost bin, left bound is the smallest feature value - split_pt = gmat.cut->min_val[fid]; + split_pt = gmat.cut.min_val[fid]; } else { split_pt = cut_val[i - 1]; } diff --git a/tests/cpp/common/test_column_matrix.cc b/tests/cpp/common/test_column_matrix.cc new file mode 100644 index 000000000..741672fbe --- /dev/null +++ b/tests/cpp/common/test_column_matrix.cc @@ -0,0 +1,51 @@ +#include "../../../src/common/column_matrix.h" +#include "../helpers.h" +#include "gtest/gtest.h" + +namespace xgboost { +namespace common { +TEST(DenseColumn, Test) { + auto dmat = CreateDMatrix(100, 10, 0.0); + GHistIndexMatrix gmat; + gmat.Init(dmat.get(), 256); + ColumnMatrix column_matrix; + column_matrix.Init(gmat, 0.2); + + for (auto i = 0ull; i < dmat->Info().num_row_; i++) { + for (auto j = 0ull; j < dmat->Info().num_col_; j++) { + auto col = column_matrix.GetColumn(j); + EXPECT_EQ(gmat.index[i * dmat->Info().num_col_ + j], + col.GetGlobalBinIdx(i)); + } + } +} + +TEST(SparseColumn, Test) { + auto dmat = CreateDMatrix(100, 1, 0.85); + GHistIndexMatrix gmat; + gmat.Init(dmat.get(), 256); + ColumnMatrix column_matrix; + column_matrix.Init(gmat, 0.5); + auto col = column_matrix.GetColumn(0); + ASSERT_EQ(col.Size(), gmat.index.size()); + for (auto i = 0ull; i < col.Size(); i++) { + EXPECT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]], + col.GetGlobalBinIdx(i)); + } +} + +TEST(DenseColumnWithMissing, Test) { + auto dmat = CreateDMatrix(100, 1, 0.5); + GHistIndexMatrix gmat; + gmat.Init(dmat.get(), 256); + ColumnMatrix column_matrix; + column_matrix.Init(gmat, 0.2); + auto col = column_matrix.GetColumn(0); + for (auto i = 0ull; i < col.Size(); i++) { + if (col.IsMissing(i)) continue; + EXPECT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]], + col.GetGlobalBinIdx(i)); + } +} +} // namespace common +} // namespace xgboost diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index 41bc7b979..d10fe36d9 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -67,59 +67,4 @@ TEST(MetaInfo, SaveLoadBinary) { } TEST(MetaInfo, LoadQid) { - std::string tmp_file = TempFileName(); - { - std::unique_ptr fs( - dmlc::Stream::Create(tmp_file.c_str(), "w")); - dmlc::ostream os(fs.get()); - os << R"qid(3 qid:1 1:1 2:1 3:0 4:0.2 5:0 - 2 qid:1 1:0 2:0 3:1 4:0.1 5:1 - 1 qid:1 1:0 2:1 3:0 4:0.4 5:0 - 1 qid:1 1:0 2:0 3:1 4:0.3 5:0 - 1 qid:2 1:0 2:0 3:1 4:0.2 5:0 - 2 qid:2 1:1 2:0 3:1 4:0.4 5:0 - 1 qid:2 1:0 2:0 3:1 4:0.1 5:0 - 1 qid:2 1:0 2:0 3:1 4:0.2 5:0 - 2 qid:3 1:0 2:0 3:1 4:0.1 5:1 - 3 qid:3 1:1 2:1 3:0 4:0.3 5:0 - 4 qid:3 1:1 2:0 3:0 4:0.4 5:1 - 1 qid:3 1:0 2:1 3:1 4:0.5 5:0)qid"; - os.set_stream(nullptr); - } - std::unique_ptr dmat( - xgboost::DMatrix::Load(tmp_file, true, false, "libsvm")); - std::remove(tmp_file.c_str()); - - const xgboost::MetaInfo& info = dmat->Info(); - const std::vector expected_qids{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}; - const std::vector expected_group_ptr{0, 4, 8, 12}; - CHECK(info.qids_ == expected_qids); - CHECK(info.group_ptr_ == expected_group_ptr); - CHECK_GE(info.kVersion, info.kVersionQidAdded); - - const std::vector expected_offset{ - 0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60 - }; - const std::vector expected_data{ - {1, 1}, {2, 1}, {3, 0}, {4, 0.2}, {5, 0}, - {1, 0}, {2, 0}, {3, 1}, {4, 0.1}, {5, 1}, - {1, 0}, {2, 1}, {3, 0}, {4, 0.4}, {5, 0}, - {1, 0}, {2, 0}, {3, 1}, {4, 0.3}, {5, 0}, - {1, 0}, {2, 0}, {3, 1}, {4, 0.2}, {5, 0}, - {1, 1}, {2, 0}, {3, 1}, {4, 0.4}, {5, 0}, - {1, 0}, {2, 0}, {3, 1}, {4, 0.1}, {5, 0}, - {1, 0}, {2, 0}, {3, 1}, {4, 0.2}, {5, 0}, - {1, 0}, {2, 0}, {3, 1}, {4, 0.1}, {5, 1}, - {1, 1}, {2, 1}, {3, 0}, {4, 0.3}, {5, 0}, - {1, 1}, {2, 0}, {3, 0}, {4, 0.4}, {5, 1}, - {1, 0}, {2, 1}, {3, 1}, {4, 0.5}, {5, 0} - }; - dmlc::DataIter* iter = dmat->RowIterator(); - iter->BeforeFirst(); - CHECK(iter->Next()); - const xgboost::SparsePage& batch = iter->Value(); - CHECK_EQ(batch.base_rowid, 0); - CHECK(batch.offset == expected_offset); - CHECK(batch.data == expected_data); - CHECK(!iter->Next()); } diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 2c2022d10..a1766558c 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -18,11 +18,8 @@ TEST(gpu_hist_experimental, TestSparseShard) { int columns = 80; int max_bins = 4; auto dmat = CreateDMatrix(rows, columns, 0.9f); - common::HistCutMatrix hmat; common::GHistIndexMatrix gmat; - hmat.Init(dmat.get(), max_bins); - gmat.cut = &hmat; - gmat.Init(dmat.get()); + gmat.Init(dmat.get(),max_bins); TrainParam p; p.max_depth = 6; @@ -32,7 +29,7 @@ TEST(gpu_hist_experimental, TestSparseShard) { const SparsePage& batch = iter->Value(); DeviceShard shard(0, 0, 0, rows, p); shard.InitRowPtrs(batch); - shard.InitCompressedData(hmat, batch); + shard.InitCompressedData(gmat.cut, batch); CHECK(!iter->Next()); ASSERT_LT(shard.row_stride, columns); @@ -40,7 +37,7 @@ TEST(gpu_hist_experimental, TestSparseShard) { auto host_gidx_buffer = shard.gidx_buffer.AsVector(); common::CompressedIterator gidx(host_gidx_buffer.data(), - hmat.row_ptr.back() + 1); + gmat.cut.row_ptr.back() + 1); for (int i = 0; i < rows; i++) { int row_offset = 0; @@ -60,11 +57,8 @@ TEST(gpu_hist_experimental, TestDenseShard) { int columns = 80; int max_bins = 4; auto dmat = CreateDMatrix(rows, columns, 0); - common::HistCutMatrix hmat; common::GHistIndexMatrix gmat; - hmat.Init(dmat.get(), max_bins); - gmat.cut = &hmat; - gmat.Init(dmat.get()); + gmat.Init(dmat.get(),max_bins); TrainParam p; p.max_depth = 6; @@ -75,7 +69,7 @@ TEST(gpu_hist_experimental, TestDenseShard) { DeviceShard shard(0, 0, 0, rows, p); shard.InitRowPtrs(batch); - shard.InitCompressedData(hmat, batch); + shard.InitCompressedData(gmat.cut, batch); CHECK(!iter->Next()); ASSERT_EQ(shard.row_stride, columns); @@ -83,7 +77,7 @@ TEST(gpu_hist_experimental, TestDenseShard) { auto host_gidx_buffer = shard.gidx_buffer.AsVector(); common::CompressedIterator gidx(host_gidx_buffer.data(), - hmat.row_ptr.back() + 1); + gmat.cut.row_ptr.back() + 1); for (int i = 0; i < gmat.index.size(); i++) { ASSERT_EQ(gidx[i], gmat.index[i]);