Unify the partitioner for hist and approx.
Co-authored-by: dmitry.razdoburdin <drazdobu@jfldaal005.jf.intel.com> Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
committed by
GitHub
parent
c69af90319
commit
5bd849f1b5
@@ -103,15 +103,18 @@ class SparseColumnIter : public Column<BinIdxT> {
|
||||
|
||||
template <typename BinIdxT, bool any_missing>
|
||||
class DenseColumnIter : public Column<BinIdxT> {
|
||||
public:
|
||||
using ByteType = bool;
|
||||
|
||||
private:
|
||||
using Base = Column<BinIdxT>;
|
||||
/* flags for missing values in dense columns */
|
||||
std::vector<bool> const& missing_flags_;
|
||||
std::vector<ByteType> const& missing_flags_;
|
||||
size_t feature_offset_;
|
||||
|
||||
public:
|
||||
explicit DenseColumnIter(common::Span<const BinIdxT> index, bst_bin_t index_base,
|
||||
std::vector<bool> const& missing_flags, size_t feature_offset)
|
||||
std::vector<ByteType> const& missing_flags, size_t feature_offset)
|
||||
: Base{index, index_base}, missing_flags_{missing_flags}, feature_offset_{feature_offset} {}
|
||||
DenseColumnIter(DenseColumnIter const&) = delete;
|
||||
DenseColumnIter(DenseColumnIter&&) = default;
|
||||
@@ -153,6 +156,7 @@ class ColumnMatrix {
|
||||
}
|
||||
|
||||
public:
|
||||
using ByteType = bool;
|
||||
// get number of features
|
||||
bst_feature_t GetNumFeature() const { return static_cast<bst_feature_t>(type_.size()); }
|
||||
|
||||
@@ -195,6 +199,8 @@ class ColumnMatrix {
|
||||
}
|
||||
}
|
||||
|
||||
bool IsInitialized() const { return !type_.empty(); }
|
||||
|
||||
/**
|
||||
* \brief Push batch of data for Quantile DMatrix support.
|
||||
*
|
||||
@@ -352,6 +358,13 @@ class ColumnMatrix {
|
||||
|
||||
fi->Read(&row_ind_);
|
||||
fi->Read(&feature_offsets_);
|
||||
|
||||
std::vector<std::uint8_t> missing;
|
||||
fi->Read(&missing);
|
||||
missing_flags_.resize(missing.size());
|
||||
std::transform(missing.cbegin(), missing.cend(), missing_flags_.begin(),
|
||||
[](std::uint8_t flag) { return !!flag; });
|
||||
|
||||
index_base_ = index_base;
|
||||
#if !DMLC_LITTLE_ENDIAN
|
||||
std::underlying_type<BinTypeSize>::type v;
|
||||
@@ -386,6 +399,11 @@ class ColumnMatrix {
|
||||
#endif // !DMLC_LITTLE_ENDIAN
|
||||
write_vec(row_ind_);
|
||||
write_vec(feature_offsets_);
|
||||
// dmlc can not handle bool vector
|
||||
std::vector<std::uint8_t> missing(missing_flags_.size());
|
||||
std::transform(missing_flags_.cbegin(), missing_flags_.cend(), missing.begin(),
|
||||
[](bool flag) { return static_cast<std::uint8_t>(flag); });
|
||||
write_vec(missing);
|
||||
|
||||
#if !DMLC_LITTLE_ENDIAN
|
||||
auto v = static_cast<std::underlying_type<BinTypeSize>::type>(bins_type_size_);
|
||||
@@ -413,7 +431,7 @@ class ColumnMatrix {
|
||||
|
||||
// index_base_[fid]: least bin id for feature fid
|
||||
uint32_t const* index_base_;
|
||||
std::vector<bool> missing_flags_;
|
||||
std::vector<ByteType> missing_flags_;
|
||||
BinTypeSize bins_type_size_;
|
||||
bool any_missing_;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user