Fix and cleanup for column matrix. (#7901)

* Fix missed type dispatching for dense columns with missing values.
* Code cleanup to reduce special cases.
* Reduce memory usage.
This commit is contained in:
Jiaming Yuan 2022-05-16 21:11:50 +08:00 committed by GitHub
parent 1496789561
commit 4fcfd9c96e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 124 additions and 136 deletions

View File

@ -125,16 +125,20 @@ class DenseColumnIter : public Column<BinIdxT> {
} }
}; };
/*! \brief a collection of columns, with support for construction from /**
GHistIndexMatrix. */ * \brief Column major matrix for gradient index. This matrix contains both dense column
* and sparse column, the type of the column is controlled by sparse threshold. When the
* number of missing values in a column is below the threshold it classified as dense
* column.
*/
class ColumnMatrix { class ColumnMatrix {
public: public:
// get number of features // get number of features
bst_feature_t GetNumFeature() const { return static_cast<bst_feature_t>(type_.size()); } bst_feature_t GetNumFeature() const { return static_cast<bst_feature_t>(type_.size()); }
// construct column matrix from GHistIndexMatrix // construct column matrix from GHistIndexMatrix
inline void Init(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold, void Init(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold,
int32_t n_threads) { int32_t n_threads) {
auto const nfeature = static_cast<bst_feature_t>(gmat.cut.Ptrs().size() - 1); auto const nfeature = static_cast<bst_feature_t>(gmat.cut.Ptrs().size() - 1);
const size_t nrow = gmat.row_ptr.size() - 1; const size_t nrow = gmat.row_ptr.size() - 1;
// identify type of each column // identify type of each column
@ -145,13 +149,14 @@ class ColumnMatrix {
for (bst_feature_t fid = 0; fid < nfeature; ++fid) { for (bst_feature_t fid = 0; fid < nfeature; ++fid) {
CHECK_LE(gmat.cut.Ptrs()[fid + 1] - gmat.cut.Ptrs()[fid], max_val); CHECK_LE(gmat.cut.Ptrs()[fid + 1] - gmat.cut.Ptrs()[fid], max_val);
} }
bool all_dense = gmat.IsDense();
bool all_dense_column = true;
gmat.GetFeatureCounts(&feature_counts_[0]); gmat.GetFeatureCounts(&feature_counts_[0]);
// classify features // classify features
for (bst_feature_t fid = 0; fid < nfeature; ++fid) { for (bst_feature_t fid = 0; fid < nfeature; ++fid) {
if (static_cast<double>(feature_counts_[fid]) < sparse_threshold * nrow) { if (static_cast<double>(feature_counts_[fid]) < sparse_threshold * nrow) {
type_[fid] = kSparseColumn; type_[fid] = kSparseColumn;
all_dense = false; all_dense_column = false;
} else { } else {
type_[fid] = kDenseColumn; type_[fid] = kDenseColumn;
} }
@ -160,70 +165,51 @@ class ColumnMatrix {
// want to compute storage boundary for each feature // want to compute storage boundary for each feature
// using variants of prefix sum scan // using variants of prefix sum scan
feature_offsets_.resize(nfeature + 1); feature_offsets_.resize(nfeature + 1);
size_t accum_index_ = 0; size_t accum_index = 0;
feature_offsets_[0] = accum_index_; feature_offsets_[0] = accum_index;
for (bst_feature_t fid = 1; fid < nfeature + 1; ++fid) { for (bst_feature_t fid = 1; fid < nfeature + 1; ++fid) {
if (type_[fid - 1] == kDenseColumn) { if (type_[fid - 1] == kDenseColumn) {
accum_index_ += static_cast<size_t>(nrow); accum_index += static_cast<size_t>(nrow);
} else { } else {
accum_index_ += feature_counts_[fid - 1]; accum_index += feature_counts_[fid - 1];
} }
feature_offsets_[fid] = accum_index_; feature_offsets_[fid] = accum_index;
} }
SetTypeSize(gmat.max_num_bins); SetTypeSize(gmat.max_num_bins);
auto storage_size =
index_.resize(feature_offsets_[nfeature] * bins_type_size_, 0); feature_offsets_.back() * static_cast<std::underlying_type_t<BinTypeSize>>(bins_type_size_);
if (!all_dense) { index_.resize(storage_size, 0);
if (!all_dense_column) {
row_ind_.resize(feature_offsets_[nfeature]); row_ind_.resize(feature_offsets_[nfeature]);
} }
// store least bin id for each feature // store least bin id for each feature
index_base_ = const_cast<uint32_t*>(gmat.cut.Ptrs().data()); index_base_ = const_cast<uint32_t*>(gmat.cut.Ptrs().data());
const bool noMissingValues = NoMissingValues(gmat.row_ptr[nrow], nrow, nfeature); any_missing_ = !gmat.IsDense();
any_missing_ = !noMissingValues;
missing_flags_.clear(); missing_flags_.clear();
if (noMissingValues) { // pre-fill index_ for dense columns
BinTypeSize gmat_bin_size = gmat.index.GetBinTypeSize();
if (!any_missing_) {
missing_flags_.resize(feature_offsets_[nfeature], false); missing_flags_.resize(feature_offsets_[nfeature], false);
// row index is compressed, we need to dispatch it.
DispatchBinType(gmat_bin_size, [&, nrow, nfeature, n_threads](auto t) {
using RowBinIdxT = decltype(t);
SetIndexNoMissing(page, gmat.index.data<RowBinIdxT>(), nrow, nfeature, n_threads);
});
} else { } else {
missing_flags_.resize(feature_offsets_[nfeature], true); missing_flags_.resize(feature_offsets_[nfeature], true);
} SetIndexMixedColumns(page, gmat.index.data<uint32_t>(), gmat, nfeature);
// pre-fill index_ for dense columns
if (all_dense) {
BinTypeSize gmat_bin_size = gmat.index.GetBinTypeSize();
if (gmat_bin_size == kUint8BinsTypeSize) {
SetIndexAllDense(page, gmat.index.data<uint8_t>(), gmat, nrow, nfeature, noMissingValues,
n_threads);
} else if (gmat_bin_size == kUint16BinsTypeSize) {
SetIndexAllDense(page, gmat.index.data<uint16_t>(), gmat, nrow, nfeature, noMissingValues,
n_threads);
} else {
CHECK_EQ(gmat_bin_size, kUint32BinsTypeSize);
SetIndexAllDense(page, gmat.index.data<uint32_t>(), gmat, nrow, nfeature, noMissingValues,
n_threads);
}
/* For sparse DMatrix gmat.index.getBinTypeSize() returns always kUint32BinsTypeSize
but for ColumnMatrix we still have a chance to reduce the memory consumption */
} else {
if (bins_type_size_ == kUint8BinsTypeSize) {
SetIndex<uint8_t>(page, gmat.index.data<uint32_t>(), gmat, nfeature);
} else if (bins_type_size_ == kUint16BinsTypeSize) {
SetIndex<uint16_t>(page, gmat.index.data<uint32_t>(), gmat, nfeature);
} else {
CHECK_EQ(bins_type_size_, kUint32BinsTypeSize);
SetIndex<uint32_t>(page, gmat.index.data<uint32_t>(), gmat, nfeature);
}
} }
} }
/* Set the number of bytes based on numeric limit of maximum number of bins provided by user */ /* Set the number of bytes based on numeric limit of maximum number of bins provided by user */
void SetTypeSize(size_t max_num_bins) { void SetTypeSize(size_t max_bin_per_feat) {
if ((max_num_bins - 1) <= static_cast<int>(std::numeric_limits<uint8_t>::max())) { if ((max_bin_per_feat - 1) <= static_cast<int>(std::numeric_limits<uint8_t>::max())) {
bins_type_size_ = kUint8BinsTypeSize; bins_type_size_ = kUint8BinsTypeSize;
} else if ((max_num_bins - 1) <= static_cast<int>(std::numeric_limits<uint16_t>::max())) { } else if ((max_bin_per_feat - 1) <= static_cast<int>(std::numeric_limits<uint16_t>::max())) {
bins_type_size_ = kUint16BinsTypeSize; bins_type_size_ = kUint16BinsTypeSize;
} else { } else {
bins_type_size_ = kUint32BinsTypeSize; bins_type_size_ = kUint32BinsTypeSize;
@ -252,98 +238,78 @@ class ColumnMatrix {
bin_index, static_cast<bst_bin_t>(index_base_[fidx]), missing_flags_, feature_offset}); bin_index, static_cast<bst_bin_t>(index_base_[fidx]), missing_flags_, feature_offset});
} }
template <typename T> // all columns are dense column and has no missing value
inline void SetIndexAllDense(SparsePage const& page, T const* index, const GHistIndexMatrix& gmat, // FIXME(jiamingy): We don't need a column matrix if there's no missing value.
const size_t nrow, const size_t nfeature, const bool noMissingValues, template <typename RowBinIdxT>
int32_t n_threads) { void SetIndexNoMissing(SparsePage const& page, RowBinIdxT const* row_index,
T* local_index = reinterpret_cast<T*>(&index_[0]); const size_t n_samples, const size_t n_features, int32_t n_threads) {
DispatchBinType(bins_type_size_, [&](auto t) {
/* missing values make sense only for column with type kDenseColumn, using ColumnBinT = decltype(t);
and if no missing values were observed it could be handled much faster. */ auto column_index = Span<ColumnBinT>{reinterpret_cast<ColumnBinT*>(index_.data()),
if (noMissingValues) { index_.size() / sizeof(ColumnBinT)};
ParallelFor(nrow, n_threads, [&](auto rid) { ParallelFor(n_samples, n_threads, [&](auto rid) {
const size_t ibegin = rid * nfeature; const size_t ibegin = rid * n_features;
const size_t iend = (rid + 1) * nfeature; const size_t iend = (rid + 1) * n_features;
size_t j = 0; size_t j = 0;
for (size_t i = ibegin; i < iend; ++i, ++j) { for (size_t i = ibegin; i < iend; ++i, ++j) {
const size_t idx = feature_offsets_[j]; const size_t idx = feature_offsets_[j];
local_index[idx + rid] = index[i]; // No need to add offset, as row index is compressed and stores the local index
column_index[idx + rid] = row_index[i];
} }
}); });
} else { });
/* to handle rows in all batches, sum of all batch sizes equal to gmat.row_ptr.size() - 1 */
auto get_bin_idx = [&](auto bin_id, auto rid, bst_feature_t fid) {
// T* begin = &local_index[feature_offsets_[fid]];
const size_t idx = feature_offsets_[fid];
/* rbegin allows to store indexes from specific SparsePage batch */
local_index[idx + rid] = bin_id;
missing_flags_[idx + rid] = false;
};
this->SetIndexSparse(page, index, gmat, nfeature, get_bin_idx);
}
} }
// FIXME(jiamingy): In the future we might want to simply use binary search to simplify /**
// this and remove the dependency on SparsePage. This way we can have quantilized * \brief Set column index for both dense and sparse columns
// matrix for host similar to `DeviceQuantileDMatrix`. */
template <typename T, typename BinFn> void SetIndexMixedColumns(SparsePage const& page, uint32_t const* row_index,
void SetIndexSparse(SparsePage const& batch, T* index, const GHistIndexMatrix& gmat, const GHistIndexMatrix& gmat, size_t n_features) {
const size_t nfeature, BinFn&& assign_bin) {
std::vector<size_t> num_nonzeros(nfeature, 0ul);
const xgboost::Entry* data_ptr = batch.data.HostVector().data();
const std::vector<bst_row_t>& offset_vec = batch.offset.HostVector();
auto rbegin = 0;
const size_t batch_size = gmat.Size();
CHECK_LT(batch_size, offset_vec.size());
for (size_t rid = 0; rid < batch_size; ++rid) {
const size_t ibegin = gmat.row_ptr[rbegin + rid];
const size_t iend = gmat.row_ptr[rbegin + rid + 1];
const size_t size = offset_vec[rid + 1] - offset_vec[rid];
SparsePage::Inst inst = {data_ptr + offset_vec[rid], size};
CHECK_EQ(ibegin + inst.size(), iend);
size_t j = 0;
for (size_t i = ibegin; i < iend; ++i, ++j) {
const uint32_t bin_id = index[i];
auto fid = inst[j].index;
assign_bin(bin_id, rid, fid);
}
}
}
template <typename T>
inline void SetIndex(SparsePage const& page, uint32_t const* index, const GHistIndexMatrix& gmat,
const size_t nfeature) {
T* local_index = reinterpret_cast<T*>(&index_[0]);
std::vector<size_t> num_nonzeros; std::vector<size_t> num_nonzeros;
num_nonzeros.resize(nfeature); num_nonzeros.resize(n_features, 0);
std::fill(num_nonzeros.begin(), num_nonzeros.end(), 0);
auto get_bin_idx = [&](auto bin_id, auto rid, bst_feature_t fid) { DispatchBinType(bins_type_size_, [&](auto t) {
if (type_[fid] == kDenseColumn) { using ColumnBinT = decltype(t);
T* begin = &local_index[feature_offsets_[fid]]; ColumnBinT* local_index = reinterpret_cast<ColumnBinT*>(index_.data());
begin[rid] = bin_id - index_base_[fid];
missing_flags_[feature_offsets_[fid] + rid] = false; auto get_bin_idx = [&](auto bin_id, auto rid, bst_feature_t fid) {
} else { if (type_[fid] == kDenseColumn) {
T* begin = &local_index[feature_offsets_[fid]]; ColumnBinT* begin = &local_index[feature_offsets_[fid]];
begin[num_nonzeros[fid]] = bin_id - index_base_[fid]; begin[rid] = bin_id - index_base_[fid];
row_ind_[feature_offsets_[fid] + num_nonzeros[fid]] = rid; // not thread-safe with bool vector.
++num_nonzeros[fid]; missing_flags_[feature_offsets_[fid] + rid] = false;
} else {
ColumnBinT* begin = &local_index[feature_offsets_[fid]];
begin[num_nonzeros[fid]] = bin_id - index_base_[fid];
row_ind_[feature_offsets_[fid] + num_nonzeros[fid]] = rid;
++num_nonzeros[fid];
}
};
const xgboost::Entry* data_ptr = page.data.HostVector().data();
const std::vector<bst_row_t>& offset_vec = page.offset.HostVector();
const size_t batch_size = gmat.Size();
CHECK_LT(batch_size, offset_vec.size());
for (size_t rid = 0; rid < batch_size; ++rid) {
const size_t ibegin = gmat.row_ptr[rid];
const size_t iend = gmat.row_ptr[rid + 1];
const size_t size = offset_vec[rid + 1] - offset_vec[rid];
SparsePage::Inst inst = {data_ptr + offset_vec[rid], size};
CHECK_EQ(ibegin + inst.size(), iend);
size_t j = 0;
for (size_t i = ibegin; i < iend; ++i, ++j) {
const uint32_t bin_id = row_index[i];
auto fid = inst[j].index;
get_bin_idx(bin_id, rid, fid);
}
} }
}; });
this->SetIndexSparse(page, index, gmat, nfeature, get_bin_idx);
} }
BinTypeSize GetTypeSize() const { return bins_type_size_; } BinTypeSize GetTypeSize() const { return bins_type_size_; }
auto GetColumnType(bst_feature_t fidx) const { return type_[fidx]; } auto GetColumnType(bst_feature_t fidx) const { return type_[fidx]; }
// This is just an utility function
bool NoMissingValues(const size_t n_elements, const size_t n_row, const size_t n_features) {
return n_elements == n_features * n_row;
}
// And this returns part of state // And this returns part of state
bool AnyMissing() const { return any_missing_; } bool AnyMissing() const { return any_missing_; }

View File

@ -113,7 +113,7 @@ class HistogramCuts {
auto end = ptrs[column_id + 1]; auto end = ptrs[column_id + 1];
auto beg = ptrs[column_id]; auto beg = ptrs[column_id];
auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value); auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value);
bst_bin_t idx = it - values.cbegin(); auto idx = it - values.cbegin();
idx -= !!(idx == end); idx -= !!(idx == end);
return idx; return idx;
} }
@ -189,12 +189,30 @@ inline HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, int32_t n_thr
return out; return out;
} }
enum BinTypeSize : uint32_t { enum BinTypeSize : uint8_t {
kUint8BinsTypeSize = 1, kUint8BinsTypeSize = 1,
kUint16BinsTypeSize = 2, kUint16BinsTypeSize = 2,
kUint32BinsTypeSize = 4 kUint32BinsTypeSize = 4
}; };
/**
* \brief Dispatch for bin type, fn is a function that accepts a scalar of the bin type.
*/
template <typename Fn>
auto DispatchBinType(BinTypeSize type, Fn&& fn) {
switch (type) {
case kUint8BinsTypeSize: {
return fn(uint8_t{});
}
case kUint16BinsTypeSize: {
return fn(uint16_t{});
}
case kUint32BinsTypeSize: {
return fn(uint32_t{});
}
}
}
/** /**
* \brief Optionally compressed gradient index. The compression works only with dense * \brief Optionally compressed gradient index. The compression works only with dense
* data. * data.

View File

@ -108,7 +108,7 @@ class PartitionBuilder {
template <typename BinIdxType, bool any_missing, bool any_cat> template <typename BinIdxType, bool any_missing, bool any_cat>
void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range, void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range,
const int32_t split_cond, GHistIndexMatrix const& gmat, const bst_bin_t split_cond, GHistIndexMatrix const& gmat,
const ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) { const ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) {
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end()); common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end()); common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());

View File

@ -28,7 +28,7 @@ void EncodeTreeLeafHost(RegTree const& tree, std::vector<bst_node_t> const& posi
sorted_pos[i] = position[ridx[i]]; sorted_pos[i] = position[ridx[i]];
} }
// find the first non-sampled row // find the first non-sampled row
auto begin_pos = size_t begin_pos =
std::distance(sorted_pos.cbegin(), std::find_if(sorted_pos.cbegin(), sorted_pos.cend(), std::distance(sorted_pos.cbegin(), std::find_if(sorted_pos.cbegin(), sorted_pos.cend(),
[](bst_node_t nidx) { return nidx >= 0; })); [](bst_node_t nidx) { return nidx >= 0; }));
CHECK_LE(begin_pos, sorted_pos.size()); CHECK_LE(begin_pos, sorted_pos.size());

View File

@ -264,7 +264,7 @@ class GlobalApproxUpdater : public TreeUpdater {
public: public:
explicit GlobalApproxUpdater(GenericParameter const *ctx, ObjInfo task) explicit GlobalApproxUpdater(GenericParameter const *ctx, ObjInfo task)
: task_{task}, TreeUpdater(ctx) { : TreeUpdater(ctx), task_{task} {
monitor_.Init(__func__); monitor_.Init(__func__);
} }

View File

@ -355,11 +355,11 @@ void HistRowPartitioner::FindSplitConditions(const std::vector<CPUExpandEntry> &
const bst_float split_pt = tree[nid].SplitCond(); const bst_float split_pt = tree[nid].SplitCond();
const uint32_t lower_bound = gmat.cut.Ptrs()[fid]; const uint32_t lower_bound = gmat.cut.Ptrs()[fid];
const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1]; const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1];
int32_t split_cond = -1; bst_bin_t split_cond = -1;
// convert floating-point split_pt into corresponding bin_id // convert floating-point split_pt into corresponding bin_id
// split_cond = -1 indicates that split_pt is less than all known cut points // split_cond = -1 indicates that split_pt is less than all known cut points
CHECK_LT(upper_bound, static_cast<uint32_t>(std::numeric_limits<int32_t>::max())); CHECK_LT(upper_bound, static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
for (uint32_t bound = lower_bound; bound < upper_bound; ++bound) { for (auto bound = lower_bound; bound < upper_bound; ++bound) {
if (split_pt == gmat.cut.Values()[bound]) { if (split_pt == gmat.cut.Values()[bound]) {
split_cond = static_cast<int32_t>(bound); split_cond = static_cast<int32_t>(bound);
} }

View File

@ -324,7 +324,7 @@ class QuantileHistMaker: public TreeUpdater {
std::unique_ptr<HistogramBuilder<CPUExpandEntry>> histogram_builder_; std::unique_ptr<HistogramBuilder<CPUExpandEntry>> histogram_builder_;
ObjInfo task_; ObjInfo task_;
// Context for number of threads // Context for number of threads
GenericParameter const* ctx_; Context const* ctx_;
std::unique_ptr<common::Monitor> monitor_; std::unique_ptr<common::Monitor> monitor_;
}; };

View File

@ -15,6 +15,7 @@ TEST(DenseColumn, Test) {
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1, int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1, static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2}; static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
BinTypeSize last{kUint8BinsTypeSize};
for (int32_t max_num_bin : max_num_bins) { for (int32_t max_num_bin : max_num_bins) {
auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix(); auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix();
auto sparse_thresh = 0.2; auto sparse_thresh = 0.2;
@ -24,7 +25,10 @@ TEST(DenseColumn, Test) {
for (auto const& page : dmat->GetBatches<SparsePage>()) { for (auto const& page : dmat->GetBatches<SparsePage>()) {
column_matrix.Init(page, gmat, sparse_thresh, common::OmpGetNumThreads(0)); column_matrix.Init(page, gmat, sparse_thresh, common::OmpGetNumThreads(0));
} }
ASSERT_GE(column_matrix.GetTypeSize(), last);
ASSERT_LE(column_matrix.GetTypeSize(), kUint32BinsTypeSize);
last = column_matrix.GetTypeSize();
ASSERT_FALSE(column_matrix.AnyMissing());
for (auto i = 0ull; i < dmat->Info().num_row_; i++) { for (auto i = 0ull; i < dmat->Info().num_row_; i++) {
for (auto j = 0ull; j < dmat->Info().num_col_; j++) { for (auto j = 0ull; j < dmat->Info().num_col_; j++) {
switch (column_matrix.GetTypeSize()) { switch (column_matrix.GetTypeSize()) {
@ -105,6 +109,7 @@ TEST(DenseColumnWithMissing, Test) {
for (auto const& page : dmat->GetBatches<SparsePage>()) { for (auto const& page : dmat->GetBatches<SparsePage>()) {
column_matrix.Init(page, gmat, 0.2, common::OmpGetNumThreads(0)); column_matrix.Init(page, gmat, 0.2, common::OmpGetNumThreads(0));
} }
ASSERT_TRUE(column_matrix.AnyMissing());
switch (column_matrix.GetTypeSize()) { switch (column_matrix.GetTypeSize()) {
case kUint8BinsTypeSize: { case kUint8BinsTypeSize: {
auto col = column_matrix.DenseColumn<uint8_t, true>(0); auto col = column_matrix.DenseColumn<uint8_t, true>(0);

View File

@ -130,7 +130,6 @@ TEST_F(TestPartitionBasedSplit, CPUHist) {
namespace { namespace {
auto CompareOneHotAndPartition(bool onehot) { auto CompareOneHotAndPartition(bool onehot) {
int static constexpr kRows = 128, kCols = 1; int static constexpr kRows = 128, kCols = 1;
using GradientSumT = double;
std::vector<FeatureType> ft(kCols, FeatureType::kCategorical); std::vector<FeatureType> ft(kCols, FeatureType::kCategorical);
TrainParam param; TrainParam param;

View File

@ -35,7 +35,7 @@ TEST(QuantileHist, Partitioner) {
for (auto const& page : Xy->GetBatches<SparsePage>()) { for (auto const& page : Xy->GetBatches<SparsePage>()) {
GHistIndexMatrix gmat; GHistIndexMatrix gmat;
gmat.Init(page, {}, cuts, 64, false, 0.5, ctx.Threads()); gmat.Init(page, {}, cuts, 64, true, 0.5, ctx.Threads());
bst_feature_t const split_ind = 0; bst_feature_t const split_ind = 0;
common::ColumnMatrix column_indices; common::ColumnMatrix column_indices;
column_indices.Init(page, gmat, 0.5, ctx.Threads()); column_indices.Init(page, gmat, 0.5, ctx.Threads());