Simplify sparse and dense CPU hist kernels (#7029)

* Simplify sparse and dense kernels
* Extract row partitioner.

Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
This commit is contained in:
ShvetsKS
2021-06-11 13:26:30 +03:00
committed by GitHub
parent 1faad825f4
commit 2567404ab6
10 changed files with 369 additions and 434 deletions

View File

@@ -30,6 +30,8 @@ enum ColumnType {
template <typename BinIdxType>
class Column {
public:
static constexpr int32_t kMissingId = -1;
Column(ColumnType type, common::Span<const BinIdxType> index, const uint32_t index_base)
: type_(type),
index_(index),
@@ -71,6 +73,30 @@ class SparseColumn: public Column<BinIdxType> {
const size_t* GetRowData() const { return row_ind_.data(); }
int32_t GetBinIdx(size_t rid, size_t* state) const {
const size_t column_size = this->Size();
if (!((*state) < column_size)) {
return this->kMissingId;
}
while ((*state) < column_size && GetRowIdx(*state) < rid) {
++(*state);
}
if (((*state) < column_size) && GetRowIdx(*state) == rid) {
return this->GetGlobalBinIdx(*state);
} else {
return this->kMissingId;
}
}
size_t GetInitialState(const size_t first_row_id) const {
const size_t* row_data = GetRowData();
const size_t column_size = this->Size();
// search first nonzero row with index >= rid_span.front()
const size_t* p = std::lower_bound(row_data, row_data + column_size, first_row_id);
// column_size if all messing
return p - row_data;
}
size_t GetRowIdx(size_t idx) const {
return row_ind_.data()[idx];
}
@@ -80,7 +106,7 @@ class SparseColumn: public Column<BinIdxType> {
common::Span<const size_t> row_ind_;
};
template <typename BinIdxType>
template <typename BinIdxType, bool any_missing>
class DenseColumn: public Column<BinIdxType> {
public:
DenseColumn(ColumnType type, common::Span<const BinIdxType> index,
@@ -90,6 +116,19 @@ class DenseColumn: public Column<BinIdxType> {
missing_flags_(missing_flags),
feature_offset_(feature_offset) {}
bool IsMissing(size_t idx) const { return missing_flags_[feature_offset_ + idx]; }
int32_t GetBinIdx(size_t idx, size_t* state) const {
if (any_missing) {
return IsMissing(idx) ? this->kMissingId : this->GetGlobalBinIdx(idx);
} else {
return this->GetGlobalBinIdx(idx);
}
}
size_t GetInitialState(const size_t first_row_id) const {
return 0;
}
private:
/* flags for missing values in dense columns */
const std::vector<bool>& missing_flags_;
@@ -202,7 +241,7 @@ class ColumnMatrix {
/* Fetch an individual column. This code should be used with type swith
to determine type of bin id's */
template <typename BinIdxType>
template <typename BinIdxType, bool any_missing>
std::unique_ptr<const Column<BinIdxType> > GetColumn(unsigned fid) const {
CHECK_EQ(sizeof(BinIdxType), bins_type_size_);
@@ -213,7 +252,8 @@ class ColumnMatrix {
column_size };
std::unique_ptr<const Column<BinIdxType> > res;
if (type_[fid] == ColumnType::kDenseColumn) {
res.reset(new DenseColumn<BinIdxType>(type_[fid], bin_index, index_base_[fid],
CHECK_EQ(any_missing, any_missing_);
res.reset(new DenseColumn<BinIdxType, any_missing>(type_[fid], bin_index, index_base_[fid],
missing_flags_, feature_offset));
} else {
res.reset(new SparseColumn<BinIdxType>(type_[fid], bin_index, index_base_[fid],