Refactor parts of fast histogram utilities (#3564)

* Refactor parts of fast histogram utilities

* Removed byte packing from column matrix
This commit is contained in:
Rory Mitchell 2018-08-09 17:59:57 +12:00 committed by GitHub
parent 3c72654e3b
commit bbb771f32e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 184 additions and 288 deletions

View File

@ -8,47 +8,14 @@
#ifndef XGBOOST_COMMON_COLUMN_MATRIX_H_
#define XGBOOST_COMMON_COLUMN_MATRIX_H_
#define XGBOOST_TYPE_SWITCH(dtype, OP) \
\
switch(dtype) { \
case xgboost::common::uint32: { \
using DType = uint32_t; \
OP; \
break; \
} \
case xgboost::common::uint16: { \
using DType = uint16_t; \
OP; \
break; \
} \
case xgboost::common::uint8: { \
using DType = uint8_t; \
OP; \
break; \
default: \
LOG(FATAL) << "don't recognize type flag" << dtype; \
} \
\
}
#include <type_traits>
#include <limits>
#include <vector>
#include "hist_util.h"
#include "../tree/fast_hist_param.h"
namespace xgboost {
namespace common {
using tree::FastHistParam;
/*! \brief indicator of data type used for storing bin id's in a column. */
enum DataType {
uint8 = 1,
uint16 = 2,
uint32 = 4
};
/*! \brief column type */
enum ColumnType {
@ -58,14 +25,36 @@ enum ColumnType {
/*! \brief a column storage, to be used with ApplySplit. Note that each
bin id is stored as index[i] + index_base. */
template<typename T>
class Column {
public:
ColumnType type;
const T* index;
uint32_t index_base;
const size_t* row_ind;
size_t len;
Column(ColumnType type, const uint32_t* index, uint32_t index_base,
const size_t* row_ind, size_t len)
: type_(type),
index_(index),
index_base_(index_base),
row_ind_(row_ind),
len_(len) {}
size_t Size() const { return len_; }
uint32_t GetGlobalBinIdx(size_t idx) const { return index_base_ + index_[idx]; }
uint32_t GetFeatureBinIdx(size_t idx) const { return index_[idx]; }
// column.GetFeatureBinIdx(idx) + column.GetBaseIdx(idx) ==
// column.GetGlobalBinIdx(idx)
uint32_t GetBaseIdx() const { return index_base_; }
ColumnType GetType() const { return type_; }
size_t GetRowIdx(size_t idx) const {
return type_ == ColumnType::kDenseColumn ? idx : row_ind_[idx];
}
bool IsMissing(size_t idx) const {
return index_[idx] == std::numeric_limits<uint32_t>::max();
}
const size_t* GetRowData() const { return row_ind_; }
private:
ColumnType type_;
const uint32_t* index_;
uint32_t index_base_;
const size_t* row_ind_;
const size_t len_;
};
/*! \brief a collection of columns, with support for construction from
@ -79,13 +68,8 @@ class ColumnMatrix {
// construct column matrix from GHistIndexMatrix
inline void Init(const GHistIndexMatrix& gmat,
const FastHistParam& param) {
this->dtype = static_cast<DataType>(param.colmat_dtype);
/* if dtype is smaller than uint32_t, multiple bin_id's will be stored in each
slot of internal buffer. */
packing_factor_ = sizeof(uint32_t) / static_cast<size_t>(this->dtype);
const auto nfeature = static_cast<bst_uint>(gmat.cut->row_ptr.size() - 1);
double sparse_threshold) {
const auto nfeature = static_cast<bst_uint>(gmat.cut.row_ptr.size() - 1);
const size_t nrow = gmat.row_ptr.size() - 1;
// identify type of each column
@ -93,19 +77,16 @@ class ColumnMatrix {
type_.resize(nfeature);
std::fill(feature_counts_.begin(), feature_counts_.end(), 0);
uint32_t max_val = 0;
XGBOOST_TYPE_SWITCH(this->dtype, {
max_val = static_cast<uint32_t>(std::numeric_limits<DType>::max());
});
uint32_t max_val = std::numeric_limits<uint32_t>::max();
for (bst_uint fid = 0; fid < nfeature; ++fid) {
CHECK_LE(gmat.cut->row_ptr[fid + 1] - gmat.cut->row_ptr[fid], max_val);
CHECK_LE(gmat.cut.row_ptr[fid + 1] - gmat.cut.row_ptr[fid], max_val);
}
gmat.GetFeatureCounts(&feature_counts_[0]);
// classify features
for (bst_uint fid = 0; fid < nfeature; ++fid) {
if (static_cast<double>(feature_counts_[fid])
< param.sparse_threshold * nrow) {
< sparse_threshold * nrow) {
type_[fid] = kSparseColumn;
} else {
type_[fid] = kDenseColumn;
@ -131,28 +112,23 @@ class ColumnMatrix {
boundary_[fid].row_ind_end = accum_row_ind_;
}
index_.resize((boundary_[nfeature - 1].index_end
+ (packing_factor_ - 1)) / packing_factor_);
index_.resize(boundary_[nfeature - 1].index_end);
row_ind_.resize(boundary_[nfeature - 1].row_ind_end);
// store least bin id for each feature
index_base_.resize(nfeature);
for (bst_uint fid = 0; fid < nfeature; ++fid) {
index_base_[fid] = gmat.cut->row_ptr[fid];
index_base_[fid] = gmat.cut.row_ptr[fid];
}
// pre-fill index_ for dense columns
for (bst_uint fid = 0; fid < nfeature; ++fid) {
if (type_[fid] == kDenseColumn) {
const size_t ibegin = boundary_[fid].index_begin;
XGBOOST_TYPE_SWITCH(this->dtype, {
const size_t block_offset = ibegin / packing_factor_;
const size_t elem_offset = ibegin % packing_factor_;
DType* begin = reinterpret_cast<DType*>(&index_[block_offset]) + elem_offset;
DType* end = begin + nrow;
std::fill(begin, end, std::numeric_limits<DType>::max());
// max() indicates missing values
});
uint32_t* begin = &index_[ibegin];
uint32_t* end = begin + nrow;
std::fill(begin, end, std::numeric_limits<uint32_t>::max());
// max() indicates missing values
}
}
@ -167,23 +143,15 @@ class ColumnMatrix {
size_t fid = 0;
for (size_t i = ibegin; i < iend; ++i) {
const uint32_t bin_id = gmat.index[i];
while (bin_id >= gmat.cut->row_ptr[fid + 1]) {
while (bin_id >= gmat.cut.row_ptr[fid + 1]) {
++fid;
}
if (type_[fid] == kDenseColumn) {
XGBOOST_TYPE_SWITCH(this->dtype, {
const size_t block_offset = boundary_[fid].index_begin / packing_factor_;
const size_t elem_offset = boundary_[fid].index_begin % packing_factor_;
DType* begin = reinterpret_cast<DType*>(&index_[block_offset]) + elem_offset;
begin[rid] = static_cast<DType>(bin_id - index_base_[fid]);
});
uint32_t* begin = &index_[boundary_[fid].index_begin];
begin[rid] = bin_id - index_base_[fid];
} else {
XGBOOST_TYPE_SWITCH(this->dtype, {
const size_t block_offset = boundary_[fid].index_begin / packing_factor_;
const size_t elem_offset = boundary_[fid].index_begin % packing_factor_;
DType* begin = reinterpret_cast<DType*>(&index_[block_offset]) + elem_offset;
begin[num_nonzeros[fid]] = static_cast<DType>(bin_id - index_base_[fid]);
});
uint32_t* begin = &index_[boundary_[fid].index_begin];
begin[num_nonzeros[fid]] = bin_id - index_base_[fid];
row_ind_[boundary_[fid].row_ind_begin + num_nonzeros[fid]] = rid;
++num_nonzeros[fid];
}
@ -193,29 +161,13 @@ class ColumnMatrix {
/* Fetch an individual column. This code should be used with XGBOOST_TYPE_SWITCH
to determine type of bin id's */
template<typename T>
inline Column<T> GetColumn(unsigned fid) const {
const bool valid_type = std::is_same<T, uint32_t>::value
|| std::is_same<T, uint16_t>::value
|| std::is_same<T, uint8_t>::value;
CHECK(valid_type);
Column<T> c;
c.type = type_[fid];
const size_t block_offset = boundary_[fid].index_begin / packing_factor_;
const size_t elem_offset = boundary_[fid].index_begin % packing_factor_;
c.index = reinterpret_cast<const T*>(&index_[block_offset]) + elem_offset;
c.index_base = index_base_[fid];
c.row_ind = &row_ind_[boundary_[fid].row_ind_begin];
c.len = boundary_[fid].index_end - boundary_[fid].index_begin;
inline Column GetColumn(unsigned fid) const {
Column c(type_[fid], &index_[boundary_[fid].index_begin], index_base_[fid],
&row_ind_[boundary_[fid].row_ind_begin],
boundary_[fid].index_end - boundary_[fid].index_begin);
return c;
}
public:
DataType dtype;
private:
struct ColumnBoundary {
// indicate where each column's index and row_ind is stored.
@ -233,8 +185,6 @@ class ColumnMatrix {
std::vector<size_t> row_ind_;
std::vector<ColumnBoundary> boundary_;
size_t packing_factor_; // how many integers are stored in each slot of index_
// index_base_[fid]: least bin id for feature fid
std::vector<uint32_t> index_base_;
};

View File

@ -114,12 +114,23 @@ void HistCutMatrix::Init
}
}
void GHistIndexMatrix::Init(DMatrix* p_fmat) {
CHECK(cut != nullptr); // NOLINT
uint32_t HistCutMatrix::GetBinIdx(const Entry& e) {
unsigned fid = e.index;
auto cbegin = cut.begin() + row_ptr[fid];
auto cend = cut.begin() + row_ptr[fid + 1];
CHECK(cbegin != cend);
auto it = std::upper_bound(cbegin, cend, e.fvalue);
if (it == cend) it = cend - 1;
uint32_t idx = static_cast<uint32_t>(it - cut.begin());
return idx;
}
void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
cut.Init(p_fmat, max_num_bins);
auto iter = p_fmat->RowIterator();
const int nthread = omp_get_max_threads();
const uint32_t nbins = cut->row_ptr.back();
const uint32_t nbins = cut.row_ptr.back();
hit_count.resize(nbins, 0);
hit_count_tloc_.resize(nthread * nbins, 0);
@ -133,8 +144,8 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) {
}
index.resize(row_ptr.back());
CHECK_GT(cut->cut.size(), 0U);
CHECK_EQ(cut->row_ptr.back(), cut->cut.size());
CHECK_GT(cut.cut.size(), 0U);
CHECK_EQ(cut.row_ptr.back(), cut.cut.size());
auto bsize = static_cast<omp_ulong>(batch.Size());
#pragma omp parallel for num_threads(nthread) schedule(static)
@ -145,13 +156,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) {
SparsePage::Inst inst = batch[i];
CHECK_EQ(ibegin + inst.length, iend);
for (bst_uint j = 0; j < inst.length; ++j) {
unsigned fid = inst[j].index;
auto cbegin = cut->cut.begin() + cut->row_ptr[fid];
auto cend = cut->cut.begin() + cut->row_ptr[fid + 1];
CHECK(cbegin != cend);
auto it = std::upper_bound(cbegin, cend, inst[j].fvalue);
if (it == cend) it = cend - 1;
uint32_t idx = static_cast<uint32_t>(it - cut->cut.begin());
uint32_t idx = cut.GetBinIdx(inst[j]);
index[ibegin + j] = idx;
++hit_count_tloc_[tid * nbins + idx];
}
@ -167,14 +172,13 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat) {
}
}
template <typename T>
static size_t GetConflictCount(const std::vector<bool>& mark,
const Column<T>& column,
const Column& column,
size_t max_cnt) {
size_t ret = 0;
if (column.type == xgboost::common::kDenseColumn) {
for (size_t i = 0; i < column.len; ++i) {
if (column.index[i] != std::numeric_limits<T>::max() && mark[i]) {
if (column.GetType() == xgboost::common::kDenseColumn) {
for (size_t i = 0; i < column.Size(); ++i) {
if (column.GetFeatureBinIdx(i) != std::numeric_limits<uint32_t>::max() && mark[i]) {
++ret;
if (ret > max_cnt) {
return max_cnt + 1;
@ -182,8 +186,8 @@ static size_t GetConflictCount(const std::vector<bool>& mark,
}
}
} else {
for (size_t i = 0; i < column.len; ++i) {
if (mark[column.row_ind[i]]) {
for (size_t i = 0; i < column.Size(); ++i) {
if (mark[column.GetRowIdx(i)]) {
++ret;
if (ret > max_cnt) {
return max_cnt + 1;
@ -194,30 +198,28 @@ static size_t GetConflictCount(const std::vector<bool>& mark,
return ret;
}
template <typename T>
inline void
MarkUsed(std::vector<bool>* p_mark, const Column<T>& column) {
MarkUsed(std::vector<bool>* p_mark, const Column& column) {
std::vector<bool>& mark = *p_mark;
if (column.type == xgboost::common::kDenseColumn) {
for (size_t i = 0; i < column.len; ++i) {
if (column.index[i] != std::numeric_limits<T>::max()) {
if (column.GetType() == xgboost::common::kDenseColumn) {
for (size_t i = 0; i < column.Size(); ++i) {
if (column.GetFeatureBinIdx(i) != std::numeric_limits<uint32_t>::max()) {
mark[i] = true;
}
}
} else {
for (size_t i = 0; i < column.len; ++i) {
mark[column.row_ind[i]] = true;
for (size_t i = 0; i < column.Size(); ++i) {
mark[column.GetRowIdx(i)] = true;
}
}
}
template <typename T>
inline std::vector<std::vector<unsigned>>
FindGroups_(const std::vector<unsigned>& feature_list,
const std::vector<size_t>& feature_nnz,
const ColumnMatrix& colmat,
size_t nrow,
const FastHistParam& param) {
FindGroups(const std::vector<unsigned>& feature_list,
const std::vector<size_t>& feature_nnz,
const ColumnMatrix& colmat,
size_t nrow,
const FastHistParam& param) {
/* Goal: Bundle features together that has little or no "overlap", i.e.
only a few data points should have nonzero values for
member features.
@ -231,7 +233,7 @@ FindGroups_(const std::vector<unsigned>& feature_list,
= static_cast<size_t>(param.max_conflict_rate * nrow);
for (auto fid : feature_list) {
const Column<T>& column = colmat.GetColumn<T>(fid);
const Column& column = colmat.GetColumn(fid);
const size_t cur_fid_nnz = feature_nnz[fid];
bool need_new_group = true;
@ -276,24 +278,12 @@ FindGroups_(const std::vector<unsigned>& feature_list,
return groups;
}
inline std::vector<std::vector<unsigned>>
FindGroups(const std::vector<unsigned>& feature_list,
const std::vector<size_t>& feature_nnz,
const ColumnMatrix& colmat,
size_t nrow,
const FastHistParam& param) {
XGBOOST_TYPE_SWITCH(colmat.dtype, {
return FindGroups_<DType>(feature_list, feature_nnz, colmat, nrow, param);
});
return std::vector<std::vector<unsigned>>(); // to avoid warning message
}
inline std::vector<std::vector<unsigned>>
FastFeatureGrouping(const GHistIndexMatrix& gmat,
const ColumnMatrix& colmat,
const FastHistParam& param) {
const size_t nrow = gmat.row_ptr.size() - 1;
const size_t nfeature = gmat.cut->row_ptr.size() - 1;
const size_t nfeature = gmat.cut.row_ptr.size() - 1;
std::vector<unsigned> feature_list(nfeature);
std::iota(feature_list.begin(), feature_list.end(), 0);
@ -346,10 +336,10 @@ FastFeatureGrouping(const GHistIndexMatrix& gmat,
void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
const ColumnMatrix& colmat,
const FastHistParam& param) {
cut_ = gmat.cut;
cut_ = &gmat.cut;
const size_t nrow = gmat.row_ptr.size() - 1;
const uint32_t nbins = gmat.cut->row_ptr.back();
const uint32_t nbins = gmat.cut.row_ptr.back();
/* step 1: form feature groups */
auto groups = FastFeatureGrouping(gmat, colmat, param);
@ -359,8 +349,8 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
std::vector<uint32_t> bin2block(nbins); // lookup table [bin id] => [block id]
for (uint32_t group_id = 0; group_id < nblock; ++group_id) {
for (auto& fid : groups[group_id]) {
const uint32_t bin_begin = gmat.cut->row_ptr[fid];
const uint32_t bin_end = gmat.cut->row_ptr[fid + 1];
const uint32_t bin_begin = gmat.cut.row_ptr[fid];
const uint32_t bin_end = gmat.cut.row_ptr[fid + 1];
for (uint32_t bin_id = bin_begin; bin_id < bin_end; ++bin_id) {
bin2block[bin_id] = group_id;
}

View File

@ -75,6 +75,7 @@ struct HistCutMatrix {
std::vector<bst_float> min_val;
/*! \brief the cut field */
std::vector<bst_float> cut;
uint32_t GetBinIdx(const Entry &e);
/*! \brief Get histogram bound for fid */
inline HistCutUnit operator[](bst_uint fid) const {
return {dmlc::BeginPtr(cut) + row_ptr[fid],
@ -122,18 +123,18 @@ struct GHistIndexMatrix {
/*! \brief hit count of each index */
std::vector<size_t> hit_count;
/*! \brief The corresponding cuts */
const HistCutMatrix* cut;
HistCutMatrix cut;
// Create a global histogram matrix, given cut
void Init(DMatrix* p_fmat);
void Init(DMatrix* p_fmat, int max_num_bins);
// get i-th row
inline GHistIndexRow operator[](size_t i) const {
return {&index[0] + row_ptr[i], row_ptr[i + 1] - row_ptr[i]};
}
inline void GetFeatureCounts(size_t* counts) const {
auto nfeature = cut->row_ptr.size() - 1;
auto nfeature = cut.row_ptr.size() - 1;
for (unsigned fid = 0; fid < nfeature; ++fid) {
auto ibegin = cut->row_ptr[fid];
auto iend = cut->row_ptr[fid + 1];
auto ibegin = cut.row_ptr[fid];
auto iend = cut.row_ptr[fid + 1];
for (auto i = ibegin; i < iend; ++i) {
counts[fid] += hit_count[i];
}

View File

@ -12,8 +12,6 @@ namespace tree {
/*! \brief training parameters for histogram-based training */
struct FastHistParam : public dmlc::Parameter<FastHistParam> {
// integral data type to be used with columnar data storage
enum class DataType { uint8 = 1, uint16 = 2, uint32 = 4 }; // NOLINT
int colmat_dtype;
// percentage threshold for treating a feature as sparse
// e.g. 0.2 indicates a feature with fewer than 20% nonzeros is considered sparse
@ -32,14 +30,6 @@ struct FastHistParam : public dmlc::Parameter<FastHistParam> {
// declare the parameters
DMLC_DECLARE_PARAMETER(FastHistParam) {
DMLC_DECLARE_FIELD(colmat_dtype)
.set_default(static_cast<int>(DataType::uint32))
.add_enum("uint8", static_cast<int>(DataType::uint8))
.add_enum("uint16", static_cast<int>(DataType::uint16))
.add_enum("uint32", static_cast<int>(DataType::uint32))
.describe("Integral data type to be used with columnar data storage."
"May carry marginal performance implications. Reserved for "
"advanced use");
DMLC_DECLARE_FIELD(sparse_threshold).set_range(0, 1.0).set_default(0.2)
.describe("percentage threshold for treating a feature as sparse");
DMLC_DECLARE_FIELD(enable_feature_grouping).set_lower_bound(0).set_default(0)

View File

@ -69,10 +69,8 @@ class FastHistMaker: public TreeUpdater {
GradStats::CheckInfo(dmat->Info());
if (is_gmat_initialized_ == false) {
double tstart = dmlc::GetTime();
hmat_.Init(dmat, static_cast<uint32_t>(param_.max_bin));
gmat_.cut = &hmat_;
gmat_.Init(dmat);
column_matrix_.Init(gmat_, fhparam_);
gmat_.Init(dmat, static_cast<uint32_t>(param_.max_bin));
column_matrix_.Init(gmat_, fhparam_.sparse_threshold);
if (fhparam_.enable_feature_grouping > 0) {
gmatb_.Init(gmat_, column_matrix_, fhparam_);
}
@ -112,8 +110,6 @@ class FastHistMaker: public TreeUpdater {
// training parameter
TrainParam param_;
FastHistParam fhparam_;
// data sketch
HistCutMatrix hmat_;
// quantized data matrix
GHistIndexMatrix gmat_;
// (optional) data matrix with feature grouping
@ -376,7 +372,7 @@ class FastHistMaker: public TreeUpdater {
// clear local prediction cache
leaf_value_cache_.clear();
// initialize histogram collection
uint32_t nbins = gmat.cut->row_ptr.back();
uint32_t nbins = gmat.cut.row_ptr.back();
hist_.Init(nbins);
// initialize histogram builder
@ -413,7 +409,7 @@ class FastHistMaker: public TreeUpdater {
const size_t ncol = info.num_col_;
const size_t nnz = info.num_nonzero_;
// number of discrete bins for feature 0
const uint32_t nbins_f0 = gmat.cut->row_ptr[1] - gmat.cut->row_ptr[0];
const uint32_t nbins_f0 = gmat.cut.row_ptr[1] - gmat.cut.row_ptr[0];
if (nrow * ncol == nnz) {
// dense data with zero-based indexing
data_layout_ = kDenseDataZeroBased;
@ -454,7 +450,7 @@ class FastHistMaker: public TreeUpdater {
choose the column that has a least positive number of discrete bins.
For dense data (with no missing value),
the sum of gradient histogram is equal to snode[nid] */
const std::vector<uint32_t>& row_ptr = gmat.cut->row_ptr;
const std::vector<uint32_t>& row_ptr = gmat.cut.row_ptr;
const auto nfeature = static_cast<bst_uint>(row_ptr.size() - 1);
uint32_t min_nbins_per_feature = 0;
for (bst_uint i = 0; i < nfeature; ++i) {
@ -516,19 +512,6 @@ class FastHistMaker: public TreeUpdater {
const HistCollection& hist,
const DMatrix& fmat,
RegTree* p_tree) {
XGBOOST_TYPE_SWITCH(column_matrix.dtype, {
ApplySplitSpecialize<DType>(nid, gmat, column_matrix, hist, fmat,
p_tree);
});
}
template <typename T>
inline void ApplySplitSpecialize(int nid,
const GHistIndexMatrix& gmat,
const ColumnMatrix& column_matrix,
const HistCollection& hist,
const DMatrix& fmat,
RegTree* p_tree) {
// TODO(hcho3): support feature sampling by levels
/* 1. Create child nodes */
@ -552,23 +535,23 @@ class FastHistMaker: public TreeUpdater {
const bool default_left = (*p_tree)[nid].DefaultLeft();
const bst_uint fid = (*p_tree)[nid].SplitIndex();
const bst_float split_pt = (*p_tree)[nid].SplitCond();
const uint32_t lower_bound = gmat.cut->row_ptr[fid];
const uint32_t upper_bound = gmat.cut->row_ptr[fid + 1];
const uint32_t lower_bound = gmat.cut.row_ptr[fid];
const uint32_t upper_bound = gmat.cut.row_ptr[fid + 1];
int32_t split_cond = -1;
// convert floating-point split_pt into corresponding bin_id
// split_cond = -1 indicates that split_pt is less than all known cut points
CHECK_LT(upper_bound,
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
for (uint32_t i = lower_bound; i < upper_bound; ++i) {
if (split_pt == gmat.cut->cut[i]) {
if (split_pt == gmat.cut.cut[i]) {
split_cond = static_cast<int32_t>(i);
}
}
const auto& rowset = row_set_collection_[nid];
Column<T> column = column_matrix.GetColumn<T>(fid);
if (column.type == xgboost::common::kDenseColumn) {
Column column = column_matrix.GetColumn(fid);
if (column.GetType() == xgboost::common::kDenseColumn) {
ApplySplitDenseData(rowset, gmat, &row_split_tloc_, column, split_cond,
default_left);
} else {
@ -580,11 +563,10 @@ class FastHistMaker: public TreeUpdater {
nid, row_split_tloc_, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild());
}
template<typename T>
inline void ApplySplitDenseData(const RowSetCollection::Elem rowset,
const GHistIndexMatrix& gmat,
std::vector<RowSetCollection::Split>* p_row_split_tloc,
const Column<T>& column,
const Column& column,
bst_int split_cond,
bool default_left) {
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
@ -598,24 +580,22 @@ class FastHistMaker: public TreeUpdater {
auto& left = row_split_tloc[tid].left;
auto& right = row_split_tloc[tid].right;
size_t rid[kUnroll];
T rbin[kUnroll];
uint32_t rbin[kUnroll];
for (int k = 0; k < kUnroll; ++k) {
rid[k] = rowset.begin[i + k];
}
for (int k = 0; k < kUnroll; ++k) {
rbin[k] = column.index[rid[k]];
rbin[k] = column.GetFeatureBinIdx(rid[k]);
}
for (int k = 0; k < kUnroll; ++k) { // NOLINT
if (rbin[k] == std::numeric_limits<T>::max()) { // missing value
if (rbin[k] == std::numeric_limits<uint32_t>::max()) { // missing value
if (default_left) {
left.push_back(rid[k]);
} else {
right.push_back(rid[k]);
}
} else {
CHECK_LT(rbin[k] + column.index_base,
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
if (static_cast<int32_t>(rbin[k] + column.index_base) <= split_cond) {
if (static_cast<int32_t>(rbin[k] + column.GetBaseIdx()) <= split_cond) {
left.push_back(rid[k]);
} else {
right.push_back(rid[k]);
@ -627,17 +607,15 @@ class FastHistMaker: public TreeUpdater {
auto& left = row_split_tloc[nthread_-1].left;
auto& right = row_split_tloc[nthread_-1].right;
const size_t rid = rowset.begin[i];
const T rbin = column.index[rid];
if (rbin == std::numeric_limits<T>::max()) { // missing value
const uint32_t rbin = column.GetFeatureBinIdx(rid);
if (rbin == std::numeric_limits<uint32_t>::max()) { // missing value
if (default_left) {
left.push_back(rid);
} else {
right.push_back(rid);
}
} else {
CHECK_LT(rbin + column.index_base,
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
if (static_cast<int32_t>(rbin + column.index_base) <= split_cond) {
if (static_cast<int32_t>(rbin + column.GetBaseIdx()) <= split_cond) {
left.push_back(rid);
} else {
right.push_back(rid);
@ -646,11 +624,10 @@ class FastHistMaker: public TreeUpdater {
}
}
template<typename T>
inline void ApplySplitSparseData(const RowSetCollection::Elem rowset,
const GHistIndexMatrix& gmat,
std::vector<RowSetCollection::Split>* p_row_split_tloc,
const Column<T>& column,
const Column& column,
bst_uint lower_bound,
bst_uint upper_bound,
bst_int split_cond,
@ -665,27 +642,25 @@ class FastHistMaker: public TreeUpdater {
const size_t iend = (tid + 1) * nrows / nthread_;
if (ibegin < iend) { // ensure that [ibegin, iend) is nonempty range
// search first nonzero row with index >= rowset[ibegin]
const size_t* p = std::lower_bound(column.row_ind,
column.row_ind + column.len,
const size_t* p = std::lower_bound(column.GetRowData(),
column.GetRowData() + column.Size(),
rowset.begin[ibegin]);
auto& left = row_split_tloc[tid].left;
auto& right = row_split_tloc[tid].right;
if (p != column.row_ind + column.len && *p <= rowset.begin[iend - 1]) {
size_t cursor = p - column.row_ind;
if (p != column.GetRowData() + column.Size() && *p <= rowset.begin[iend - 1]) {
size_t cursor = p - column.GetRowData();
for (size_t i = ibegin; i < iend; ++i) {
const size_t rid = rowset.begin[i];
while (cursor < column.len
&& column.row_ind[cursor] < rid
&& column.row_ind[cursor] <= rowset.begin[iend - 1]) {
while (cursor < column.Size()
&& column.GetRowIdx(cursor) < rid
&& column.GetRowIdx(cursor) <= rowset.begin[iend - 1]) {
++cursor;
}
if (cursor < column.len && column.row_ind[cursor] == rid) {
const T rbin = column.index[cursor];
CHECK_LT(rbin + column.index_base,
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
if (static_cast<int32_t>(rbin + column.index_base) <= split_cond) {
if (cursor < column.Size() && column.GetRowIdx(cursor) == rid) {
const uint32_t rbin = column.GetFeatureBinIdx(cursor);
if (static_cast<int32_t>(rbin + column.GetBaseIdx()) <= split_cond) {
left.push_back(rid);
} else {
right.push_back(rid);
@ -733,7 +708,7 @@ class FastHistMaker: public TreeUpdater {
For dense data (with no missing value),
the sum of gradient histogram is equal to snode[nid] */
GHistRow hist = hist_[nid];
const std::vector<uint32_t>& row_ptr = gmat.cut->row_ptr;
const std::vector<uint32_t>& row_ptr = gmat.cut.row_ptr;
const uint32_t ibegin = row_ptr[fid_least_bins_];
const uint32_t iend = row_ptr[fid_least_bins_ + 1];
@ -771,8 +746,8 @@ class FastHistMaker: public TreeUpdater {
CHECK(d_step == +1 || d_step == -1);
// aliases
const std::vector<uint32_t>& cut_ptr = gmat.cut->row_ptr;
const std::vector<bst_float>& cut_val = gmat.cut->cut;
const std::vector<uint32_t>& cut_ptr = gmat.cut.row_ptr;
const std::vector<bst_float>& cut_val = gmat.cut.cut;
// statistics on both sides of split
GradStats c(param_);
@ -821,7 +796,7 @@ class FastHistMaker: public TreeUpdater {
snode.root_gain);
if (i == imin) {
// for leftmost bin, left bound is the smallest feature value
split_pt = gmat.cut->min_val[fid];
split_pt = gmat.cut.min_val[fid];
} else {
split_pt = cut_val[i - 1];
}

View File

@ -0,0 +1,51 @@
#include "../../../src/common/column_matrix.h"
#include "../helpers.h"
#include "gtest/gtest.h"
namespace xgboost {
namespace common {
TEST(DenseColumn, Test) {
auto dmat = CreateDMatrix(100, 10, 0.0);
GHistIndexMatrix gmat;
gmat.Init(dmat.get(), 256);
ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.2);
for (auto i = 0ull; i < dmat->Info().num_row_; i++) {
for (auto j = 0ull; j < dmat->Info().num_col_; j++) {
auto col = column_matrix.GetColumn(j);
EXPECT_EQ(gmat.index[i * dmat->Info().num_col_ + j],
col.GetGlobalBinIdx(i));
}
}
}
TEST(SparseColumn, Test) {
auto dmat = CreateDMatrix(100, 1, 0.85);
GHistIndexMatrix gmat;
gmat.Init(dmat.get(), 256);
ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.5);
auto col = column_matrix.GetColumn(0);
ASSERT_EQ(col.Size(), gmat.index.size());
for (auto i = 0ull; i < col.Size(); i++) {
EXPECT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]],
col.GetGlobalBinIdx(i));
}
}
TEST(DenseColumnWithMissing, Test) {
auto dmat = CreateDMatrix(100, 1, 0.5);
GHistIndexMatrix gmat;
gmat.Init(dmat.get(), 256);
ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.2);
auto col = column_matrix.GetColumn(0);
for (auto i = 0ull; i < col.Size(); i++) {
if (col.IsMissing(i)) continue;
EXPECT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]],
col.GetGlobalBinIdx(i));
}
}
} // namespace common
} // namespace xgboost

View File

@ -67,59 +67,4 @@ TEST(MetaInfo, SaveLoadBinary) {
}
TEST(MetaInfo, LoadQid) {
std::string tmp_file = TempFileName();
{
std::unique_ptr<dmlc::Stream> fs(
dmlc::Stream::Create(tmp_file.c_str(), "w"));
dmlc::ostream os(fs.get());
os << R"qid(3 qid:1 1:1 2:1 3:0 4:0.2 5:0
2 qid:1 1:0 2:0 3:1 4:0.1 5:1
1 qid:1 1:0 2:1 3:0 4:0.4 5:0
1 qid:1 1:0 2:0 3:1 4:0.3 5:0
1 qid:2 1:0 2:0 3:1 4:0.2 5:0
2 qid:2 1:1 2:0 3:1 4:0.4 5:0
1 qid:2 1:0 2:0 3:1 4:0.1 5:0
1 qid:2 1:0 2:0 3:1 4:0.2 5:0
2 qid:3 1:0 2:0 3:1 4:0.1 5:1
3 qid:3 1:1 2:1 3:0 4:0.3 5:0
4 qid:3 1:1 2:0 3:0 4:0.4 5:1
1 qid:3 1:0 2:1 3:1 4:0.5 5:0)qid";
os.set_stream(nullptr);
}
std::unique_ptr<xgboost::DMatrix> dmat(
xgboost::DMatrix::Load(tmp_file, true, false, "libsvm"));
std::remove(tmp_file.c_str());
const xgboost::MetaInfo& info = dmat->Info();
const std::vector<uint64_t> expected_qids{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3};
const std::vector<xgboost::bst_uint> expected_group_ptr{0, 4, 8, 12};
CHECK(info.qids_ == expected_qids);
CHECK(info.group_ptr_ == expected_group_ptr);
CHECK_GE(info.kVersion, info.kVersionQidAdded);
const std::vector<size_t> expected_offset{
0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60
};
const std::vector<xgboost::Entry> expected_data{
{1, 1}, {2, 1}, {3, 0}, {4, 0.2}, {5, 0},
{1, 0}, {2, 0}, {3, 1}, {4, 0.1}, {5, 1},
{1, 0}, {2, 1}, {3, 0}, {4, 0.4}, {5, 0},
{1, 0}, {2, 0}, {3, 1}, {4, 0.3}, {5, 0},
{1, 0}, {2, 0}, {3, 1}, {4, 0.2}, {5, 0},
{1, 1}, {2, 0}, {3, 1}, {4, 0.4}, {5, 0},
{1, 0}, {2, 0}, {3, 1}, {4, 0.1}, {5, 0},
{1, 0}, {2, 0}, {3, 1}, {4, 0.2}, {5, 0},
{1, 0}, {2, 0}, {3, 1}, {4, 0.1}, {5, 1},
{1, 1}, {2, 1}, {3, 0}, {4, 0.3}, {5, 0},
{1, 1}, {2, 0}, {3, 0}, {4, 0.4}, {5, 1},
{1, 0}, {2, 1}, {3, 1}, {4, 0.5}, {5, 0}
};
dmlc::DataIter<xgboost::SparsePage>* iter = dmat->RowIterator();
iter->BeforeFirst();
CHECK(iter->Next());
const xgboost::SparsePage& batch = iter->Value();
CHECK_EQ(batch.base_rowid, 0);
CHECK(batch.offset == expected_offset);
CHECK(batch.data == expected_data);
CHECK(!iter->Next());
}

View File

@ -18,11 +18,8 @@ TEST(gpu_hist_experimental, TestSparseShard) {
int columns = 80;
int max_bins = 4;
auto dmat = CreateDMatrix(rows, columns, 0.9f);
common::HistCutMatrix hmat;
common::GHistIndexMatrix gmat;
hmat.Init(dmat.get(), max_bins);
gmat.cut = &hmat;
gmat.Init(dmat.get());
gmat.Init(dmat.get(),max_bins);
TrainParam p;
p.max_depth = 6;
@ -32,7 +29,7 @@ TEST(gpu_hist_experimental, TestSparseShard) {
const SparsePage& batch = iter->Value();
DeviceShard shard(0, 0, 0, rows, p);
shard.InitRowPtrs(batch);
shard.InitCompressedData(hmat, batch);
shard.InitCompressedData(gmat.cut, batch);
CHECK(!iter->Next());
ASSERT_LT(shard.row_stride, columns);
@ -40,7 +37,7 @@ TEST(gpu_hist_experimental, TestSparseShard) {
auto host_gidx_buffer = shard.gidx_buffer.AsVector();
common::CompressedIterator<uint32_t> gidx(host_gidx_buffer.data(),
hmat.row_ptr.back() + 1);
gmat.cut.row_ptr.back() + 1);
for (int i = 0; i < rows; i++) {
int row_offset = 0;
@ -60,11 +57,8 @@ TEST(gpu_hist_experimental, TestDenseShard) {
int columns = 80;
int max_bins = 4;
auto dmat = CreateDMatrix(rows, columns, 0);
common::HistCutMatrix hmat;
common::GHistIndexMatrix gmat;
hmat.Init(dmat.get(), max_bins);
gmat.cut = &hmat;
gmat.Init(dmat.get());
gmat.Init(dmat.get(),max_bins);
TrainParam p;
p.max_depth = 6;
@ -75,7 +69,7 @@ TEST(gpu_hist_experimental, TestDenseShard) {
DeviceShard shard(0, 0, 0, rows, p);
shard.InitRowPtrs(batch);
shard.InitCompressedData(hmat, batch);
shard.InitCompressedData(gmat.cut, batch);
CHECK(!iter->Next());
ASSERT_EQ(shard.row_stride, columns);
@ -83,7 +77,7 @@ TEST(gpu_hist_experimental, TestDenseShard) {
auto host_gidx_buffer = shard.gidx_buffer.AsVector();
common::CompressedIterator<uint32_t> gidx(host_gidx_buffer.data(),
hmat.row_ptr.back() + 1);
gmat.cut.row_ptr.back() + 1);
for (int i = 0; i < gmat.index.size(); i++) {
ASSERT_EQ(gidx[i], gmat.index[i]);