Small cleanup to Column. (#7898)

* Define forward iterator to hide the internal state.
This commit is contained in:
Jiaming Yuan 2022-05-15 12:39:10 +08:00 committed by GitHub
parent ee382c4153
commit 1baad8650c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 131 additions and 152 deletions

View File

@ -13,6 +13,7 @@
#include <algorithm>
#include <limits>
#include <memory>
#include <utility> // std::move
#include <vector>
#include "../data/gradient_index.h"
@ -32,101 +33,96 @@ enum ColumnType : uint8_t { kDenseColumn, kSparseColumn };
template <typename BinIdxType>
class Column {
public:
static constexpr int32_t kMissingId = -1;
Column(ColumnType type, common::Span<const BinIdxType> index, const bst_bin_t index_base)
: type_(type), index_(index), index_base_{index_base} {}
static constexpr bst_bin_t kMissingId = -1;
Column(common::Span<const BinIdxType> index, bst_bin_t least_bin_idx)
: index_(index), index_base_(least_bin_idx) {}
virtual ~Column() = default;
uint32_t GetGlobalBinIdx(size_t idx) const {
return index_base_ + static_cast<uint32_t>(index_[idx]);
bst_bin_t GetGlobalBinIdx(size_t idx) const {
return index_base_ + static_cast<bst_bin_t>(index_[idx]);
}
BinIdxType GetFeatureBinIdx(size_t idx) const { return index_[idx]; }
uint32_t GetBaseIdx() const { return index_base_; }
common::Span<const BinIdxType> GetFeatureBinIdxPtr() const { return index_; }
ColumnType GetType() const { return type_; }
/* returns number of elements in column */
size_t Size() const { return index_.size(); }
private:
/* type of column */
ColumnType type_;
/* bin indexes in range [0, max_bins - 1] */
common::Span<const BinIdxType> index_;
/* bin index offset for specific feature */
bst_bin_t const index_base_;
};
template <typename BinIdxType>
class SparseColumn : public Column<BinIdxType> {
public:
SparseColumn(ColumnType type, common::Span<const BinIdxType> index, bst_bin_t index_base,
common::Span<const size_t> row_ind)
: Column<BinIdxType>(type, index, index_base), row_ind_(row_ind) {}
const size_t* GetRowData() const { return row_ind_.data(); }
bst_bin_t GetBinIdx(size_t rid, size_t* state) const {
const size_t column_size = this->Size();
if (!((*state) < column_size)) {
return this->kMissingId;
}
while ((*state) < column_size && GetRowIdx(*state) < rid) {
++(*state);
}
if (((*state) < column_size) && GetRowIdx(*state) == rid) {
return this->GetGlobalBinIdx(*state);
} else {
return this->kMissingId;
}
}
size_t GetInitialState(const size_t first_row_id) const {
const size_t* row_data = GetRowData();
const size_t column_size = this->Size();
// search first nonzero row with index >= rid_span.front()
const size_t* p = std::lower_bound(row_data, row_data + column_size, first_row_id);
// column_size if all messing
return p - row_data;
}
size_t GetRowIdx(size_t idx) const { return row_ind_.data()[idx]; }
template <typename BinIdxT>
class SparseColumnIter : public Column<BinIdxT> {
private:
using Base = Column<BinIdxT>;
/* indexes of rows */
common::Span<const size_t> row_ind_;
size_t idx_;
size_t const* RowIndices() const { return row_ind_.data(); }
public:
SparseColumnIter(common::Span<const BinIdxT> index, bst_bin_t least_bin_idx,
common::Span<const size_t> row_ind, bst_row_t first_row_idx)
: Base{index, least_bin_idx}, row_ind_(row_ind) {
// first_row_id is the first row in the leaf partition
const size_t* row_data = RowIndices();
const size_t column_size = this->Size();
// search first nonzero row with index >= rid_span.front()
// note that the input row partition is always sorted.
const size_t* p = std::lower_bound(row_data, row_data + column_size, first_row_idx);
// column_size if all missing
idx_ = p - row_data;
}
SparseColumnIter(SparseColumnIter const&) = delete;
SparseColumnIter(SparseColumnIter&&) = default;
size_t GetRowIdx(size_t idx) const { return RowIndices()[idx]; }
bst_bin_t operator[](size_t rid) {
const size_t column_size = this->Size();
if (!((idx_) < column_size)) {
return this->kMissingId;
}
// find next non-missing row
while ((idx_) < column_size && GetRowIdx(idx_) < rid) {
++(idx_);
}
if (((idx_) < column_size) && GetRowIdx(idx_) == rid) {
// non-missing row found
return this->GetGlobalBinIdx(idx_);
} else {
// at the end of column
return this->kMissingId;
}
}
};
template <typename BinIdxType, bool any_missing>
class DenseColumn : public Column<BinIdxType> {
public:
DenseColumn(ColumnType type, common::Span<const BinIdxType> index, uint32_t index_base,
const std::vector<bool>& missing_flags, size_t feature_offset)
: Column<BinIdxType>(type, index, index_base),
missing_flags_(missing_flags),
feature_offset_(feature_offset) {}
bool IsMissing(size_t idx) const { return missing_flags_[feature_offset_ + idx]; }
int32_t GetBinIdx(size_t idx, size_t* state) const {
if (any_missing) {
return IsMissing(idx) ? this->kMissingId : this->GetGlobalBinIdx(idx);
} else {
return this->GetGlobalBinIdx(idx);
}
}
size_t GetInitialState(const size_t first_row_id) const { return 0; }
template <typename BinIdxT, bool any_missing>
class DenseColumnIter : public Column<BinIdxT> {
private:
using Base = Column<BinIdxT>;
/* flags for missing values in dense columns */
const std::vector<bool>& missing_flags_;
std::vector<bool> const& missing_flags_;
size_t feature_offset_;
public:
explicit DenseColumnIter(common::Span<const BinIdxT> index, bst_bin_t index_base,
std::vector<bool> const& missing_flags, size_t feature_offset)
: Base{index, index_base}, missing_flags_{missing_flags}, feature_offset_{feature_offset} {}
DenseColumnIter(DenseColumnIter const&) = delete;
DenseColumnIter(DenseColumnIter&&) = default;
bool IsMissing(size_t ridx) const { return missing_flags_[feature_offset_ + ridx]; }
bst_bin_t operator[](size_t ridx) const {
if (any_missing) {
return IsMissing(ridx) ? this->kMissingId : this->GetGlobalBinIdx(ridx);
} else {
return this->GetGlobalBinIdx(ridx);
}
}
};
/*! \brief a collection of columns, with support for construction from
@ -234,27 +230,26 @@ class ColumnMatrix {
}
}
/* Fetch an individual column. This code should be used with type swith
to determine type of bin id's */
template <typename BinIdxType, bool any_missing>
std::unique_ptr<const Column<BinIdxType> > GetColumn(unsigned fid) const {
CHECK_EQ(sizeof(BinIdxType), bins_type_size_);
const size_t feature_offset = feature_offsets_[fid]; // to get right place for certain feature
const size_t column_size = feature_offsets_[fid + 1] - feature_offset;
template <typename BinIdxType>
auto SparseColumn(bst_feature_t fidx, bst_row_t first_row_idx) const {
const size_t feature_offset = feature_offsets_[fidx]; // to get right place for certain feature
const size_t column_size = feature_offsets_[fidx + 1] - feature_offset;
common::Span<const BinIdxType> bin_index = {
reinterpret_cast<const BinIdxType*>(&index_[feature_offset * bins_type_size_]),
column_size};
std::unique_ptr<const Column<BinIdxType> > res;
if (type_[fid] == ColumnType::kDenseColumn) {
CHECK_EQ(any_missing, any_missing_);
res.reset(new DenseColumn<BinIdxType, any_missing>(type_[fid], bin_index, index_base_[fid],
missing_flags_, feature_offset));
} else {
res.reset(new SparseColumn<BinIdxType>(type_[fid], bin_index, index_base_[fid],
{&row_ind_[feature_offset], column_size}));
return SparseColumnIter<BinIdxType>(bin_index, index_base_[fidx],
{&row_ind_[feature_offset], column_size}, first_row_idx);
}
return res;
template <typename BinIdxType, bool any_missing>
auto DenseColumn(bst_feature_t fidx) const {
const size_t feature_offset = feature_offsets_[fidx]; // to get right place for certain feature
const size_t column_size = feature_offsets_[fidx + 1] - feature_offset;
common::Span<const BinIdxType> bin_index = {
reinterpret_cast<const BinIdxType*>(&index_[feature_offset * bins_type_size_]),
column_size};
return std::move(DenseColumnIter<BinIdxType, any_missing>{
bin_index, static_cast<bst_bin_t>(index_base_[fidx]), missing_flags_, feature_offset});
}
template <typename T>
@ -342,6 +337,7 @@ class ColumnMatrix {
}
BinTypeSize GetTypeSize() const { return bins_type_size_; }
auto GetColumnType(bst_feature_t fidx) const { return type_[fidx]; }
// This is just an utility function
bool NoMissingValues(const size_t n_elements, const size_t n_row, const size_t n_features) {

View File

@ -52,23 +52,23 @@ class PartitionBuilder {
// Handle dense columns
// Analog of std::stable_partition, but in no-inplace manner
template <bool default_left, bool any_missing, typename ColumnType, typename Predicate>
inline std::pair<size_t, size_t> PartitionKernel(const ColumnType& column,
inline std::pair<size_t, size_t> PartitionKernel(ColumnType* p_column,
common::Span<const size_t> row_indices,
common::Span<size_t> left_part,
common::Span<size_t> right_part,
size_t base_rowid, Predicate&& pred) {
auto& column = *p_column;
size_t* p_left_part = left_part.data();
size_t* p_right_part = right_part.data();
size_t nleft_elems = 0;
size_t nright_elems = 0;
auto state = column.GetInitialState(row_indices.front() - base_rowid);
auto p_row_indices = row_indices.data();
auto n_samples = row_indices.size();
for (size_t i = 0; i < n_samples; ++i) {
auto rid = p_row_indices[i];
const int32_t bin_id = column.GetBinIdx(rid - base_rowid, &state);
const int32_t bin_id = column[rid - base_rowid];
if (any_missing && bin_id == ColumnType::kMissingId) {
if (default_left) {
p_left_part[nleft_elems++] = rid;
@ -115,8 +115,6 @@ class PartitionBuilder {
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
const bst_uint fid = tree[nid].SplitIndex();
const bool default_left = tree[nid].DefaultLeft();
const auto column_ptr = column_matrix.GetColumn<BinIdxType, any_missing>(fid);
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
auto node_cats = tree.NodeCats(nid);
@ -146,25 +144,23 @@ class PartitionBuilder {
};
std::pair<size_t, size_t> child_nodes_sizes;
if (column_ptr->GetType() == xgboost::common::kDenseColumn) {
const common::DenseColumn<BinIdxType, any_missing>& column =
static_cast<const common::DenseColumn<BinIdxType, any_missing>& >(*(column_ptr.get()));
if (column_matrix.GetColumnType(fid) == xgboost::common::kDenseColumn) {
auto column = column_matrix.DenseColumn<BinIdxType, any_missing>(fid);
if (default_left) {
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span, left, right,
child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right,
gmat.base_rowid, pred);
} else {
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span, left, right,
child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right,
gmat.base_rowid, pred);
}
} else {
CHECK_EQ(any_missing, true);
const common::SparseColumn<BinIdxType>& column
= static_cast<const common::SparseColumn<BinIdxType>& >(*(column_ptr.get()));
auto column = column_matrix.SparseColumn<BinIdxType>(fid, rid_span.front() - gmat.base_rowid);
if (default_left) {
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span, left, right,
child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right,
gmat.base_rowid, pred);
} else {
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span, left, right,
child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right,
gmat.base_rowid, pred);
}
}

View File

@ -29,23 +29,17 @@ TEST(DenseColumn, Test) {
for (auto j = 0ull; j < dmat->Info().num_col_; j++) {
switch (column_matrix.GetTypeSize()) {
case kUint8BinsTypeSize: {
auto col = column_matrix.GetColumn<uint8_t, false>(j);
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j],
(*col.get()).GetGlobalBinIdx(i));
}
break;
auto col = column_matrix.DenseColumn<uint8_t, false>(j);
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i));
} break;
case kUint16BinsTypeSize: {
auto col = column_matrix.GetColumn<uint16_t, false>(j);
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j],
(*col.get()).GetGlobalBinIdx(i));
}
break;
auto col = column_matrix.DenseColumn<uint16_t, false>(j);
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i));
} break;
case kUint32BinsTypeSize: {
auto col = column_matrix.GetColumn<uint32_t, false>(j);
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j],
(*col.get()).GetGlobalBinIdx(i));
}
break;
auto col = column_matrix.DenseColumn<uint32_t, false>(j);
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i));
} break;
}
}
}
@ -53,12 +47,13 @@ TEST(DenseColumn, Test) {
}
template <typename BinIdxType>
inline void CheckSparseColumn(const Column<BinIdxType>& col_input, const GHistIndexMatrix& gmat) {
const SparseColumn<BinIdxType>& col = static_cast<const SparseColumn<BinIdxType>& >(col_input);
inline void CheckSparseColumn(const SparseColumnIter<BinIdxType>& col_input,
const GHistIndexMatrix& gmat) {
const SparseColumnIter<BinIdxType>& col =
static_cast<const SparseColumnIter<BinIdxType>&>(col_input);
ASSERT_EQ(col.Size(), gmat.index.Size());
for (auto i = 0ull; i < col.Size(); i++) {
ASSERT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]],
col.GetGlobalBinIdx(i));
ASSERT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]], col.GetGlobalBinIdx(i));
}
}
@ -75,32 +70,27 @@ TEST(SparseColumn, Test) {
}
switch (column_matrix.GetTypeSize()) {
case kUint8BinsTypeSize: {
auto col = column_matrix.GetColumn<uint8_t, true>(0);
CheckSparseColumn(*col.get(), gmat);
}
break;
auto col = column_matrix.SparseColumn<uint8_t>(0, 0);
CheckSparseColumn(col, gmat);
} break;
case kUint16BinsTypeSize: {
auto col = column_matrix.GetColumn<uint16_t, true>(0);
CheckSparseColumn(*col.get(), gmat);
}
break;
auto col = column_matrix.SparseColumn<uint16_t>(0, 0);
CheckSparseColumn(col, gmat);
} break;
case kUint32BinsTypeSize: {
auto col = column_matrix.GetColumn<uint32_t, true>(0);
CheckSparseColumn(*col.get(), gmat);
}
break;
auto col = column_matrix.SparseColumn<uint32_t>(0, 0);
CheckSparseColumn(col, gmat);
} break;
}
}
}
template <typename BinIdxType>
inline void CheckColumWithMissingValue(const Column<BinIdxType>& col_input,
inline void CheckColumWithMissingValue(const DenseColumnIter<BinIdxType, true>& col,
const GHistIndexMatrix& gmat) {
const DenseColumn<BinIdxType, true>& col = static_cast<const DenseColumn<BinIdxType, true>& >(col_input);
for (auto i = 0ull; i < col.Size(); i++) {
if (col.IsMissing(i)) continue;
EXPECT_EQ(gmat.index[gmat.row_ptr[i]],
col.GetGlobalBinIdx(i));
EXPECT_EQ(gmat.index[gmat.row_ptr[i]], col.GetGlobalBinIdx(i));
}
}
@ -117,20 +107,17 @@ TEST(DenseColumnWithMissing, Test) {
}
switch (column_matrix.GetTypeSize()) {
case kUint8BinsTypeSize: {
auto col = column_matrix.GetColumn<uint8_t, true>(0);
CheckColumWithMissingValue(*col.get(), gmat);
}
break;
auto col = column_matrix.DenseColumn<uint8_t, true>(0);
CheckColumWithMissingValue(col, gmat);
} break;
case kUint16BinsTypeSize: {
auto col = column_matrix.GetColumn<uint16_t, true>(0);
CheckColumWithMissingValue(*col.get(), gmat);
}
break;
auto col = column_matrix.DenseColumn<uint16_t, true>(0);
CheckColumWithMissingValue(col, gmat);
} break;
case kUint32BinsTypeSize: {
auto col = column_matrix.GetColumn<uint32_t, true>(0);
CheckColumWithMissingValue(*col.get(), gmat);
}
break;
auto col = column_matrix.DenseColumn<uint32_t, true>(0);
CheckColumWithMissingValue(col, gmat);
} break;
}
}
}