Initial GPU support for the approx tree method. (#9414)

This commit is contained in:
Jiaming Yuan
2023-07-31 15:50:28 +08:00
committed by GitHub
parent 8f0efb4ab3
commit 912e341d57
23 changed files with 639 additions and 360 deletions

View File

@@ -11,7 +11,6 @@
#include "../common/categorical.h"
#include "../common/cuda_context.cuh"
#include "../common/hist_util.cuh"
#include "../common/random.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "./ellpack_page.cuh"
#include "device_adapter.cuh" // for HasInfInData
@@ -131,7 +130,11 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchP
monitor_.Start("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
row_stride = GetRowStride(dmat);
cuts_ = common::DeviceSketch(ctx, dmat, param.max_bin);
if (!param.hess.empty()) {
cuts_ = common::DeviceSketchWithHessian(ctx, dmat, param.max_bin, param.hess);
} else {
cuts_ = common::DeviceSketch(ctx, dmat, param.max_bin);
}
monitor_.Stop("Quantiles");
monitor_.Start("InitCompressedData");

View File

@@ -7,13 +7,12 @@
#include <algorithm>
#include <limits>
#include <memory>
#include <utility> // std::forward
#include <utility> // for forward
#include "../common/column_matrix.h"
#include "../common/hist_util.h"
#include "../common/numeric.h"
#include "../common/threading_utils.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "../common/transform_iterator.h" // for MakeIndexTransformIter
namespace xgboost {

View File

@@ -8,12 +8,12 @@
#include <algorithm>
#include <limits>
#include <numeric> // for accumulate
#include <type_traits>
#include <vector>
#include "../common/error_msg.h" // for InconsistentMaxBin
#include "../common/random.h"
#include "../common/threading_utils.h"
#include "../collective/communicator-inl.h" // for GetWorldSize, GetRank, Allgather
#include "../common/error_msg.h" // for InconsistentMaxBin
#include "./simple_batch_iterator.h"
#include "adapter.h"
#include "batch_utils.h" // for CheckEmpty, RegenGHist

View File

@@ -8,7 +8,6 @@
#include "./sparse_page_dmatrix.h"
#include "../collective/communicator-inl.h"
#include "./simple_batch_iterator.h"
#include "batch_utils.h" // for RegenGHist
#include "gradient_index.h"

View File

@@ -1,13 +1,15 @@
/**
* Copyright 2021-2023 by XGBoost contributors
*/
#include <memory>
#include <memory> // for unique_ptr
#include "../common/hist_util.cuh"
#include "batch_utils.h" // for CheckEmpty, RegenGHist
#include "../common/hist_util.h" // for HistogramCuts
#include "batch_utils.h" // for CheckEmpty, RegenGHist
#include "ellpack_page.cuh"
#include "sparse_page_dmatrix.h"
#include "sparse_page_source.h"
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for BatchParam
namespace xgboost::data {
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
@@ -25,8 +27,13 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
cache_info_.erase(id);
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
std::unique_ptr<common::HistogramCuts> cuts;
cuts =
std::make_unique<common::HistogramCuts>(common::DeviceSketch(ctx, this, param.max_bin, 0));
if (!param.hess.empty()) {
cuts = std::make_unique<common::HistogramCuts>(
common::DeviceSketchWithHessian(ctx, this, param.max_bin, param.hess));
} else {
cuts =
std::make_unique<common::HistogramCuts>(common::DeviceSketch(ctx, this, param.max_bin));
}
this->InitializeSparsePage(ctx); // reset after use.
row_stride = GetRowStride(this);
@@ -35,10 +42,10 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
batch_param_ = param;
auto ft = this->info_.feature_types.ConstDeviceSpan();
ellpack_page_source_.reset(); // release resources.
ellpack_page_source_.reset(new EllpackPageSource(
ellpack_page_source_.reset(); // make sure resource is released before making new ones.
ellpack_page_source_ = std::make_shared<EllpackPageSource>(
this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id),
param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_, ctx->gpu_id));
param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_, ctx->gpu_id);
} else {
CHECK(sparse_page_source_);
ellpack_page_source_->Reset();