Simplify sparse and dense CPU hist kernels (#7029)
* Simplify sparse and dense kernels * Extract row partitioner. Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
This commit is contained in:
@@ -30,6 +30,8 @@ enum ColumnType {
|
||||
template <typename BinIdxType>
|
||||
class Column {
|
||||
public:
|
||||
static constexpr int32_t kMissingId = -1;
|
||||
|
||||
Column(ColumnType type, common::Span<const BinIdxType> index, const uint32_t index_base)
|
||||
: type_(type),
|
||||
index_(index),
|
||||
@@ -71,6 +73,30 @@ class SparseColumn: public Column<BinIdxType> {
|
||||
|
||||
const size_t* GetRowData() const { return row_ind_.data(); }
|
||||
|
||||
int32_t GetBinIdx(size_t rid, size_t* state) const {
|
||||
const size_t column_size = this->Size();
|
||||
if (!((*state) < column_size)) {
|
||||
return this->kMissingId;
|
||||
}
|
||||
while ((*state) < column_size && GetRowIdx(*state) < rid) {
|
||||
++(*state);
|
||||
}
|
||||
if (((*state) < column_size) && GetRowIdx(*state) == rid) {
|
||||
return this->GetGlobalBinIdx(*state);
|
||||
} else {
|
||||
return this->kMissingId;
|
||||
}
|
||||
}
|
||||
|
||||
size_t GetInitialState(const size_t first_row_id) const {
|
||||
const size_t* row_data = GetRowData();
|
||||
const size_t column_size = this->Size();
|
||||
// search first nonzero row with index >= rid_span.front()
|
||||
const size_t* p = std::lower_bound(row_data, row_data + column_size, first_row_id);
|
||||
// column_size if all messing
|
||||
return p - row_data;
|
||||
}
|
||||
|
||||
size_t GetRowIdx(size_t idx) const {
|
||||
return row_ind_.data()[idx];
|
||||
}
|
||||
@@ -80,7 +106,7 @@ class SparseColumn: public Column<BinIdxType> {
|
||||
common::Span<const size_t> row_ind_;
|
||||
};
|
||||
|
||||
template <typename BinIdxType>
|
||||
template <typename BinIdxType, bool any_missing>
|
||||
class DenseColumn: public Column<BinIdxType> {
|
||||
public:
|
||||
DenseColumn(ColumnType type, common::Span<const BinIdxType> index,
|
||||
@@ -90,6 +116,19 @@ class DenseColumn: public Column<BinIdxType> {
|
||||
missing_flags_(missing_flags),
|
||||
feature_offset_(feature_offset) {}
|
||||
bool IsMissing(size_t idx) const { return missing_flags_[feature_offset_ + idx]; }
|
||||
|
||||
int32_t GetBinIdx(size_t idx, size_t* state) const {
|
||||
if (any_missing) {
|
||||
return IsMissing(idx) ? this->kMissingId : this->GetGlobalBinIdx(idx);
|
||||
} else {
|
||||
return this->GetGlobalBinIdx(idx);
|
||||
}
|
||||
}
|
||||
|
||||
size_t GetInitialState(const size_t first_row_id) const {
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
/* flags for missing values in dense columns */
|
||||
const std::vector<bool>& missing_flags_;
|
||||
@@ -202,7 +241,7 @@ class ColumnMatrix {
|
||||
|
||||
/* Fetch an individual column. This code should be used with type swith
|
||||
to determine type of bin id's */
|
||||
template <typename BinIdxType>
|
||||
template <typename BinIdxType, bool any_missing>
|
||||
std::unique_ptr<const Column<BinIdxType> > GetColumn(unsigned fid) const {
|
||||
CHECK_EQ(sizeof(BinIdxType), bins_type_size_);
|
||||
|
||||
@@ -213,7 +252,8 @@ class ColumnMatrix {
|
||||
column_size };
|
||||
std::unique_ptr<const Column<BinIdxType> > res;
|
||||
if (type_[fid] == ColumnType::kDenseColumn) {
|
||||
res.reset(new DenseColumn<BinIdxType>(type_[fid], bin_index, index_base_[fid],
|
||||
CHECK_EQ(any_missing, any_missing_);
|
||||
res.reset(new DenseColumn<BinIdxType, any_missing>(type_[fid], bin_index, index_base_[fid],
|
||||
missing_flags_, feature_offset));
|
||||
} else {
|
||||
res.reset(new SparseColumn<BinIdxType>(type_[fid], bin_index, index_base_[fid],
|
||||
|
||||
Reference in New Issue
Block a user