Small cleanup to Column. (#7898)
* Define forward iterator to hide the internal state.
This commit is contained in:
parent
ee382c4153
commit
1baad8650c
@ -13,6 +13,7 @@
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <utility> // std::move
|
||||
#include <vector>
|
||||
|
||||
#include "../data/gradient_index.h"
|
||||
@ -32,101 +33,96 @@ enum ColumnType : uint8_t { kDenseColumn, kSparseColumn };
|
||||
template <typename BinIdxType>
|
||||
class Column {
|
||||
public:
|
||||
static constexpr int32_t kMissingId = -1;
|
||||
|
||||
Column(ColumnType type, common::Span<const BinIdxType> index, const bst_bin_t index_base)
|
||||
: type_(type), index_(index), index_base_{index_base} {}
|
||||
static constexpr bst_bin_t kMissingId = -1;
|
||||
|
||||
Column(common::Span<const BinIdxType> index, bst_bin_t least_bin_idx)
|
||||
: index_(index), index_base_(least_bin_idx) {}
|
||||
virtual ~Column() = default;
|
||||
|
||||
uint32_t GetGlobalBinIdx(size_t idx) const {
|
||||
return index_base_ + static_cast<uint32_t>(index_[idx]);
|
||||
bst_bin_t GetGlobalBinIdx(size_t idx) const {
|
||||
return index_base_ + static_cast<bst_bin_t>(index_[idx]);
|
||||
}
|
||||
|
||||
BinIdxType GetFeatureBinIdx(size_t idx) const { return index_[idx]; }
|
||||
|
||||
uint32_t GetBaseIdx() const { return index_base_; }
|
||||
|
||||
common::Span<const BinIdxType> GetFeatureBinIdxPtr() const { return index_; }
|
||||
|
||||
ColumnType GetType() const { return type_; }
|
||||
|
||||
/* returns number of elements in column */
|
||||
size_t Size() const { return index_.size(); }
|
||||
|
||||
private:
|
||||
/* type of column */
|
||||
ColumnType type_;
|
||||
/* bin indexes in range [0, max_bins - 1] */
|
||||
common::Span<const BinIdxType> index_;
|
||||
/* bin index offset for specific feature */
|
||||
bst_bin_t const index_base_;
|
||||
};
|
||||
|
||||
template <typename BinIdxType>
|
||||
class SparseColumn : public Column<BinIdxType> {
|
||||
public:
|
||||
SparseColumn(ColumnType type, common::Span<const BinIdxType> index, bst_bin_t index_base,
|
||||
common::Span<const size_t> row_ind)
|
||||
: Column<BinIdxType>(type, index, index_base), row_ind_(row_ind) {}
|
||||
|
||||
const size_t* GetRowData() const { return row_ind_.data(); }
|
||||
|
||||
bst_bin_t GetBinIdx(size_t rid, size_t* state) const {
|
||||
const size_t column_size = this->Size();
|
||||
if (!((*state) < column_size)) {
|
||||
return this->kMissingId;
|
||||
}
|
||||
while ((*state) < column_size && GetRowIdx(*state) < rid) {
|
||||
++(*state);
|
||||
}
|
||||
if (((*state) < column_size) && GetRowIdx(*state) == rid) {
|
||||
return this->GetGlobalBinIdx(*state);
|
||||
} else {
|
||||
return this->kMissingId;
|
||||
}
|
||||
}
|
||||
|
||||
size_t GetInitialState(const size_t first_row_id) const {
|
||||
const size_t* row_data = GetRowData();
|
||||
const size_t column_size = this->Size();
|
||||
// search first nonzero row with index >= rid_span.front()
|
||||
const size_t* p = std::lower_bound(row_data, row_data + column_size, first_row_id);
|
||||
// column_size if all messing
|
||||
return p - row_data;
|
||||
}
|
||||
|
||||
size_t GetRowIdx(size_t idx) const { return row_ind_.data()[idx]; }
|
||||
|
||||
template <typename BinIdxT>
|
||||
class SparseColumnIter : public Column<BinIdxT> {
|
||||
private:
|
||||
using Base = Column<BinIdxT>;
|
||||
/* indexes of rows */
|
||||
common::Span<const size_t> row_ind_;
|
||||
size_t idx_;
|
||||
|
||||
size_t const* RowIndices() const { return row_ind_.data(); }
|
||||
|
||||
public:
|
||||
SparseColumnIter(common::Span<const BinIdxT> index, bst_bin_t least_bin_idx,
|
||||
common::Span<const size_t> row_ind, bst_row_t first_row_idx)
|
||||
: Base{index, least_bin_idx}, row_ind_(row_ind) {
|
||||
// first_row_id is the first row in the leaf partition
|
||||
const size_t* row_data = RowIndices();
|
||||
const size_t column_size = this->Size();
|
||||
// search first nonzero row with index >= rid_span.front()
|
||||
// note that the input row partition is always sorted.
|
||||
const size_t* p = std::lower_bound(row_data, row_data + column_size, first_row_idx);
|
||||
// column_size if all missing
|
||||
idx_ = p - row_data;
|
||||
}
|
||||
SparseColumnIter(SparseColumnIter const&) = delete;
|
||||
SparseColumnIter(SparseColumnIter&&) = default;
|
||||
|
||||
size_t GetRowIdx(size_t idx) const { return RowIndices()[idx]; }
|
||||
bst_bin_t operator[](size_t rid) {
|
||||
const size_t column_size = this->Size();
|
||||
if (!((idx_) < column_size)) {
|
||||
return this->kMissingId;
|
||||
}
|
||||
// find next non-missing row
|
||||
while ((idx_) < column_size && GetRowIdx(idx_) < rid) {
|
||||
++(idx_);
|
||||
}
|
||||
if (((idx_) < column_size) && GetRowIdx(idx_) == rid) {
|
||||
// non-missing row found
|
||||
return this->GetGlobalBinIdx(idx_);
|
||||
} else {
|
||||
// at the end of column
|
||||
return this->kMissingId;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename BinIdxType, bool any_missing>
|
||||
class DenseColumn : public Column<BinIdxType> {
|
||||
public:
|
||||
DenseColumn(ColumnType type, common::Span<const BinIdxType> index, uint32_t index_base,
|
||||
const std::vector<bool>& missing_flags, size_t feature_offset)
|
||||
: Column<BinIdxType>(type, index, index_base),
|
||||
missing_flags_(missing_flags),
|
||||
feature_offset_(feature_offset) {}
|
||||
bool IsMissing(size_t idx) const { return missing_flags_[feature_offset_ + idx]; }
|
||||
|
||||
int32_t GetBinIdx(size_t idx, size_t* state) const {
|
||||
if (any_missing) {
|
||||
return IsMissing(idx) ? this->kMissingId : this->GetGlobalBinIdx(idx);
|
||||
} else {
|
||||
return this->GetGlobalBinIdx(idx);
|
||||
}
|
||||
}
|
||||
|
||||
size_t GetInitialState(const size_t first_row_id) const { return 0; }
|
||||
|
||||
template <typename BinIdxT, bool any_missing>
|
||||
class DenseColumnIter : public Column<BinIdxT> {
|
||||
private:
|
||||
using Base = Column<BinIdxT>;
|
||||
/* flags for missing values in dense columns */
|
||||
const std::vector<bool>& missing_flags_;
|
||||
std::vector<bool> const& missing_flags_;
|
||||
size_t feature_offset_;
|
||||
|
||||
public:
|
||||
explicit DenseColumnIter(common::Span<const BinIdxT> index, bst_bin_t index_base,
|
||||
std::vector<bool> const& missing_flags, size_t feature_offset)
|
||||
: Base{index, index_base}, missing_flags_{missing_flags}, feature_offset_{feature_offset} {}
|
||||
DenseColumnIter(DenseColumnIter const&) = delete;
|
||||
DenseColumnIter(DenseColumnIter&&) = default;
|
||||
|
||||
bool IsMissing(size_t ridx) const { return missing_flags_[feature_offset_ + ridx]; }
|
||||
|
||||
bst_bin_t operator[](size_t ridx) const {
|
||||
if (any_missing) {
|
||||
return IsMissing(ridx) ? this->kMissingId : this->GetGlobalBinIdx(ridx);
|
||||
} else {
|
||||
return this->GetGlobalBinIdx(ridx);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief a collection of columns, with support for construction from
|
||||
@ -234,27 +230,26 @@ class ColumnMatrix {
|
||||
}
|
||||
}
|
||||
|
||||
/* Fetch an individual column. This code should be used with type swith
|
||||
to determine type of bin id's */
|
||||
template <typename BinIdxType, bool any_missing>
|
||||
std::unique_ptr<const Column<BinIdxType> > GetColumn(unsigned fid) const {
|
||||
CHECK_EQ(sizeof(BinIdxType), bins_type_size_);
|
||||
|
||||
const size_t feature_offset = feature_offsets_[fid]; // to get right place for certain feature
|
||||
const size_t column_size = feature_offsets_[fid + 1] - feature_offset;
|
||||
template <typename BinIdxType>
|
||||
auto SparseColumn(bst_feature_t fidx, bst_row_t first_row_idx) const {
|
||||
const size_t feature_offset = feature_offsets_[fidx]; // to get right place for certain feature
|
||||
const size_t column_size = feature_offsets_[fidx + 1] - feature_offset;
|
||||
common::Span<const BinIdxType> bin_index = {
|
||||
reinterpret_cast<const BinIdxType*>(&index_[feature_offset * bins_type_size_]),
|
||||
column_size};
|
||||
std::unique_ptr<const Column<BinIdxType> > res;
|
||||
if (type_[fid] == ColumnType::kDenseColumn) {
|
||||
CHECK_EQ(any_missing, any_missing_);
|
||||
res.reset(new DenseColumn<BinIdxType, any_missing>(type_[fid], bin_index, index_base_[fid],
|
||||
missing_flags_, feature_offset));
|
||||
} else {
|
||||
res.reset(new SparseColumn<BinIdxType>(type_[fid], bin_index, index_base_[fid],
|
||||
{&row_ind_[feature_offset], column_size}));
|
||||
return SparseColumnIter<BinIdxType>(bin_index, index_base_[fidx],
|
||||
{&row_ind_[feature_offset], column_size}, first_row_idx);
|
||||
}
|
||||
return res;
|
||||
|
||||
template <typename BinIdxType, bool any_missing>
|
||||
auto DenseColumn(bst_feature_t fidx) const {
|
||||
const size_t feature_offset = feature_offsets_[fidx]; // to get right place for certain feature
|
||||
const size_t column_size = feature_offsets_[fidx + 1] - feature_offset;
|
||||
common::Span<const BinIdxType> bin_index = {
|
||||
reinterpret_cast<const BinIdxType*>(&index_[feature_offset * bins_type_size_]),
|
||||
column_size};
|
||||
return std::move(DenseColumnIter<BinIdxType, any_missing>{
|
||||
bin_index, static_cast<bst_bin_t>(index_base_[fidx]), missing_flags_, feature_offset});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -342,6 +337,7 @@ class ColumnMatrix {
|
||||
}
|
||||
|
||||
BinTypeSize GetTypeSize() const { return bins_type_size_; }
|
||||
auto GetColumnType(bst_feature_t fidx) const { return type_[fidx]; }
|
||||
|
||||
// This is just an utility function
|
||||
bool NoMissingValues(const size_t n_elements, const size_t n_row, const size_t n_features) {
|
||||
|
||||
@ -52,23 +52,23 @@ class PartitionBuilder {
|
||||
// Handle dense columns
|
||||
// Analog of std::stable_partition, but in no-inplace manner
|
||||
template <bool default_left, bool any_missing, typename ColumnType, typename Predicate>
|
||||
inline std::pair<size_t, size_t> PartitionKernel(const ColumnType& column,
|
||||
inline std::pair<size_t, size_t> PartitionKernel(ColumnType* p_column,
|
||||
common::Span<const size_t> row_indices,
|
||||
common::Span<size_t> left_part,
|
||||
common::Span<size_t> right_part,
|
||||
size_t base_rowid, Predicate&& pred) {
|
||||
auto& column = *p_column;
|
||||
size_t* p_left_part = left_part.data();
|
||||
size_t* p_right_part = right_part.data();
|
||||
size_t nleft_elems = 0;
|
||||
size_t nright_elems = 0;
|
||||
auto state = column.GetInitialState(row_indices.front() - base_rowid);
|
||||
|
||||
auto p_row_indices = row_indices.data();
|
||||
auto n_samples = row_indices.size();
|
||||
|
||||
for (size_t i = 0; i < n_samples; ++i) {
|
||||
auto rid = p_row_indices[i];
|
||||
const int32_t bin_id = column.GetBinIdx(rid - base_rowid, &state);
|
||||
const int32_t bin_id = column[rid - base_rowid];
|
||||
if (any_missing && bin_id == ColumnType::kMissingId) {
|
||||
if (default_left) {
|
||||
p_left_part[nleft_elems++] = rid;
|
||||
@ -115,8 +115,6 @@ class PartitionBuilder {
|
||||
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
|
||||
const bst_uint fid = tree[nid].SplitIndex();
|
||||
const bool default_left = tree[nid].DefaultLeft();
|
||||
const auto column_ptr = column_matrix.GetColumn<BinIdxType, any_missing>(fid);
|
||||
|
||||
bool is_cat = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
|
||||
auto node_cats = tree.NodeCats(nid);
|
||||
|
||||
@ -146,25 +144,23 @@ class PartitionBuilder {
|
||||
};
|
||||
|
||||
std::pair<size_t, size_t> child_nodes_sizes;
|
||||
if (column_ptr->GetType() == xgboost::common::kDenseColumn) {
|
||||
const common::DenseColumn<BinIdxType, any_missing>& column =
|
||||
static_cast<const common::DenseColumn<BinIdxType, any_missing>& >(*(column_ptr.get()));
|
||||
if (column_matrix.GetColumnType(fid) == xgboost::common::kDenseColumn) {
|
||||
auto column = column_matrix.DenseColumn<BinIdxType, any_missing>(fid);
|
||||
if (default_left) {
|
||||
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span, left, right,
|
||||
child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right,
|
||||
gmat.base_rowid, pred);
|
||||
} else {
|
||||
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span, left, right,
|
||||
child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right,
|
||||
gmat.base_rowid, pred);
|
||||
}
|
||||
} else {
|
||||
CHECK_EQ(any_missing, true);
|
||||
const common::SparseColumn<BinIdxType>& column
|
||||
= static_cast<const common::SparseColumn<BinIdxType>& >(*(column_ptr.get()));
|
||||
auto column = column_matrix.SparseColumn<BinIdxType>(fid, rid_span.front() - gmat.base_rowid);
|
||||
if (default_left) {
|
||||
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span, left, right,
|
||||
child_nodes_sizes = PartitionKernel<true, any_missing>(&column, rid_span, left, right,
|
||||
gmat.base_rowid, pred);
|
||||
} else {
|
||||
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span, left, right,
|
||||
child_nodes_sizes = PartitionKernel<false, any_missing>(&column, rid_span, left, right,
|
||||
gmat.base_rowid, pred);
|
||||
}
|
||||
}
|
||||
|
||||
@ -29,23 +29,17 @@ TEST(DenseColumn, Test) {
|
||||
for (auto j = 0ull; j < dmat->Info().num_col_; j++) {
|
||||
switch (column_matrix.GetTypeSize()) {
|
||||
case kUint8BinsTypeSize: {
|
||||
auto col = column_matrix.GetColumn<uint8_t, false>(j);
|
||||
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j],
|
||||
(*col.get()).GetGlobalBinIdx(i));
|
||||
}
|
||||
break;
|
||||
auto col = column_matrix.DenseColumn<uint8_t, false>(j);
|
||||
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i));
|
||||
} break;
|
||||
case kUint16BinsTypeSize: {
|
||||
auto col = column_matrix.GetColumn<uint16_t, false>(j);
|
||||
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j],
|
||||
(*col.get()).GetGlobalBinIdx(i));
|
||||
}
|
||||
break;
|
||||
auto col = column_matrix.DenseColumn<uint16_t, false>(j);
|
||||
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i));
|
||||
} break;
|
||||
case kUint32BinsTypeSize: {
|
||||
auto col = column_matrix.GetColumn<uint32_t, false>(j);
|
||||
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j],
|
||||
(*col.get()).GetGlobalBinIdx(i));
|
||||
}
|
||||
break;
|
||||
auto col = column_matrix.DenseColumn<uint32_t, false>(j);
|
||||
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i));
|
||||
} break;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -53,12 +47,13 @@ TEST(DenseColumn, Test) {
|
||||
}
|
||||
|
||||
template <typename BinIdxType>
|
||||
inline void CheckSparseColumn(const Column<BinIdxType>& col_input, const GHistIndexMatrix& gmat) {
|
||||
const SparseColumn<BinIdxType>& col = static_cast<const SparseColumn<BinIdxType>& >(col_input);
|
||||
inline void CheckSparseColumn(const SparseColumnIter<BinIdxType>& col_input,
|
||||
const GHistIndexMatrix& gmat) {
|
||||
const SparseColumnIter<BinIdxType>& col =
|
||||
static_cast<const SparseColumnIter<BinIdxType>&>(col_input);
|
||||
ASSERT_EQ(col.Size(), gmat.index.Size());
|
||||
for (auto i = 0ull; i < col.Size(); i++) {
|
||||
ASSERT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]],
|
||||
col.GetGlobalBinIdx(i));
|
||||
ASSERT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]], col.GetGlobalBinIdx(i));
|
||||
}
|
||||
}
|
||||
|
||||
@ -75,32 +70,27 @@ TEST(SparseColumn, Test) {
|
||||
}
|
||||
switch (column_matrix.GetTypeSize()) {
|
||||
case kUint8BinsTypeSize: {
|
||||
auto col = column_matrix.GetColumn<uint8_t, true>(0);
|
||||
CheckSparseColumn(*col.get(), gmat);
|
||||
}
|
||||
break;
|
||||
auto col = column_matrix.SparseColumn<uint8_t>(0, 0);
|
||||
CheckSparseColumn(col, gmat);
|
||||
} break;
|
||||
case kUint16BinsTypeSize: {
|
||||
auto col = column_matrix.GetColumn<uint16_t, true>(0);
|
||||
CheckSparseColumn(*col.get(), gmat);
|
||||
}
|
||||
break;
|
||||
auto col = column_matrix.SparseColumn<uint16_t>(0, 0);
|
||||
CheckSparseColumn(col, gmat);
|
||||
} break;
|
||||
case kUint32BinsTypeSize: {
|
||||
auto col = column_matrix.GetColumn<uint32_t, true>(0);
|
||||
CheckSparseColumn(*col.get(), gmat);
|
||||
}
|
||||
break;
|
||||
auto col = column_matrix.SparseColumn<uint32_t>(0, 0);
|
||||
CheckSparseColumn(col, gmat);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BinIdxType>
|
||||
inline void CheckColumWithMissingValue(const Column<BinIdxType>& col_input,
|
||||
inline void CheckColumWithMissingValue(const DenseColumnIter<BinIdxType, true>& col,
|
||||
const GHistIndexMatrix& gmat) {
|
||||
const DenseColumn<BinIdxType, true>& col = static_cast<const DenseColumn<BinIdxType, true>& >(col_input);
|
||||
for (auto i = 0ull; i < col.Size(); i++) {
|
||||
if (col.IsMissing(i)) continue;
|
||||
EXPECT_EQ(gmat.index[gmat.row_ptr[i]],
|
||||
col.GetGlobalBinIdx(i));
|
||||
EXPECT_EQ(gmat.index[gmat.row_ptr[i]], col.GetGlobalBinIdx(i));
|
||||
}
|
||||
}
|
||||
|
||||
@ -117,20 +107,17 @@ TEST(DenseColumnWithMissing, Test) {
|
||||
}
|
||||
switch (column_matrix.GetTypeSize()) {
|
||||
case kUint8BinsTypeSize: {
|
||||
auto col = column_matrix.GetColumn<uint8_t, true>(0);
|
||||
CheckColumWithMissingValue(*col.get(), gmat);
|
||||
}
|
||||
break;
|
||||
auto col = column_matrix.DenseColumn<uint8_t, true>(0);
|
||||
CheckColumWithMissingValue(col, gmat);
|
||||
} break;
|
||||
case kUint16BinsTypeSize: {
|
||||
auto col = column_matrix.GetColumn<uint16_t, true>(0);
|
||||
CheckColumWithMissingValue(*col.get(), gmat);
|
||||
}
|
||||
break;
|
||||
auto col = column_matrix.DenseColumn<uint16_t, true>(0);
|
||||
CheckColumWithMissingValue(col, gmat);
|
||||
} break;
|
||||
case kUint32BinsTypeSize: {
|
||||
auto col = column_matrix.GetColumn<uint32_t, true>(0);
|
||||
CheckColumWithMissingValue(*col.get(), gmat);
|
||||
}
|
||||
break;
|
||||
auto col = column_matrix.DenseColumn<uint32_t, true>(0);
|
||||
CheckColumWithMissingValue(col, gmat);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user