Unify the partitioner for hist and approx.

Co-authored-by: dmitry.razdoburdin <drazdobu@jfldaal005.jf.intel.com>
Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
Dmitry Razdoburdin
2022-10-19 20:49:20 +02:00
committed by GitHub
parent c69af90319
commit 5bd849f1b5
13 changed files with 358 additions and 450 deletions

View File

@@ -103,15 +103,18 @@ class SparseColumnIter : public Column<BinIdxT> {
template <typename BinIdxT, bool any_missing>
class DenseColumnIter : public Column<BinIdxT> {
public:
using ByteType = bool;
private:
using Base = Column<BinIdxT>;
/* flags for missing values in dense columns */
std::vector<bool> const& missing_flags_;
std::vector<ByteType> const& missing_flags_;
size_t feature_offset_;
public:
explicit DenseColumnIter(common::Span<const BinIdxT> index, bst_bin_t index_base,
std::vector<bool> const& missing_flags, size_t feature_offset)
std::vector<ByteType> const& missing_flags, size_t feature_offset)
: Base{index, index_base}, missing_flags_{missing_flags}, feature_offset_{feature_offset} {}
DenseColumnIter(DenseColumnIter const&) = delete;
DenseColumnIter(DenseColumnIter&&) = default;
@@ -153,6 +156,7 @@ class ColumnMatrix {
}
public:
using ByteType = bool;
// get number of features
bst_feature_t GetNumFeature() const { return static_cast<bst_feature_t>(type_.size()); }
@@ -195,6 +199,8 @@ class ColumnMatrix {
}
}
bool IsInitialized() const { return !type_.empty(); }
/**
* \brief Push batch of data for Quantile DMatrix support.
*
@@ -352,6 +358,13 @@ class ColumnMatrix {
fi->Read(&row_ind_);
fi->Read(&feature_offsets_);
std::vector<std::uint8_t> missing;
fi->Read(&missing);
missing_flags_.resize(missing.size());
std::transform(missing.cbegin(), missing.cend(), missing_flags_.begin(),
[](std::uint8_t flag) { return !!flag; });
index_base_ = index_base;
#if !DMLC_LITTLE_ENDIAN
std::underlying_type<BinTypeSize>::type v;
@@ -386,6 +399,11 @@ class ColumnMatrix {
#endif // !DMLC_LITTLE_ENDIAN
write_vec(row_ind_);
write_vec(feature_offsets_);
// dmlc can not handle bool vector
std::vector<std::uint8_t> missing(missing_flags_.size());
std::transform(missing_flags_.cbegin(), missing_flags_.cend(), missing.begin(),
[](bool flag) { return static_cast<std::uint8_t>(flag); });
write_vec(missing);
#if !DMLC_LITTLE_ENDIAN
auto v = static_cast<std::underlying_type<BinTypeSize>::type>(bins_type_size_);
@@ -413,7 +431,7 @@ class ColumnMatrix {
// index_base_[fid]: least bin id for feature fid
uint32_t const* index_base_;
std::vector<bool> missing_flags_;
std::vector<ByteType> missing_flags_;
BinTypeSize bins_type_size_;
bool any_missing_;
};