[EM] Multi-level quantile sketching for GPU. (#10813)

This commit is contained in:
Jiaming Yuan
2024-09-10 13:08:34 +08:00
committed by GitHub
parent 3ef8383d93
commit ed5f33df16
6 changed files with 111 additions and 34 deletions

View File

@@ -143,17 +143,30 @@ class SketchContainer {
*/
void Push(Context const* ctx, Span<Entry const> entries, Span<size_t> columns_ptr,
common::Span<OffsetT> cuts_ptr, size_t total_cuts, Span<float> weights = {});
/* \brief Prune the quantile structure.
/**
* @brief Prune the quantile structure.
*
* \param to The maximum size of pruned quantile. If the size of quantile
* structure is already less than `to`, then no operation is performed.
* @param to The maximum size of pruned quantile. If the size of quantile structure is
* already less than `to`, then no operation is performed.
*/
void Prune(Context const* ctx, size_t to);
/* \brief Merge another set of sketch.
* \param that columns of other.
/**
* @brief Merge another set of sketch.
*
* @param that_columns_ptr Column pointer of the quantile summary being merged.
* @param that Columns of the other quantile summary.
*/
void Merge(Context const* ctx, Span<OffsetT const> that_columns_ptr,
Span<SketchEntry const> that);
/**
* @brief Shrink the internal data structure to reduce memory usage. Can be used after
* prune.
*/
void ShrinkToFit() {
this->Current().shrink_to_fit();
this->Other().clear();
this->Other().shrink_to_fit();
}
/* \brief Merge quantiles from other GPU workers. */
void AllReduce(Context const* ctx, bool is_column_split);

View File

@@ -3,6 +3,7 @@
*/
#include <algorithm> // for max
#include <numeric> // for partial_sum
#include <utility> // for pair
#include <vector> // for vector
#include "../collective/allreduce.h" // for Allreduce
@@ -29,11 +30,39 @@ void MakeSketches(Context const* ctx,
float missing, std::shared_ptr<common::HistogramCuts> cuts, MetaInfo const& info,
ExternalDataInfo* p_ext_info) {
xgboost_NVTX_FN_RANGE();
std::unique_ptr<common::SketchContainer> sketch;
/**
* A variant of: A Fast Algorithm for Approximate Quantiles in High Speed Data Streams
*
* The original algorithm was designed for CPU where input is a stream with individual
* elements. For GPU, we process the data in batches. As a result, the implementation
* here simply uses the user input batch as the basic unit of sketching blocks. The
* number of blocks per-level grows exponentially.
*/
std::vector<std::pair<std::unique_ptr<common::SketchContainer>, bst_idx_t>> sketches;
auto& ext_info = *p_ext_info;
auto lazy_init_sketch = [&] {
// Lazy because we need the `n_features`.
sketches.emplace_back(std::make_unique<common::SketchContainer>(
proxy->Info().feature_types, p.max_bin, ext_info.n_features,
data::BatchSamples(proxy), dh::GetDevice(ctx)),
0);
};
// Workaround empty input with CPU ctx.
Context new_ctx;
Context const* p_ctx;
if (ctx->IsCUDA()) {
p_ctx = ctx;
} else {
new_ctx.UpdateAllowUnknown(Args{{"device", dh::GetDevice(ctx).Name()}});
p_ctx = &new_ctx;
}
do {
/**
* Get the data shape.
*/
// We use do while here as the first batch is fetched in ctor
CHECK_LT(ctx->Ordinal(), common::AllVisibleGPUs());
common::SetDevice(dh::GetDevice(ctx).ordinal);
@@ -46,28 +75,38 @@ void MakeSketches(Context const* ctx,
CHECK_EQ(ext_info.n_features, ::xgboost::data::BatchColumns(proxy))
<< "Inconsistent number of columns.";
}
auto batch_rows = data::BatchSamples(proxy);
ext_info.accumulated_rows += batch_rows;
/**
* Handle sketching.
*/
if (!ref) {
if (!sketch) {
sketch = std::make_unique<common::SketchContainer>(
proxy->Info().feature_types, p.max_bin, ext_info.n_features, data::BatchSamples(proxy),
dh::GetDevice(ctx));
if (sketches.empty()) {
lazy_init_sketch();
}
if (sketches.back().second > (1ul << (sketches.size() - 1))) {
auto n_cuts_per_feat =
common::detail::RequiredSampleCutsPerColumn(p.max_bin, ext_info.accumulated_rows);
// Prune to a single block
sketches.back().first->Prune(p_ctx, n_cuts_per_feat);
sketches.back().first->ShrinkToFit();
sketches.back().second = 1;
lazy_init_sketch(); // Add a new level.
}
proxy->Info().weights_.SetDevice(dh::GetDevice(ctx));
cuda_impl::Dispatch(proxy, [&](auto const& value) {
// Workaround empty input with CPU ctx.
Context new_ctx;
Context const* p_ctx;
if (ctx->IsCUDA()) {
p_ctx = ctx;
} else {
new_ctx.UpdateAllowUnknown(Args{{"device", dh::GetDevice(ctx).Name()}});
p_ctx = &new_ctx;
}
common::AdapterDeviceSketch(p_ctx, value, p.max_bin, proxy->Info(), missing, sketch.get());
common::AdapterDeviceSketch(p_ctx, value, p.max_bin, proxy->Info(), missing,
sketches.back().first.get());
sketches.back().second++;
});
}
auto batch_rows = data::BatchSamples(proxy);
ext_info.accumulated_rows += batch_rows;
/**
* Rest of the data shape.
*/
dh::device_vector<size_t> row_counts(batch_rows + 1, 0);
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
ext_info.row_stride =
@@ -87,7 +126,28 @@ void MakeSketches(Context const* ctx,
// Get reference
common::SetDevice(dh::GetDevice(ctx).ordinal);
if (!ref) {
sketch->MakeCuts(ctx, cuts.get(), info.IsColumnSplit());
HostDeviceVector<FeatureType> ft;
common::SketchContainer final_sketch(
sketches.empty() ? ft : sketches.front().first->FeatureTypes(), p.max_bin,
ext_info.n_features, ext_info.accumulated_rows, dh::GetDevice(ctx));
// Reverse order since the last container might contain summary that's not yet pruned.
for (auto it = sketches.crbegin(); it != sketches.crend(); ++it) {
auto& sketch = *it;
CHECK_GE(sketch.second, 1);
if (sketch.second > 1) {
sketch.first->Prune(p_ctx, common::detail::RequiredSampleCutsPerColumn(
p.max_bin, ext_info.accumulated_rows));
sketch.first->ShrinkToFit();
}
final_sketch.Merge(p_ctx, sketch.first->ColumnsPtr(), sketch.first->Data());
final_sketch.FixError();
}
sketches.clear();
sketches.shrink_to_fit();
final_sketch.MakeCuts(ctx, cuts.get(), info.IsColumnSplit());
} else {
GetCutsFromRef(ctx, ref, ext_info.n_features, p, cuts.get());
}

View File

@@ -289,11 +289,11 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
auto page = std::make_shared<S>();
this->exce_.Run([&] {
std::unique_ptr<typename FormatStreamPolicy::FormatT> fmt{
this->CreatePageFormat(this->param_)};
self->CreatePageFormat(self->param_)};
auto name = self->cache_info_->ShardName();
auto [offset, length] = self->cache_info_->View(fetch_it);
std::unique_ptr<typename FormatStreamPolicy::ReaderT> fi{
this->CreateReader(name, offset, length)};
self->CreateReader(name, offset, length)};
CHECK(fmt->Read(page.get(), fi.get()));
});
return page;