Implement sketching with Hessian on GPU. (#9399)

- Prepare for implementing approx on GPU.
- Unify the code path between weighted and uniform sketching on DMatrix.
This commit is contained in:
Jiaming Yuan
2023-07-24 15:43:03 +08:00
committed by GitHub
parent 851cba931e
commit a196443a07
14 changed files with 446 additions and 230 deletions

View File

@@ -131,7 +131,7 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchP
monitor_.Start("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
row_stride = GetRowStride(dmat);
cuts_ = common::DeviceSketch(ctx->gpu_id, dmat, param.max_bin);
cuts_ = common::DeviceSketch(ctx, dmat, param.max_bin);
monitor_.Stop("Quantiles");
monitor_.Start("InitCompressedData");

View File

@@ -21,7 +21,7 @@ GHistIndexMatrix::GHistIndexMatrix() : columns_{std::make_unique<common::ColumnM
GHistIndexMatrix::GHistIndexMatrix(Context const *ctx, DMatrix *p_fmat, bst_bin_t max_bins_per_feat,
double sparse_thresh, bool sorted_sketch,
common::Span<float> hess)
common::Span<float const> hess)
: max_numeric_bins_per_feat{max_bins_per_feat} {
CHECK(p_fmat->SingleColBlock());
// We use sorted sketching for approx tree method since it's more efficient in

View File

@@ -160,7 +160,7 @@ class GHistIndexMatrix {
* \brief Constrcutor for SimpleDMatrix.
*/
GHistIndexMatrix(Context const* ctx, DMatrix* x, bst_bin_t max_bins_per_feat,
double sparse_thresh, bool sorted_sketch, common::Span<float> hess = {});
double sparse_thresh, bool sorted_sketch, common::Span<float const> hess = {});
/**
* \brief Constructor for Iterative DMatrix. Initialize basic information and prepare
* for push batch.

View File

@@ -25,8 +25,8 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
cache_info_.erase(id);
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
std::unique_ptr<common::HistogramCuts> cuts;
cuts = std::make_unique<common::HistogramCuts>(
common::DeviceSketch(ctx->gpu_id, this, param.max_bin, 0));
cuts =
std::make_unique<common::HistogramCuts>(common::DeviceSketch(ctx, this, param.max_bin, 0));
this->InitializeSparsePage(ctx); // reset after use.
row_stride = GetRowStride(this);