Extract Sketch Entry from hist maker. (#7503)

* Extract Sketch Entry from hist maker.

* Add a new sketch container for sorted inputs.
* Optimize bin search.
This commit is contained in:
Jiaming Yuan 2021-12-18 05:36:56 +08:00 committed by GitHub
parent b4a1236cfc
commit 9ab73f737e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 393 additions and 217 deletions

View File

@ -15,7 +15,6 @@
#include "random.h" #include "random.h"
#include "column_matrix.h" #include "column_matrix.h"
#include "quantile.h" #include "quantile.h"
#include "./../tree/updater_quantile_hist.h"
#include "../data/gradient_index.h" #include "../data/gradient_index.h"
#if defined(XGBOOST_MM_PREFETCH_PRESENT) #if defined(XGBOOST_MM_PREFETCH_PRESENT)

View File

@ -92,18 +92,20 @@ class HistogramCuts {
// Return the index of a cut point that is strictly greater than the input // Return the index of a cut point that is strictly greater than the input
// value, or the last available index if none exists // value, or the last available index if none exists
BinIdx SearchBin(float value, uint32_t column_id) const { BinIdx SearchBin(float value, uint32_t column_id, std::vector<uint32_t> const& ptrs,
auto beg = cut_ptrs_.ConstHostVector().at(column_id); std::vector<float> const& values) const {
auto end = cut_ptrs_.ConstHostVector().at(column_id + 1); auto end = ptrs[column_id + 1];
const auto &values = cut_values_.ConstHostVector(); auto beg = ptrs[column_id];
auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value); auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value);
BinIdx idx = it - values.cbegin(); BinIdx idx = it - values.cbegin();
if (idx == end) { idx -= !!(idx == end);
idx -= 1;
}
return idx; return idx;
} }
BinIdx SearchBin(float value, uint32_t column_id) const {
return this->SearchBin(value, column_id, Ptrs(), Values());
}
/** /**
* \brief Search the bin index for numerical feature. * \brief Search the bin index for numerical feature.
*/ */
@ -129,7 +131,13 @@ class HistogramCuts {
} }
}; };
inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, /**
* \brief Run CPU sketching on DMatrix.
*
* \param use_sorted Whether should we use SortedCSC for sketching, it's more efficient
* but consumes more memory.
*/
inline HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, bool use_sorted = false,
Span<float> const hessian = {}) { Span<float> const hessian = {}) {
HistogramCuts out; HistogramCuts out;
auto const& info = m->Info(); auto const& info = m->Info();
@ -146,13 +154,23 @@ inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins,
reduced[i] += entries_per_column[i]; reduced[i] += entries_per_column[i];
} }
} }
HostSketchContainer container(reduced, max_bins,
m->Info().feature_types.ConstHostSpan(), if (!use_sorted) {
HostSketchContainer::UseGroup(info), threads); HostSketchContainer container(max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info),
hessian, threads);
for (auto const& page : m->GetBatches<SparsePage>()) { for (auto const& page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian); container.PushRowPage(page, info, hessian);
} }
container.MakeCuts(&out); container.MakeCuts(&out);
} else {
SortedSketchContainer container{
max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info), hessian, threads};
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
container.PushColPage(page, info, hessian);
}
container.MakeCuts(&out);
}
return out; return out;
} }

View File

@ -12,31 +12,25 @@
namespace xgboost { namespace xgboost {
namespace common { namespace common {
HostSketchContainer::HostSketchContainer( template <typename WQSketch>
std::vector<bst_row_t> columns_size, int32_t max_bins, SketchContainerImpl<WQSketch>::SketchContainerImpl(std::vector<bst_row_t> columns_size,
common::Span<FeatureType const> feature_types, bool use_group, int32_t max_bins,
int32_t n_threads) common::Span<FeatureType const> feature_types,
bool use_group, int32_t n_threads)
: feature_types_(feature_types.cbegin(), feature_types.cend()), : feature_types_(feature_types.cbegin(), feature_types.cend()),
columns_size_{std::move(columns_size)}, max_bins_{max_bins}, columns_size_{std::move(columns_size)},
use_group_ind_{use_group}, n_threads_{n_threads} { max_bins_{max_bins},
use_group_ind_{use_group},
n_threads_{n_threads} {
monitor_.Init(__func__); monitor_.Init(__func__);
CHECK_NE(columns_size_.size(), 0); CHECK_NE(columns_size_.size(), 0);
sketches_.resize(columns_size_.size()); sketches_.resize(columns_size_.size());
CHECK_GE(n_threads_, 1); CHECK_GE(n_threads_, 1);
categories_.resize(columns_size_.size()); categories_.resize(columns_size_.size());
ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) {
auto n_bins = std::min(static_cast<size_t>(max_bins_), columns_size_[i]);
n_bins = std::max(n_bins, static_cast<decltype(n_bins)>(1));
auto eps = 1.0 / (static_cast<float>(n_bins) * WQSketch::kFactor);
if (!IsCat(this->feature_types_, i)) {
sketches_[i].Init(columns_size_[i], eps);
sketches_[i].inqueue.queue.resize(sketches_[i].limit_size * 2);
}
});
} }
std::vector<bst_row_t> template <typename WQSketch>
HostSketchContainer::CalcColumnSize(SparsePage const &batch, std::vector<bst_row_t> SketchContainerImpl<WQSketch>::CalcColumnSize(SparsePage const &batch,
bst_feature_t const n_columns, bst_feature_t const n_columns,
size_t const nthreads) { size_t const nthreads) {
auto page = batch.GetView(); auto page = batch.GetView();
@ -45,7 +39,7 @@ HostSketchContainer::CalcColumnSize(SparsePage const &batch,
column.resize(n_columns, 0); column.resize(n_columns, 0);
} }
ParallelFor(omp_ulong(page.Size()), nthreads, [&](omp_ulong i) { ParallelFor(page.Size(), nthreads, [&](omp_ulong i) {
auto &local_column_sizes = column_sizes.at(omp_get_thread_num()); auto &local_column_sizes = column_sizes.at(omp_get_thread_num());
auto row = page[i]; auto row = page[i];
auto const *p_row = row.data(); auto const *p_row = row.data();
@ -54,7 +48,7 @@ HostSketchContainer::CalcColumnSize(SparsePage const &batch,
} }
}); });
std::vector<bst_row_t> entries_per_columns(n_columns, 0); std::vector<bst_row_t> entries_per_columns(n_columns, 0);
ParallelFor(bst_omp_uint(n_columns), nthreads, [&](bst_omp_uint i) { ParallelFor(n_columns, nthreads, [&](bst_omp_uint i) {
for (auto const &thread : column_sizes) { for (auto const &thread : column_sizes) {
entries_per_columns[i] += thread[i]; entries_per_columns[i] += thread[i];
} }
@ -62,8 +56,10 @@ HostSketchContainer::CalcColumnSize(SparsePage const &batch,
return entries_per_columns; return entries_per_columns;
} }
std::vector<bst_feature_t> HostSketchContainer::LoadBalance( template <typename WQSketch>
SparsePage const &batch, bst_feature_t n_columns, size_t const nthreads) { std::vector<bst_feature_t> SketchContainerImpl<WQSketch>::LoadBalance(SparsePage const &batch,
bst_feature_t n_columns,
size_t const nthreads) {
/* Some sparse datasets have their mass concentrating on small number of features. To /* Some sparse datasets have their mass concentrating on small number of features. To
* avoid waiting for a few threads running forever, we here distribute different number * avoid waiting for a few threads running forever, we here distribute different number
* of columns to different threads according to number of entries. * of columns to different threads according to number of entries.
@ -101,9 +97,8 @@ std::vector<bst_feature_t> HostSketchContainer::LoadBalance(
namespace { namespace {
// Function to merge hessian and sample weights // Function to merge hessian and sample weights
std::vector<float> MergeWeights(MetaInfo const &info, std::vector<float> MergeWeights(MetaInfo const &info, Span<float const> hessian, bool use_group,
Span<float> const hessian, int32_t n_threads) {
bool use_group, int32_t n_threads) {
CHECK_EQ(hessian.size(), info.num_row_); CHECK_EQ(hessian.size(), info.num_row_);
std::vector<float> results(hessian.size()); std::vector<float> results(hessian.size());
auto const &group_ptr = info.group_ptr_; auto const &group_ptr = info.group_ptr_;
@ -148,8 +143,9 @@ std::vector<float> UnrollGroupWeights(MetaInfo const &info) {
} }
} // anonymous namespace } // anonymous namespace
void HostSketchContainer::PushRowPage( template <typename WQSketch>
SparsePage const &page, MetaInfo const &info, Span<float> hessian) { void SketchContainerImpl<WQSketch>::PushRowPage(SparsePage const &page, MetaInfo const &info,
Span<float const> hessian) {
monitor_.Start(__func__); monitor_.Start(__func__);
bst_feature_t n_columns = info.num_col_; bst_feature_t n_columns = info.num_col_;
auto is_dense = info.num_nonzero_ == info.num_col_ * info.num_row_; auto is_dense = info.num_nonzero_ == info.num_col_ * info.num_row_;
@ -216,11 +212,12 @@ void HostSketchContainer::PushRowPage(
monitor_.Stop(__func__); monitor_.Stop(__func__);
} }
void HostSketchContainer::GatherSketchInfo( template <typename WQSketch>
std::vector<WQSketch::SummaryContainer> const &reduced, void SketchContainerImpl<WQSketch>::GatherSketchInfo(
std::vector<typename WQSketch::SummaryContainer> const &reduced,
std::vector<size_t> *p_worker_segments, std::vector<size_t> *p_worker_segments,
std::vector<bst_row_t> *p_sketches_scan, std::vector<bst_row_t> *p_sketches_scan,
std::vector<WQSketch::Entry> *p_global_sketches) { std::vector<typename WQSketch::Entry> *p_global_sketches) {
auto& worker_segments = *p_worker_segments; auto& worker_segments = *p_worker_segments;
worker_segments.resize(1, 0); worker_segments.resize(1, 0);
auto world = rabit::GetWorldSize(); auto world = rabit::GetWorldSize();
@ -251,8 +248,8 @@ void HostSketchContainer::GatherSketchInfo(
auto total = worker_segments.back(); auto total = worker_segments.back();
auto& global_sketches = *p_global_sketches; auto& global_sketches = *p_global_sketches;
global_sketches.resize(total, WQSketch::Entry{0, 0, 0, 0}); global_sketches.resize(total, typename WQSketch::Entry{0, 0, 0, 0});
auto worker_sketch = Span<WQSketch::Entry>{global_sketches}.subspan( auto worker_sketch = Span<typename WQSketch::Entry>{global_sketches}.subspan(
worker_segments[rank], worker_segments[rank + 1] - worker_segments[rank]); worker_segments[rank], worker_segments[rank + 1] - worker_segments[rank]);
size_t cursor = 0; size_t cursor = 0;
for (auto const &sketch : reduced) { for (auto const &sketch : reduced) {
@ -261,14 +258,15 @@ void HostSketchContainer::GatherSketchInfo(
cursor += sketch.size; cursor += sketch.size;
} }
static_assert(sizeof(WQSketch::Entry) / 4 == sizeof(float), ""); static_assert(sizeof(typename WQSketch::Entry) / 4 == sizeof(float), "");
rabit::Allreduce<rabit::op::Sum>( rabit::Allreduce<rabit::op::Sum>(
reinterpret_cast<float *>(global_sketches.data()), reinterpret_cast<float *>(global_sketches.data()),
global_sketches.size() * sizeof(WQSketch::Entry) / sizeof(float)); global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float));
} }
void HostSketchContainer::AllReduce( template <typename WQSketch>
std::vector<WQSketch::SummaryContainer> *p_reduced, void SketchContainerImpl<WQSketch>::AllReduce(
std::vector<typename WQSketch::SummaryContainer> *p_reduced,
std::vector<int32_t>* p_num_cuts) { std::vector<int32_t>* p_num_cuts) {
monitor_.Start(__func__); monitor_.Start(__func__);
auto& num_cuts = *p_num_cuts; auto& num_cuts = *p_num_cuts;
@ -291,7 +289,7 @@ void HostSketchContainer::AllReduce(
std::min(global_column_size[i], std::min(global_column_size[i],
static_cast<size_t>(max_bins_ * WQSketch::kFactor))); static_cast<size_t>(max_bins_ * WQSketch::kFactor)));
if (global_column_size[i] != 0) { if (global_column_size[i] != 0) {
WQSketch::SummaryContainer out; typename WQSketch::SummaryContainer out;
sketches_[i].GetSummary(&out); sketches_[i].GetSummary(&out);
reduced[i].Reserve(intermediate_num_cuts); reduced[i].Reserve(intermediate_num_cuts);
CHECK(reduced[i].data); CHECK(reduced[i].data);
@ -309,11 +307,11 @@ void HostSketchContainer::AllReduce(
std::vector<size_t> worker_segments(1, 0); // CSC pointer to sketches. std::vector<size_t> worker_segments(1, 0); // CSC pointer to sketches.
std::vector<bst_row_t> sketches_scan((n_columns + 1) * world, 0); std::vector<bst_row_t> sketches_scan((n_columns + 1) * world, 0);
std::vector<WQSketch::Entry> global_sketches; std::vector<typename WQSketch::Entry> global_sketches;
this->GatherSketchInfo(reduced, &worker_segments, &sketches_scan, this->GatherSketchInfo(reduced, &worker_segments, &sketches_scan,
&global_sketches); &global_sketches);
std::vector<WQSketch::SummaryContainer> final_sketches(n_columns); std::vector<typename WQSketch::SummaryContainer> final_sketches(n_columns);
ParallelFor(n_columns, n_threads_, [&](auto fidx) { ParallelFor(n_columns, n_threads_, [&](auto fidx) {
int32_t intermediate_num_cuts = num_cuts[fidx]; int32_t intermediate_num_cuts = num_cuts[fidx];
auto nbytes = auto nbytes =
@ -321,8 +319,8 @@ void HostSketchContainer::AllReduce(
for (int32_t i = 1; i < world + 1; ++i) { for (int32_t i = 1; i < world + 1; ++i) {
auto size = worker_segments.at(i) - worker_segments[i - 1]; auto size = worker_segments.at(i) - worker_segments[i - 1];
auto worker_sketches = Span<WQSketch::Entry>{global_sketches}.subspan( auto worker_sketches =
worker_segments[i - 1], size); Span<typename WQSketch::Entry>{global_sketches}.subspan(worker_segments[i - 1], size);
auto worker_scan = auto worker_scan =
Span<bst_row_t>(sketches_scan) Span<bst_row_t>(sketches_scan)
.subspan((i - 1) * (n_columns + 1), (n_columns + 1)); .subspan((i - 1) * (n_columns + 1), (n_columns + 1));
@ -330,8 +328,7 @@ void HostSketchContainer::AllReduce(
auto worker_feature = worker_sketches.subspan( auto worker_feature = worker_sketches.subspan(
worker_scan[fidx], worker_scan[fidx + 1] - worker_scan[fidx]); worker_scan[fidx], worker_scan[fidx + 1] - worker_scan[fidx]);
CHECK(worker_feature.data()); CHECK(worker_feature.data());
WQSummary<float, float> summary(worker_feature.data(), typename WQSketch::Summary summary(worker_feature.data(), worker_feature.size());
worker_feature.size());
auto &out = final_sketches.at(fidx); auto &out = final_sketches.at(fidx);
out.Reduce(summary, nbytes); out.Reduce(summary, nbytes);
} }
@ -342,8 +339,9 @@ void HostSketchContainer::AllReduce(
monitor_.Stop(__func__); monitor_.Stop(__func__);
} }
void AddCutPoint(WQuantileSketch<float, float>::SummaryContainer const &summary, template <typename SketchType>
int max_bin, HistogramCuts *cuts) { void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_bin,
HistogramCuts *cuts) {
size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin)); size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin));
auto &cut_values = cuts->cut_values_.HostVector(); auto &cut_values = cuts->cut_values_.HostVector();
for (size_t i = 1; i < required_cuts; ++i) { for (size_t i = 1; i < required_cuts; ++i) {
@ -361,20 +359,21 @@ void AddCategories(std::set<bst_cat_t> const &categories, HistogramCuts *cuts) {
} }
} }
void HostSketchContainer::MakeCuts(HistogramCuts* cuts) { template <typename WQSketch>
void SketchContainerImpl<WQSketch>::MakeCuts(HistogramCuts* cuts) {
monitor_.Start(__func__); monitor_.Start(__func__);
std::vector<WQSketch::SummaryContainer> reduced; std::vector<typename WQSketch::SummaryContainer> reduced;
std::vector<int32_t> num_cuts; std::vector<int32_t> num_cuts;
this->AllReduce(&reduced, &num_cuts); this->AllReduce(&reduced, &num_cuts);
cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f); cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f);
std::vector<WQSketch::SummaryContainer> final_summaries(reduced.size()); std::vector<typename WQSketch::SummaryContainer> final_summaries(reduced.size());
ParallelFor(reduced.size(), n_threads_, Sched::Guided(), [&](size_t fidx) { ParallelFor(reduced.size(), n_threads_, Sched::Guided(), [&](size_t fidx) {
if (IsCat(feature_types_, fidx)) { if (IsCat(feature_types_, fidx)) {
return; return;
} }
WQSketch::SummaryContainer &a = final_summaries[fidx]; typename WQSketch::SummaryContainer &a = final_summaries[fidx];
size_t max_num_bins = std::min(num_cuts[fidx], max_bins_); size_t max_num_bins = std::min(num_cuts[fidx], max_bins_);
a.Reserve(max_num_bins + 1); a.Reserve(max_num_bins + 1);
CHECK(a.data); CHECK(a.data);
@ -392,11 +391,11 @@ void HostSketchContainer::MakeCuts(HistogramCuts* cuts) {
for (size_t fid = 0; fid < reduced.size(); ++fid) { for (size_t fid = 0; fid < reduced.size(); ++fid) {
size_t max_num_bins = std::min(num_cuts[fid], max_bins_); size_t max_num_bins = std::min(num_cuts[fid], max_bins_);
WQSketch::SummaryContainer const& a = final_summaries[fid]; typename WQSketch::SummaryContainer const& a = final_summaries[fid];
if (IsCat(feature_types_, fid)) { if (IsCat(feature_types_, fid)) {
AddCategories(categories_.at(fid), cuts); AddCategories(categories_.at(fid), cuts);
} else { } else {
AddCutPoint(a, max_num_bins, cuts); AddCutPoint<WQSketch>(a, max_num_bins, cuts);
// push a value that is greater than anything // push a value that is greater than anything
const bst_float cpt = (a.size > 0) ? a.data[a.size - 1].value const bst_float cpt = (a.size > 0) ? a.data[a.size - 1].value
: cuts->min_vals_.HostVector()[fid]; : cuts->min_vals_.HostVector()[fid];
@ -413,5 +412,64 @@ void HostSketchContainer::MakeCuts(HistogramCuts* cuts) {
} }
monitor_.Stop(__func__); monitor_.Stop(__func__);
} }
template class SketchContainerImpl<WQuantileSketch<float, float>>;
template class SketchContainerImpl<WXQuantileSketch<float, float>>;
HostSketchContainer::HostSketchContainer(int32_t max_bins, MetaInfo const &info,
std::vector<size_t> columns_size, bool use_group,
Span<float const> hessian, int32_t n_threads)
: SketchContainerImpl{columns_size, max_bins, info.feature_types.ConstHostSpan(), use_group,
n_threads} {
monitor_.Init(__func__);
ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) {
auto n_bins = std::min(static_cast<size_t>(max_bins_), columns_size_[i]);
n_bins = std::max(n_bins, static_cast<decltype(n_bins)>(1));
auto eps = 1.0 / (static_cast<float>(n_bins) * WQSketch::kFactor);
if (!IsCat(this->feature_types_, i)) {
sketches_[i].Init(columns_size_[i], eps);
sketches_[i].inqueue.queue.resize(sketches_[i].limit_size * 2);
}
});
}
void SortedSketchContainer::PushColPage(SparsePage const &page, MetaInfo const &info,
Span<float const> hessian) {
monitor_.Start(__func__);
// glue these conditions using ternary operator to avoid making data copies.
auto const &weights =
hessian.empty() ? (use_group_ind_ ? UnrollGroupWeights(info) // use group weight
: info.weights_.HostVector()) // use sample weight
: MergeWeights(info, hessian, use_group_ind_,
n_threads_); // use hessian merged with group/sample weights
CHECK_EQ(weights.size(), info.num_row_);
auto view = page.GetView();
ParallelFor(view.Size(), n_threads_, [&](size_t fidx) {
auto column = view[fidx];
auto &sketch = sketches_[fidx];
sketch.Init(max_bins_);
// first pass
sketch.sum_total = 0.0;
for (auto c : column) {
sketch.sum_total += weights[c.index];
}
// second pass
if (IsCat(feature_types_, fidx)) {
for (auto c : column) {
categories_[fidx].emplace(AsCat(c.fvalue));
}
} else {
for (auto c : column) {
sketch.Push(c.fvalue, weights[c.index], max_bins_);
}
}
if (!IsCat(feature_types_, fidx) && !column.empty()) {
sketch.Finalize(max_bins_);
}
});
monitor_.Stop(__func__);
}
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -702,11 +702,9 @@ class HistogramCuts;
/*! /*!
* A sketch matrix storing sketches for each feature. * A sketch matrix storing sketches for each feature.
*/ */
class HostSketchContainer { template <typename WQSketch>
public: class SketchContainerImpl {
using WQSketch = WQuantileSketch<float, float>; protected:
private:
std::vector<WQSketch> sketches_; std::vector<WQSketch> sketches_;
std::vector<std::set<bst_cat_t>> categories_; std::vector<std::set<bst_cat_t>> categories_;
std::vector<FeatureType> const feature_types_; std::vector<FeatureType> const feature_types_;
@ -724,7 +722,7 @@ class HostSketchContainer {
* \param max_bins maximum number of bins for each feature. * \param max_bins maximum number of bins for each feature.
* \param use_group whether is assigned to group to data instance. * \param use_group whether is assigned to group to data instance.
*/ */
HostSketchContainer(std::vector<bst_row_t> columns_size, int32_t max_bins, SketchContainerImpl(std::vector<bst_row_t> columns_size, int32_t max_bins,
common::Span<FeatureType const> feature_types, bool use_group, common::Span<FeatureType const> feature_types, bool use_group,
int32_t n_threads); int32_t n_threads);
@ -755,20 +753,139 @@ class HostSketchContainer {
return group_ind; return group_ind;
} }
// Gather sketches from all workers. // Gather sketches from all workers.
void GatherSketchInfo(std::vector<WQSketch::SummaryContainer> const &reduced, void GatherSketchInfo(std::vector<typename WQSketch::SummaryContainer> const &reduced,
std::vector<bst_row_t> *p_worker_segments, std::vector<bst_row_t> *p_worker_segments,
std::vector<bst_row_t> *p_sketches_scan, std::vector<bst_row_t> *p_sketches_scan,
std::vector<WQSketch::Entry> *p_global_sketches); std::vector<typename WQSketch::Entry> *p_global_sketches);
// Merge sketches from all workers. // Merge sketches from all workers.
void AllReduce(std::vector<WQSketch::SummaryContainer> *p_reduced, void AllReduce(std::vector<typename WQSketch::SummaryContainer> *p_reduced,
std::vector<int32_t> *p_num_cuts); std::vector<int32_t> *p_num_cuts);
/* \brief Push a CSR matrix. */ /* \brief Push a CSR matrix. */
void PushRowPage(SparsePage const &page, MetaInfo const &info, void PushRowPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian = {});
Span<float> const hessian = {});
void MakeCuts(HistogramCuts* cuts); void MakeCuts(HistogramCuts* cuts);
}; };
class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, float>> {
public:
using WQSketch = WQuantileSketch<float, float>;
public:
HostSketchContainer(int32_t max_bins, MetaInfo const &info, std::vector<size_t> columns_size,
bool use_group, Span<float const> hessian, int32_t n_threads);
};
/**
* \brief Quantile structure accepts sorted data, extracted from histmaker.
*/
struct SortedQuantile {
/*! \brief total sum of amount to be met */
double sum_total{0.0};
/*! \brief statistics used in the sketch */
double rmin, wmin;
/*! \brief last seen feature value */
bst_float last_fvalue;
/*! \brief current size of sketch */
double next_goal;
// pointer to the sketch to put things in
common::WXQuantileSketch<bst_float, bst_float>* sketch;
// initialize the space
inline void Init(unsigned max_size) {
next_goal = -1.0f;
rmin = wmin = 0.0f;
sketch->temp.Reserve(max_size + 1);
sketch->temp.size = 0;
}
/*!
* \brief push a new element to sketch
* \param fvalue feature value, comes in sorted ascending order
* \param w weight
* \param max_size
*/
inline void Push(bst_float fvalue, bst_float w, unsigned max_size) {
if (next_goal == -1.0f) {
next_goal = 0.0f;
last_fvalue = fvalue;
wmin = w;
return;
}
if (last_fvalue != fvalue) {
double rmax = rmin + wmin;
if (rmax >= next_goal && sketch->temp.size != max_size) {
if (sketch->temp.size == 0 ||
last_fvalue > sketch->temp.data[sketch->temp.size - 1].value) {
// push to sketch
sketch->temp.data[sketch->temp.size] =
common::WXQuantileSketch<bst_float, bst_float>::Entry(
static_cast<bst_float>(rmin), static_cast<bst_float>(rmax),
static_cast<bst_float>(wmin), last_fvalue);
CHECK_LT(sketch->temp.size, max_size) << "invalid maximum size max_size=" << max_size
<< ", stemp.size" << sketch->temp.size;
++sketch->temp.size;
}
if (sketch->temp.size == max_size) {
next_goal = sum_total * 2.0f + 1e-5f;
} else {
next_goal = static_cast<bst_float>(sketch->temp.size * sum_total / max_size);
}
} else {
if (rmax >= next_goal) {
LOG(DEBUG) << "INFO: rmax=" << rmax << ", sum_total=" << sum_total
<< ", naxt_goal=" << next_goal << ", size=" << sketch->temp.size;
}
}
rmin = rmax;
wmin = w;
last_fvalue = fvalue;
} else {
wmin += w;
}
}
/*! \brief push final unfinished value to the sketch */
inline void Finalize(unsigned max_size) {
double rmax = rmin + wmin;
if (sketch->temp.size == 0 || last_fvalue > sketch->temp.data[sketch->temp.size - 1].value) {
CHECK_LE(sketch->temp.size, max_size)
<< "Finalize: invalid maximum size, max_size=" << max_size
<< ", stemp.size=" << sketch->temp.size;
// push to sketch
sketch->temp.data[sketch->temp.size] = common::WXQuantileSketch<bst_float, bst_float>::Entry(
static_cast<bst_float>(rmin), static_cast<bst_float>(rmax), static_cast<bst_float>(wmin),
last_fvalue);
++sketch->temp.size;
}
sketch->PushTemp();
}
};
class SortedSketchContainer : public SketchContainerImpl<WXQuantileSketch<float, float>> {
std::vector<SortedQuantile> sketches_;
using Super = SketchContainerImpl<WXQuantileSketch<float, float>>;
public:
explicit SortedSketchContainer(int32_t max_bins, MetaInfo const &info,
std::vector<size_t> columns_size, bool use_group,
Span<float const> hessian, int32_t n_threads)
: SketchContainerImpl{columns_size, max_bins, info.feature_types.ConstHostSpan(), use_group,
n_threads} {
monitor_.Init(__func__);
sketches_.resize(info.num_col_);
size_t i = 0;
for (auto &sketch : sketches_) {
sketch.sketch = &Super::sketches_[i];
sketch.Init(max_bins_);
auto eps = 2.0 / max_bins;
sketch.sketch->Init(columns_size_[i], eps);
++i;
}
}
/**
* \brief Push a sorted CSC page.
*/
void PushColPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian);
};
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_COMMON_QUANTILE_H_ #endif // XGBOOST_COMMON_QUANTILE_H_

View File

@ -118,7 +118,7 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch,
[](auto idx, auto) { return idx; }); [](auto idx, auto) { return idx; });
} }
common::ParallelFor(bst_omp_uint(nbins), n_threads, [&](bst_omp_uint idx) { common::ParallelFor(nbins, n_threads, [&](bst_omp_uint idx) {
for (int32_t tid = 0; tid < n_threads; ++tid) { for (int32_t tid = 0; tid < n_threads; ++tid) {
hit_count[idx] += hit_count_tloc_[tid * nbins + idx]; hit_count[idx] += hit_count_tloc_[tid * nbins + idx];
hit_count_tloc_[tid * nbins + idx] = 0; // reset for next batch hit_count_tloc_[tid * nbins + idx] = 0; // reset for next batch
@ -126,8 +126,11 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch,
}); });
} }
void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins, common::Span<float> hess) { void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, bool sorted_sketch,
cut = common::SketchOnDMatrix(p_fmat, max_bins, hess); common::Span<float> hess) {
// We use sorted sketching for approx tree method since it's more efficient in
// computation time (but higher memory usage).
cut = common::SketchOnDMatrix(p_fmat, max_bins, sorted_sketch, hess);
max_num_bins = max_bins; max_num_bins = max_bins;
const int32_t nthread = omp_get_max_threads(); const int32_t nthread = omp_get_max_threads();

View File

@ -37,14 +37,14 @@ class GHistIndexMatrix {
size_t base_rowid{0}; size_t base_rowid{0};
GHistIndexMatrix() = default; GHistIndexMatrix() = default;
GHistIndexMatrix(DMatrix* x, int32_t max_bin, common::Span<float> hess = {}) { GHistIndexMatrix(DMatrix* x, int32_t max_bin, bool sorted_sketch, common::Span<float> hess = {}) {
this->Init(x, max_bin, hess); this->Init(x, max_bin, sorted_sketch, hess);
} }
// Create a global histogram matrix, given cut // Create a global histogram matrix, given cut
void Init(DMatrix* p_fmat, int max_num_bins, common::Span<float> hess); void Init(DMatrix* p_fmat, int max_num_bins, bool sorted_sketch, common::Span<float> hess);
void Init(SparsePage const& page, common::Span<FeatureType const> ft, void Init(SparsePage const& page, common::Span<FeatureType const> ft,
common::HistogramCuts const &cuts, int32_t max_bins_per_feat, common::HistogramCuts const& cuts, int32_t max_bins_per_feat, bool is_dense,
bool is_dense, int32_t n_threads); int32_t n_threads);
// specific method for sparse data as no possibility to reduce allocated memory // specific method for sparse data as no possibility to reduce allocated memory
template <typename BinIdxType, typename GetOffset> template <typename BinIdxType, typename GetOffset>
@ -57,7 +57,9 @@ class GHistIndexMatrix {
const size_t batch_size = batch.Size(); const size_t batch_size = batch.Size();
CHECK_LT(batch_size, offset_vec.size()); CHECK_LT(batch_size, offset_vec.size());
BinIdxType* index_data = index_data_span.data(); BinIdxType* index_data = index_data_span.data();
common::ParallelFor(omp_ulong(batch_size), batch_threads, [&](omp_ulong i) { auto const& ptrs = cut.Ptrs();
auto const& values = cut.Values();
common::ParallelFor(batch_size, batch_threads, [&](omp_ulong i) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
size_t ibegin = row_ptr[rbegin + i]; size_t ibegin = row_ptr[rbegin + i];
size_t iend = row_ptr[rbegin + i + 1]; size_t iend = row_ptr[rbegin + i + 1];
@ -71,7 +73,7 @@ class GHistIndexMatrix {
index_data[ibegin + j] = get_offset(bin_idx, j); index_data[ibegin + j] = get_offset(bin_idx, j);
++hit_count_tloc_[tid * nbins + bin_idx]; ++hit_count_tloc_[tid * nbins + bin_idx];
} else { } else {
uint32_t idx = cut.SearchBin(inst[j]); uint32_t idx = cut.SearchBin(inst[j].fvalue, inst[j].index, ptrs, values);
index_data[ibegin + j] = get_offset(idx, j); index_data[ibegin + j] = get_offset(idx, j);
++hit_count_tloc_[tid * nbins + idx]; ++hit_count_tloc_[tid * nbins + idx];
} }

View File

@ -97,7 +97,9 @@ BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& par
if (!gradient_index_ || (batch_param_ != param && param != BatchParam{}) || param.regen) { if (!gradient_index_ || (batch_param_ != param && param != BatchParam{}) || param.regen) {
CHECK_GE(param.max_bin, 2); CHECK_GE(param.max_bin, 2);
CHECK_EQ(param.gpu_id, -1); CHECK_EQ(param.gpu_id, -1);
gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin, param.hess)); // Used only by approx.
auto sorted_sketch = param.regen;
gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin, sorted_sketch, param.hess));
batch_param_ = param; batch_param_ = param;
CHECK_EQ(batch_param_.hess.data(), param.hess.data()); CHECK_EQ(batch_param_.hess.data(), param.hess.data());
} }

View File

@ -159,12 +159,12 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {
BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam& param) { BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam& param) {
CHECK_GE(param.max_bin, 2); CHECK_GE(param.max_bin, 2);
if (param.hess.empty()) { if (param.hess.empty() && !param.regen) {
// hist method doesn't support full external memory implementation, so we concatenate // hist method doesn't support full external memory implementation, so we concatenate
// all index here. // all index here.
if (!ghist_index_page_ || (param != batch_param_ && param != BatchParam{})) { if (!ghist_index_page_ || (param != batch_param_ && param != BatchParam{})) {
this->InitializeSparsePage(); this->InitializeSparsePage();
ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin}); ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin, param.regen});
this->InitializeSparsePage(); this->InitializeSparsePage();
batch_param_ = param; batch_param_ = param;
} }
@ -175,18 +175,21 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam&
auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_); auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
this->InitializeSparsePage(); this->InitializeSparsePage();
if (!cache_info_.at(id)->written || (batch_param_ != param && param != BatchParam{})) { if (!cache_info_.at(id)->written || (batch_param_ != param && param != BatchParam{}) ||
param.regen) {
cache_info_.erase(id); cache_info_.erase(id);
MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_); MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
auto cuts = common::SketchOnDMatrix(this, param.max_bin, param.hess); // Use sorted sketch for approx.
auto sorted_sketch = param.regen;
auto cuts = common::SketchOnDMatrix(this, param.max_bin, sorted_sketch, param.hess);
this->InitializeSparsePage(); // reset after use. this->InitializeSparsePage(); // reset after use.
batch_param_ = param; batch_param_ = param;
ghist_index_source_.reset(); ghist_index_source_.reset();
CHECK_NE(cuts.Values().size(), 0); CHECK_NE(cuts.Values().size(), 0);
auto ft = this->info_.feature_types.ConstHostSpan(); auto ft = this->info_.feature_types.ConstHostSpan();
ghist_index_source_.reset(new GradientIndexPageSource( ghist_index_source_.reset(
this->missing_, this->ctx_.Threads(), this->Info().num_col_, new GradientIndexPageSource(this->missing_, this->ctx_.Threads(), this->Info().num_col_,
this->n_batches_, cache_info_.at(id), param, std::move(cuts), this->n_batches_, cache_info_.at(id), param, std::move(cuts),
this->IsDense(), param.max_bin, ft, sparse_page_source_)); this->IsDense(), param.max_bin, ft, sparse_page_source_));
} else { } else {

View File

@ -369,92 +369,7 @@ class BaseMaker: public TreeUpdater {
} }
} }
} }
/*! \brief common helper data structure to build sketch */ using SketchEntry = common::SortedQuantile;
struct SketchEntry {
/*! \brief total sum of amount to be met */
double sum_total;
/*! \brief statistics used in the sketch */
double rmin, wmin;
/*! \brief last seen feature value */
bst_float last_fvalue;
/*! \brief current size of sketch */
double next_goal;
// pointer to the sketch to put things in
common::WXQuantileSketch<bst_float, bst_float> *sketch;
// initialize the space
inline void Init(unsigned max_size) {
next_goal = -1.0f;
rmin = wmin = 0.0f;
sketch->temp.Reserve(max_size + 1);
sketch->temp.size = 0;
}
/*!
* \brief push a new element to sketch
* \param fvalue feature value, comes in sorted ascending order
* \param w weight
* \param max_size
*/
inline void Push(bst_float fvalue, bst_float w, unsigned max_size) {
if (next_goal == -1.0f) {
next_goal = 0.0f;
last_fvalue = fvalue;
wmin = w;
return;
}
if (last_fvalue != fvalue) {
double rmax = rmin + wmin;
if (rmax >= next_goal && sketch->temp.size != max_size) {
if (sketch->temp.size == 0 ||
last_fvalue > sketch->temp.data[sketch->temp.size-1].value) {
// push to sketch
sketch->temp.data[sketch->temp.size] =
common::WXQuantileSketch<bst_float, bst_float>::
Entry(static_cast<bst_float>(rmin),
static_cast<bst_float>(rmax),
static_cast<bst_float>(wmin), last_fvalue);
CHECK_LT(sketch->temp.size, max_size)
<< "invalid maximum size max_size=" << max_size
<< ", stemp.size" << sketch->temp.size;
++sketch->temp.size;
}
if (sketch->temp.size == max_size) {
next_goal = sum_total * 2.0f + 1e-5f;
} else {
next_goal = static_cast<bst_float>(sketch->temp.size * sum_total / max_size);
}
} else {
if (rmax >= next_goal) {
LOG(TRACKER) << "INFO: rmax=" << rmax
<< ", sum_total=" << sum_total
<< ", naxt_goal=" << next_goal
<< ", size=" << sketch->temp.size;
}
}
rmin = rmax;
wmin = w;
last_fvalue = fvalue;
} else {
wmin += w;
}
}
/*! \brief push final unfinished value to the sketch */
inline void Finalize(unsigned max_size) {
double rmax = rmin + wmin;
if (sketch->temp.size == 0 || last_fvalue > sketch->temp.data[sketch->temp.size-1].value) {
CHECK_LE(sketch->temp.size, max_size)
<< "Finalize: invalid maximum size, max_size=" << max_size
<< ", stemp.size=" << sketch->temp.size;
// push to sketch
sketch->temp.data[sketch->temp.size] =
common::WXQuantileSketch<bst_float, bst_float>::
Entry(static_cast<bst_float>(rmin),
static_cast<bst_float>(rmax),
static_cast<bst_float>(wmin), last_fvalue);
++sketch->temp.size;
}
sketch->PushTemp();
}
};
/*! \brief training parameter of tree grower */ /*! \brief training parameter of tree grower */
TrainParam param_; TrainParam param_;
/*! \brief queue of nodes to be expanded */ /*! \brief queue of nodes to be expanded */

View File

@ -14,7 +14,7 @@ TEST(DenseColumn, Test) {
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2}; static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2};
for (size_t max_num_bin : max_num_bins) { for (size_t max_num_bin : max_num_bins) {
auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix(); auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix();
GHistIndexMatrix gmat(dmat.get(), max_num_bin); GHistIndexMatrix gmat(dmat.get(), max_num_bin, false);
ColumnMatrix column_matrix; ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.2); column_matrix.Init(gmat, 0.2);
@ -61,7 +61,7 @@ TEST(SparseColumn, Test) {
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2}; static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2};
for (size_t max_num_bin : max_num_bins) { for (size_t max_num_bin : max_num_bins) {
auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix(); auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix();
GHistIndexMatrix gmat(dmat.get(), max_num_bin); GHistIndexMatrix gmat(dmat.get(), max_num_bin, false);
ColumnMatrix column_matrix; ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.5); column_matrix.Init(gmat, 0.5);
switch (column_matrix.GetTypeSize()) { switch (column_matrix.GetTypeSize()) {
@ -101,7 +101,7 @@ TEST(DenseColumnWithMissing, Test) {
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2 }; static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2 };
for (size_t max_num_bin : max_num_bins) { for (size_t max_num_bin : max_num_bins) {
auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix(); auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix();
GHistIndexMatrix gmat(dmat.get(), max_num_bin); GHistIndexMatrix gmat(dmat.get(), max_num_bin, false);
ColumnMatrix column_matrix; ColumnMatrix column_matrix;
column_matrix.Init(gmat, 0.2); column_matrix.Init(gmat, 0.2);
switch (column_matrix.GetTypeSize()) { switch (column_matrix.GetTypeSize()) {
@ -130,7 +130,7 @@ void TestGHistIndexMatrixCreation(size_t nthreads) {
/* This should create multiple sparse pages */ /* This should create multiple sparse pages */
std::unique_ptr<DMatrix> dmat{ CreateSparsePageDMatrix(kEntries) }; std::unique_ptr<DMatrix> dmat{ CreateSparsePageDMatrix(kEntries) };
omp_set_num_threads(nthreads); omp_set_num_threads(nthreads);
GHistIndexMatrix gmat(dmat.get(), 256); GHistIndexMatrix gmat(dmat.get(), 256, false);
} }
TEST(HistIndexCreationWithExternalMemory, Test) { TEST(HistIndexCreationWithExternalMemory, Test) {

View File

@ -223,13 +223,19 @@ TEST(HistUtil, DenseCutsAccuracyTestWeights) {
auto w = GenerateRandomWeights(num_rows); auto w = GenerateRandomWeights(num_rows);
dmat->Info().weights_.HostVector() = w; dmat->Info().weights_.HostVector() = w;
for (auto num_bins : bin_sizes) { for (auto num_bins : bin_sizes) {
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins); {
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, true);
ValidateCuts(cuts, dmat.get(), num_bins);
}
{
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, false);
ValidateCuts(cuts, dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins);
} }
} }
} }
}
TEST(HistUtil, QuantileWithHessian) { void TestQuantileWithHessian(bool use_sorted) {
int bin_sizes[] = {2, 16, 256, 512}; int bin_sizes[] = {2, 16, 256, 512};
int sizes[] = {1000, 1500}; int sizes[] = {1000, 1500};
int num_columns = 5; int num_columns = 5;
@ -243,13 +249,13 @@ TEST(HistUtil, QuantileWithHessian) {
dmat->Info().weights_.HostVector() = w; dmat->Info().weights_.HostVector() = w;
for (auto num_bins : bin_sizes) { for (auto num_bins : bin_sizes) {
HistogramCuts cuts_hess = SketchOnDMatrix(dmat.get(), num_bins, hessian); HistogramCuts cuts_hess = SketchOnDMatrix(dmat.get(), num_bins, use_sorted, hessian);
for (size_t i = 0; i < w.size(); ++i) { for (size_t i = 0; i < w.size(); ++i) {
dmat->Info().weights_.HostVector()[i] = w[i] * hessian[i]; dmat->Info().weights_.HostVector()[i] = w[i] * hessian[i];
} }
ValidateCuts(cuts_hess, dmat.get(), num_bins); ValidateCuts(cuts_hess, dmat.get(), num_bins);
HistogramCuts cuts_wh = SketchOnDMatrix(dmat.get(), num_bins); HistogramCuts cuts_wh = SketchOnDMatrix(dmat.get(), num_bins, use_sorted);
ValidateCuts(cuts_wh, dmat.get(), num_bins); ValidateCuts(cuts_wh, dmat.get(), num_bins);
ASSERT_EQ(cuts_hess.Values().size(), cuts_wh.Values().size()); ASSERT_EQ(cuts_hess.Values().size(), cuts_wh.Values().size());
@ -262,6 +268,11 @@ TEST(HistUtil, QuantileWithHessian) {
} }
} }
TEST(HistUtil, QuantileWithHessian) {
TestQuantileWithHessian(true);
TestQuantileWithHessian(false);
}
TEST(HistUtil, DenseCutsExternalMemory) { TEST(HistUtil, DenseCutsExternalMemory) {
int bin_sizes[] = {2, 16, 256, 512}; int bin_sizes[] = {2, 16, 256, 512};
int sizes[] = {100, 1000, 1500}; int sizes[] = {100, 1000, 1500};
@ -292,7 +303,7 @@ TEST(HistUtil, IndexBinBound) {
for (auto max_bin : bin_sizes) { for (auto max_bin : bin_sizes) {
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
GHistIndexMatrix hmat(p_fmat.get(), max_bin); GHistIndexMatrix hmat(p_fmat.get(), max_bin, false);
EXPECT_EQ(hmat.index.Size(), kRows*kCols); EXPECT_EQ(hmat.index.Size(), kRows*kCols);
EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.GetBinTypeSize()); EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.GetBinTypeSize());
} }
@ -315,7 +326,7 @@ TEST(HistUtil, IndexBinData) {
for (auto max_bin : kBinSizes) { for (auto max_bin : kBinSizes) {
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
GHistIndexMatrix hmat(p_fmat.get(), max_bin); GHistIndexMatrix hmat(p_fmat.get(), max_bin, false);
uint32_t* offsets = hmat.index.Offset(); uint32_t* offsets = hmat.index.Offset();
EXPECT_EQ(hmat.index.Size(), kRows*kCols); EXPECT_EQ(hmat.index.Size(), kRows*kCols);
switch (max_bin) { switch (max_bin) {

View File

@ -19,7 +19,22 @@ TEST(Quantile, LoadBalance) {
} }
CHECK_EQ(n_cols, kCols); CHECK_EQ(n_cols, kCols);
} }
namespace {
template <bool use_column>
using ContainerType = std::conditional_t<use_column, SortedSketchContainer, HostSketchContainer>;
// Dispatch for push page.
void PushPage(SortedSketchContainer* container, SparsePage const& page, MetaInfo const& info,
Span<float const> hessian) {
container->PushColPage(page, info, hessian);
}
void PushPage(HostSketchContainer* container, SparsePage const& page, MetaInfo const& info,
Span<float const> hessian) {
container->PushRowPage(page, info, hessian);
}
} // anonymous namespace
template <bool use_column>
void TestDistributedQuantile(size_t rows, size_t cols) { void TestDistributedQuantile(size_t rows, size_t cols) {
std::string msg {"Skipping AllReduce test"}; std::string msg {"Skipping AllReduce test"};
int32_t constexpr kWorkers = 4; int32_t constexpr kWorkers = 4;
@ -48,12 +63,23 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
.Lower(.0f) .Lower(.0f)
.Upper(1.0f) .Upper(1.0f)
.GenerateDMatrix(); .GenerateDMatrix();
HostSketchContainer sketch_distributed(
column_size, n_bins, m->Info().feature_types.ConstHostSpan(), false, std::vector<float> hessian(rows, 1.0);
auto hess = Span<float const>{hessian};
ContainerType<use_column> sketch_distributed(n_bins, m->Info(), column_size, false, hess,
OmpGetNumThreads(0)); OmpGetNumThreads(0));
for (auto const &page : m->GetBatches<SparsePage>()) {
sketch_distributed.PushRowPage(page, m->Info()); if (use_column) {
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
PushPage(&sketch_distributed, page, m->Info(), hess);
} }
} else {
for (auto const& page : m->GetBatches<SparsePage>()) {
PushPage(&sketch_distributed, page, m->Info(), hess);
}
}
HistogramCuts distributed_cuts; HistogramCuts distributed_cuts;
sketch_distributed.MakeCuts(&distributed_cuts); sketch_distributed.MakeCuts(&distributed_cuts);
@ -61,17 +87,25 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
rabit::Finalize(); rabit::Finalize();
CHECK_EQ(rabit::GetWorldSize(), 1); CHECK_EQ(rabit::GetWorldSize(), 1);
std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; }); std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; });
HostSketchContainer sketch_on_single_node( m->Info().num_row_ = world * rows;
column_size, n_bins, m->Info().feature_types.ConstHostSpan(), false, ContainerType<use_column> sketch_on_single_node(n_bins, m->Info(), column_size, false, hess,
OmpGetNumThreads(0)); OmpGetNumThreads(0));
m->Info().num_row_ = rows;
for (auto rank = 0; rank < world; ++rank) { for (auto rank = 0; rank < world; ++rank) {
auto m = RandomDataGenerator{rows, cols, sparsity} auto m = RandomDataGenerator{rows, cols, sparsity}
.Seed(rank) .Seed(rank)
.Lower(.0f) .Lower(.0f)
.Upper(1.0f) .Upper(1.0f)
.GenerateDMatrix(); .GenerateDMatrix();
if (use_column) {
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
PushPage(&sketch_on_single_node, page, m->Info(), hess);
}
} else {
for (auto const& page : m->GetBatches<SparsePage>()) { for (auto const& page : m->GetBatches<SparsePage>()) {
sketch_on_single_node.PushRowPage(page, m->Info()); PushPage(&sketch_on_single_node, page, m->Info(), hess);
}
} }
} }
@ -87,7 +121,7 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
ASSERT_EQ(sptrs.size(), dptrs.size()); ASSERT_EQ(sptrs.size(), dptrs.size());
for (size_t i = 0; i < sptrs.size(); ++i) { for (size_t i = 0; i < sptrs.size(); ++i) {
ASSERT_EQ(sptrs[i], dptrs[i]); ASSERT_EQ(sptrs[i], dptrs[i]) << i;
} }
ASSERT_EQ(svals.size(), dvals.size()); ASSERT_EQ(svals.size(), dvals.size());
@ -104,14 +138,28 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
TEST(Quantile, DistributedBasic) { TEST(Quantile, DistributedBasic) {
#if defined(__unix__) #if defined(__unix__)
constexpr size_t kRows = 10, kCols = 10; constexpr size_t kRows = 10, kCols = 10;
TestDistributedQuantile(kRows, kCols); TestDistributedQuantile<false>(kRows, kCols);
#endif #endif
} }
TEST(Quantile, Distributed) { TEST(Quantile, Distributed) {
#if defined(__unix__) #if defined(__unix__)
constexpr size_t kRows = 1000, kCols = 200; constexpr size_t kRows = 4000, kCols = 200;
TestDistributedQuantile(kRows, kCols); TestDistributedQuantile<false>(kRows, kCols);
#endif
}
TEST(Quantile, SortedDistributedBasic) {
#if defined(__unix__)
constexpr size_t kRows = 10, kCols = 10;
TestDistributedQuantile<true>(kRows, kCols);
#endif
}
TEST(Quantile, SortedDistributed) {
#if defined(__unix__)
constexpr size_t kRows = 4000, kCols = 200;
TestDistributedQuantile<true>(kRows, kCols);
#endif #endif
} }

View File

@ -36,7 +36,7 @@ TEST(GradientIndex, FromCategoricalBasic) {
BatchParam p(0, max_bins); BatchParam p(0, max_bins);
GHistIndexMatrix gidx; GHistIndexMatrix gidx;
gidx.Init(m.get(), max_bins, {}); gidx.Init(m.get(), max_bins, false, {});
auto x_copy = x; auto x_copy = x;
std::sort(x_copy.begin(), x_copy.end()); std::sort(x_copy.begin(), x_copy.end());

View File

@ -29,7 +29,7 @@ template <typename GradientSumT> void TestEvaluateSplits() {
size_t constexpr kMaxBins = 4; size_t constexpr kMaxBins = 4;
// dense, no missing values // dense, no missing values
GHistIndexMatrix gmat(dmat.get(), kMaxBins); GHistIndexMatrix gmat(dmat.get(), kMaxBins, false);
common::RowSetCollection row_set_collection; common::RowSetCollection row_set_collection;
std::vector<size_t> &row_indices = *row_set_collection.Data(); std::vector<size_t> &row_indices = *row_set_collection.Data();
row_indices.resize(kRows); row_indices.resize(kRows);

View File

@ -162,7 +162,7 @@ class QuantileHistMock : public QuantileHistMaker {
// kNRows samples with kNCols features // kNRows samples with kNCols features
auto dmat = RandomDataGenerator(kNRows, kNCols, sparsity).Seed(3).GenerateDMatrix(); auto dmat = RandomDataGenerator(kNRows, kNCols, sparsity).Seed(3).GenerateDMatrix();
GHistIndexMatrix gmat(dmat.get(), kMaxBins); GHistIndexMatrix gmat(dmat.get(), kMaxBins, false);
ColumnMatrix cm; ColumnMatrix cm;
// treat everything as dense, as this is what we intend to test here // treat everything as dense, as this is what we intend to test here
@ -253,7 +253,7 @@ class QuantileHistMock : public QuantileHistMaker {
void TestInitData() { void TestInitData() {
size_t constexpr kMaxBins = 4; size_t constexpr kMaxBins = 4;
GHistIndexMatrix gmat(dmat_.get(), kMaxBins); GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false);
RegTree tree = RegTree(); RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_); tree.param.UpdateAllowUnknown(cfg_);
@ -270,7 +270,7 @@ class QuantileHistMock : public QuantileHistMaker {
void TestInitDataSampling() { void TestInitDataSampling() {
size_t constexpr kMaxBins = 4; size_t constexpr kMaxBins = 4;
GHistIndexMatrix gmat(dmat_.get(), kMaxBins); GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false);
RegTree tree = RegTree(); RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_); tree.param.UpdateAllowUnknown(cfg_);