Initial GPU support for the approx tree method. (#9414)
This commit is contained in:
@@ -11,7 +11,6 @@
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/cuda_context.cuh"
|
||||
#include "../common/hist_util.cuh"
|
||||
#include "../common/random.h"
|
||||
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "./ellpack_page.cuh"
|
||||
#include "device_adapter.cuh" // for HasInfInData
|
||||
@@ -131,7 +130,11 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchP
|
||||
monitor_.Start("Quantiles");
|
||||
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
|
||||
row_stride = GetRowStride(dmat);
|
||||
cuts_ = common::DeviceSketch(ctx, dmat, param.max_bin);
|
||||
if (!param.hess.empty()) {
|
||||
cuts_ = common::DeviceSketchWithHessian(ctx, dmat, param.max_bin, param.hess);
|
||||
} else {
|
||||
cuts_ = common::DeviceSketch(ctx, dmat, param.max_bin);
|
||||
}
|
||||
monitor_.Stop("Quantiles");
|
||||
|
||||
monitor_.Start("InitCompressedData");
|
||||
|
||||
@@ -7,13 +7,12 @@
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <utility> // std::forward
|
||||
#include <utility> // for forward
|
||||
|
||||
#include "../common/column_matrix.h"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/numeric.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "../common/transform_iterator.h" // for MakeIndexTransformIter
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
|
||||
@@ -8,12 +8,12 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <numeric> // for accumulate
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "../common/error_msg.h" // for InconsistentMaxBin
|
||||
#include "../common/random.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../collective/communicator-inl.h" // for GetWorldSize, GetRank, Allgather
|
||||
#include "../common/error_msg.h" // for InconsistentMaxBin
|
||||
#include "./simple_batch_iterator.h"
|
||||
#include "adapter.h"
|
||||
#include "batch_utils.h" // for CheckEmpty, RegenGHist
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
#include "./sparse_page_dmatrix.h"
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "./simple_batch_iterator.h"
|
||||
#include "batch_utils.h" // for RegenGHist
|
||||
#include "gradient_index.h"
|
||||
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
/**
|
||||
* Copyright 2021-2023 by XGBoost contributors
|
||||
*/
|
||||
#include <memory>
|
||||
#include <memory> // for unique_ptr
|
||||
|
||||
#include "../common/hist_util.cuh"
|
||||
#include "batch_utils.h" // for CheckEmpty, RegenGHist
|
||||
#include "../common/hist_util.h" // for HistogramCuts
|
||||
#include "batch_utils.h" // for CheckEmpty, RegenGHist
|
||||
#include "ellpack_page.cuh"
|
||||
#include "sparse_page_dmatrix.h"
|
||||
#include "sparse_page_source.h"
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/data.h" // for BatchParam
|
||||
|
||||
namespace xgboost::data {
|
||||
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
|
||||
@@ -25,8 +27,13 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
|
||||
cache_info_.erase(id);
|
||||
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
||||
std::unique_ptr<common::HistogramCuts> cuts;
|
||||
cuts =
|
||||
std::make_unique<common::HistogramCuts>(common::DeviceSketch(ctx, this, param.max_bin, 0));
|
||||
if (!param.hess.empty()) {
|
||||
cuts = std::make_unique<common::HistogramCuts>(
|
||||
common::DeviceSketchWithHessian(ctx, this, param.max_bin, param.hess));
|
||||
} else {
|
||||
cuts =
|
||||
std::make_unique<common::HistogramCuts>(common::DeviceSketch(ctx, this, param.max_bin));
|
||||
}
|
||||
this->InitializeSparsePage(ctx); // reset after use.
|
||||
|
||||
row_stride = GetRowStride(this);
|
||||
@@ -35,10 +42,10 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
|
||||
batch_param_ = param;
|
||||
|
||||
auto ft = this->info_.feature_types.ConstDeviceSpan();
|
||||
ellpack_page_source_.reset(); // release resources.
|
||||
ellpack_page_source_.reset(new EllpackPageSource(
|
||||
ellpack_page_source_.reset(); // make sure resource is released before making new ones.
|
||||
ellpack_page_source_ = std::make_shared<EllpackPageSource>(
|
||||
this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id),
|
||||
param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_, ctx->gpu_id));
|
||||
param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_, ctx->gpu_id);
|
||||
} else {
|
||||
CHECK(sparse_page_source_);
|
||||
ellpack_page_source_->Reset();
|
||||
|
||||
Reference in New Issue
Block a user