[EM] Return a full DMatrix instead of a Ellpack from the GPU sampler. (#10753)

This commit is contained in:
Jiaming Yuan 2024-08-28 01:05:11 +08:00 committed by GitHub
parent d6ebcfb032
commit bde1265caf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 525 additions and 214 deletions

4
.gitignore vendored
View File

@ -63,6 +63,7 @@ java/xgboost4j-demo/data/
java/xgboost4j-demo/tmp/ java/xgboost4j-demo/tmp/
java/xgboost4j-demo/model/ java/xgboost4j-demo/model/
nb-configuration* nb-configuration*
# Eclipse # Eclipse
.project .project
.cproject .cproject
@ -154,3 +155,6 @@ model*.json
*.rds *.rds
Rplots.pdf Rplots.pdf
*.zip *.zip
# nsys
*.nsys-rep

View File

@ -110,8 +110,15 @@ class MetaInfo {
* @brief Validate all metainfo. * @brief Validate all metainfo.
*/ */
void Validate(DeviceOrd device) const; void Validate(DeviceOrd device) const;
/**
MetaInfo Slice(common::Span<int32_t const> ridxs) const; * @brief Slice the meta info.
*
* The device of ridxs is specified by the ctx object.
*
* @param ridxs Index of selected rows.
* @param nnz The number of non-missing values.
*/
MetaInfo Slice(Context const* ctx, common::Span<bst_idx_t const> ridxs, bst_idx_t nnz) const;
MetaInfo Copy() const; MetaInfo Copy() const;
/** /**

View File

@ -508,6 +508,11 @@ xgboost::common::Span<T> ToSpan(DeviceUVector<T> &vec) {
return {vec.data(), vec.size()}; return {vec.data(), vec.size()};
} }
template <typename T>
xgboost::common::Span<std::add_const_t<T>> ToSpan(DeviceUVector<T> const &vec) {
return {vec.data(), vec.size()};
}
// thrust begin, similiar to std::begin // thrust begin, similiar to std::begin
template <typename T> template <typename T>
thrust::device_ptr<T> tbegin(xgboost::HostDeviceVector<T>& vector) { // NOLINT thrust::device_ptr<T> tbegin(xgboost::HostDeviceVector<T>& vector) { // NOLINT

View File

@ -76,7 +76,7 @@ struct IterOp {
// returns a thrust iterator for a tensor view. // returns a thrust iterator for a tensor view.
template <typename T, std::int32_t kDim> template <typename T, std::int32_t kDim>
auto tcbegin(TensorView<T, kDim> v) { // NOLINT auto tcbegin(TensorView<T, kDim> v) { // NOLINT
return dh::MakeTransformIterator<T>( return thrust::make_transform_iterator(
thrust::make_counting_iterator(0ul), thrust::make_counting_iterator(0ul),
detail::IterOp<std::add_const_t<std::remove_const_t<T>>, kDim>{v}); detail::IterOp<std::add_const_t<std::remove_const_t<T>>, kDim>{v});
} }
@ -85,5 +85,16 @@ template <typename T, std::int32_t kDim>
auto tcend(TensorView<T, kDim> v) { // NOLINT auto tcend(TensorView<T, kDim> v) { // NOLINT
return tcbegin(v) + v.Size(); return tcbegin(v) + v.Size();
} }
template <typename T, std::int32_t kDim>
auto tbegin(TensorView<T, kDim> v) { // NOLINT
return thrust::make_transform_iterator(thrust::make_counting_iterator(0ul),
detail::IterOp<std::remove_const_t<T>, kDim>{v});
}
template <typename T, std::int32_t kDim>
auto tend(TensorView<T, kDim> v) { // NOLINT
return tbegin(v) + v.Size();
}
} // namespace xgboost::linalg } // namespace xgboost::linalg
#endif // XGBOOST_COMMON_LINALG_OP_CUH_ #endif // XGBOOST_COMMON_LINALG_OP_CUH_

View File

@ -351,8 +351,10 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
this->has_categorical_ = LoadFeatureType(feature_type_names, &feature_types.HostVector()); this->has_categorical_ = LoadFeatureType(feature_type_names, &feature_types.HostVector());
} }
namespace {
template <typename T> template <typename T>
std::vector<T> Gather(const std::vector<T> &in, common::Span<int const> ridxs, size_t stride = 1) { std::vector<T> Gather(const std::vector<T>& in, common::Span<bst_idx_t const> ridxs,
size_t stride = 1) {
if (in.empty()) { if (in.empty()) {
return {}; return {};
} }
@ -361,16 +363,56 @@ std::vector<T> Gather(const std::vector<T> &in, common::Span<int const> ridxs, s
for (auto i = 0ull; i < size; i++) { for (auto i = 0ull; i < size; i++) {
auto ridx = ridxs[i]; auto ridx = ridxs[i];
for (size_t j = 0; j < stride; ++j) { for (size_t j = 0; j < stride; ++j) {
out[i * stride +j] = in[ridx * stride + j]; out[i * stride + j] = in[ridx * stride + j];
} }
} }
return out; return out;
} }
} // namespace
MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const { namespace cuda_impl {
void SliceMetaInfo(Context const* ctx, MetaInfo const& info, common::Span<bst_idx_t const> ridx,
MetaInfo* p_out);
#if !defined(XGBOOST_USE_CUDA)
void SliceMetaInfo(Context const*, MetaInfo const&, common::Span<bst_idx_t const>, MetaInfo*) {
common::AssertGPUSupport();
}
#endif
} // namespace cuda_impl
MetaInfo MetaInfo::Slice(Context const* ctx, common::Span<bst_idx_t const> ridxs,
bst_idx_t nnz) const {
/**
* Shape
*/
MetaInfo out; MetaInfo out;
out.num_row_ = ridxs.size(); out.num_row_ = ridxs.size();
out.num_col_ = this->num_col_; out.num_col_ = this->num_col_;
out.num_nonzero_ = nnz;
/**
* Feature Info
*/
out.feature_weights.SetDevice(ctx->Device());
out.feature_weights.Resize(this->feature_weights.Size());
out.feature_weights.Copy(this->feature_weights);
out.feature_names = this->feature_names;
out.feature_types.SetDevice(ctx->Device());
out.feature_types.Resize(this->feature_types.Size());
out.feature_types.Copy(this->feature_types);
out.feature_type_names = this->feature_type_names;
/**
* Sample Info
*/
if (ctx->IsCUDA()) {
cuda_impl::SliceMetaInfo(ctx, *this, ridxs, &out);
return out;
}
// Groups is maintained by a higher level Python function. We should aim at deprecating // Groups is maintained by a higher level Python function. We should aim at deprecating
// the slice function. // the slice function.
if (this->labels.Size() != this->num_row_) { if (this->labels.Size() != this->num_row_) {
@ -386,13 +428,11 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
}); });
} }
out.labels_upper_bound_.HostVector() = out.labels_upper_bound_.HostVector() = Gather(this->labels_upper_bound_.HostVector(), ridxs);
Gather(this->labels_upper_bound_.HostVector(), ridxs); out.labels_lower_bound_.HostVector() = Gather(this->labels_lower_bound_.HostVector(), ridxs);
out.labels_lower_bound_.HostVector() =
Gather(this->labels_lower_bound_.HostVector(), ridxs);
// weights // weights
if (this->weights_.Size() + 1 == this->group_ptr_.size()) { if (this->weights_.Size() + 1 == this->group_ptr_.size()) {
auto& h_weights = out.weights_.HostVector(); auto& h_weights = out.weights_.HostVector();
// Assuming all groups are available. // Assuming all groups are available.
out.weights_.HostVector() = h_weights; out.weights_.HostVector() = h_weights;
} else { } else {
@ -414,14 +454,6 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
}); });
} }
out.feature_weights.Resize(this->feature_weights.Size());
out.feature_weights.Copy(this->feature_weights);
out.feature_names = this->feature_names;
out.feature_types.Resize(this->feature_types.Size());
out.feature_types.Copy(this->feature_types);
out.feature_type_names = this->feature_type_names;
return out; return out;
} }

View File

@ -1,9 +1,11 @@
/** /**
* Copyright 2019-2022 by XGBoost Contributors * Copyright 2019-2024, XGBoost Contributors
* *
* \file data.cu * \file data.cu
* \brief Handles setting metainfo from array interface. * \brief Handles setting metainfo from array interface.
*/ */
#include <thrust/gather.h> // for gather
#include "../common/cuda_context.cuh" #include "../common/cuda_context.cuh"
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
#include "../common/linalg_op.cuh" #include "../common/linalg_op.cuh"
@ -169,6 +171,62 @@ void MetaInfo::SetInfoFromCUDA(Context const& ctx, StringView key, Json array) {
} }
} }
namespace {
void Gather(Context const* ctx, linalg::MatrixView<float const> in,
common::Span<bst_idx_t const> ridx, linalg::Matrix<float>* p_out) {
if (in.Empty()) {
return;
}
auto& out = *p_out;
out.Reshape(ridx.size(), in.Shape(1));
auto d_out = out.View(ctx->Device());
auto cuctx = ctx->CUDACtx();
auto map_it = thrust::make_transform_iterator(thrust::make_counting_iterator(0ull),
[=] XGBOOST_DEVICE(bst_idx_t i) {
auto [r, c] = linalg::UnravelIndex(i, in.Shape());
return (ridx[r] * in.Shape(1)) + c;
});
CHECK_NE(in.Shape(1), 0);
thrust::gather(cuctx->TP(), map_it, map_it + out.Size(), linalg::tcbegin(in),
linalg::tbegin(d_out));
}
template <typename T>
void Gather(Context const* ctx, HostDeviceVector<T> const& in, common::Span<bst_idx_t const> ridx,
HostDeviceVector<T>* p_out) {
if (in.Empty()) {
return;
}
in.SetDevice(ctx->Device());
auto& out = *p_out;
out.SetDevice(ctx->Device());
out.Resize(ridx.size());
auto d_out = out.DeviceSpan();
auto cuctx = ctx->CUDACtx();
auto d_in = in.ConstDeviceSpan();
thrust::gather(cuctx->TP(), dh::tcbegin(ridx), dh::tcend(ridx), dh::tcbegin(d_in),
dh::tbegin(d_out));
}
} // anonymous namespace
namespace cuda_impl {
void SliceMetaInfo(Context const* ctx, MetaInfo const& info, common::Span<bst_idx_t const> ridx,
MetaInfo* p_out) {
auto& out = *p_out;
Gather(ctx, info.labels.View(ctx->Device()), ridx, &p_out->labels);
Gather(ctx, info.base_margin_.View(ctx->Device()), ridx, &p_out->base_margin_);
Gather(ctx, info.labels_lower_bound_, ridx, &out.labels_lower_bound_);
Gather(ctx, info.labels_upper_bound_, ridx, &out.labels_upper_bound_);
Gather(ctx, info.weights_, ridx, &out.weights_);
}
} // namespace cuda_impl
template <typename AdapterT> template <typename AdapterT>
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
const std::string& cache_prefix, DataSplitMode data_split_mode) { const std::string& cache_prefix, DataSplitMode data_split_mode) {

View File

@ -1,12 +1,13 @@
/** /**
* Copyright 2019-2024, XGBoost contributors * Copyright 2019-2024, XGBoost contributors
*/ */
#include <cuda/functional> // for proclaim_return_type
#include <thrust/iterator/discard_iterator.h> #include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h> #include <thrust/iterator/transform_output_iterator.h>
#include <algorithm> // for copy #include <algorithm> // for copy
#include <utility> // for move #include <utility> // for move
#include <vector> // for vector #include <vector> // for vector
#include "../common/categorical.h" #include "../common/categorical.h"
#include "../common/cuda_context.cuh" #include "../common/cuda_context.cuh"
@ -576,4 +577,17 @@ EllpackDeviceAccessor EllpackPageImpl::GetHostAccessor(
common::CompressedIterator<uint32_t>(h_gidx_buffer->data(), NumSymbols()), common::CompressedIterator<uint32_t>(h_gidx_buffer->data(), NumSymbols()),
feature_types}; feature_types};
} }
[[nodiscard]] bst_idx_t EllpackPageImpl::NumNonMissing(
Context const* ctx, common::Span<FeatureType const> feature_types) const {
auto d_acc = this->GetDeviceAccessor(ctx->Device(), feature_types);
using T = typename decltype(d_acc.gidx_iter)::value_type;
auto it = thrust::make_transform_iterator(
thrust::make_counting_iterator(0ull),
cuda::proclaim_return_type<T>([=] __device__(std::size_t i) { return d_acc.gidx_iter[i]; }));
auto nnz = thrust::count_if(ctx->CUDACtx()->CTP(), it, it + d_acc.row_stride * d_acc.n_rows,
cuda::proclaim_return_type<bool>(
[=] __device__(T gidx) { return gidx != d_acc.NullValue(); }));
return nnz;
}
} // namespace xgboost } // namespace xgboost

View File

@ -236,6 +236,11 @@ class EllpackPageImpl {
[[nodiscard]] EllpackDeviceAccessor GetHostAccessor( [[nodiscard]] EllpackDeviceAccessor GetHostAccessor(
Context const* ctx, std::vector<common::CompressedByteT>* h_gidx_buffer, Context const* ctx, std::vector<common::CompressedByteT>* h_gidx_buffer,
common::Span<FeatureType const> feature_types = {}) const; common::Span<FeatureType const> feature_types = {}) const;
/**
* @brief Calculate the number of non-missing values.
*/
[[nodiscard]] bst_idx_t NumNonMissing(Context const* ctx,
common::Span<FeatureType const> feature_types) const;
private: private:
/** /**

View File

@ -101,6 +101,17 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
// Synchronise worker columns // Synchronise worker columns
} }
IterativeDMatrix::IterativeDMatrix(std::shared_ptr<EllpackPage> ellpack, MetaInfo const& info,
BatchParam batch) {
this->ellpack_ = ellpack;
CHECK_EQ(this->Info().num_row_, 0);
CHECK_EQ(this->Info().num_col_, 0);
this->Info().Extend(info, true, true);
this->Info().num_nonzero_ = info.num_nonzero_;
CHECK_EQ(this->Info().num_row_, info.num_row_);
this->batch_ = batch;
}
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx, BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx,
BatchParam const& param) { BatchParam const& param) {
if (param.Initialized()) { if (param.Initialized()) {

View File

@ -48,6 +48,11 @@ class IterativeDMatrix : public QuantileDMatrix {
std::shared_ptr<DMatrix> ref, DataIterResetCallback *reset, std::shared_ptr<DMatrix> ref, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing, int nthread, XGDMatrixCallbackNext *next, float missing, int nthread,
bst_bin_t max_bin); bst_bin_t max_bin);
/**
* @param Directly construct a QDM from an existing one.
*/
IterativeDMatrix(std::shared_ptr<EllpackPage> ellpack, MetaInfo const &info, BatchParam batch);
~IterativeDMatrix() override = default; ~IterativeDMatrix() override = default;
bool EllpackExists() const override { return static_cast<bool>(ellpack_); } bool EllpackExists() const override { return static_cast<bool>(ellpack_); }

View File

@ -31,6 +31,9 @@ const MetaInfo& SimpleDMatrix::Info() const { return info_; }
DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) { DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
auto out = new SimpleDMatrix; auto out = new SimpleDMatrix;
SparsePage& out_page = *out->sparse_page_; SparsePage& out_page = *out->sparse_page_;
// Convert to uint64 to avoid a breaking change in the C API. The performance impact is
// small since we have to iteratve through the sparse page.
std::vector<bst_idx_t> h_ridx(ridxs.data(), ridxs.data() + ridxs.size());
for (auto const& page : this->GetBatches<SparsePage>()) { for (auto const& page : this->GetBatches<SparsePage>()) {
auto batch = page.GetView(); auto batch = page.GetView();
auto& h_data = out_page.data.HostVector(); auto& h_data = out_page.data.HostVector();
@ -42,8 +45,8 @@ DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
std::copy(inst.begin(), inst.end(), std::back_inserter(h_data)); std::copy(inst.begin(), inst.end(), std::back_inserter(h_data));
h_offset.emplace_back(rptr); h_offset.emplace_back(rptr);
} }
out->Info() = this->Info().Slice(ridxs); auto ctx = this->fmat_ctx_.MakeCPU();
out->Info().num_nonzero_ = h_offset.back(); out->Info() = this->Info().Slice(&ctx, h_ridx, h_offset.back());
} }
out->fmat_ctx_ = this->fmat_ctx_; out->fmat_ctx_ = this->fmat_ctx_;
return out; return out;

View File

@ -14,12 +14,12 @@
#include "../../common/cuda_context.cuh" // for CUDAContext #include "../../common/cuda_context.cuh" // for CUDAContext
#include "../../common/random.h" #include "../../common/random.h"
#include "../../data/ellpack_page.cuh" // for EllpackPageImpl
#include "../../data/iterative_dmatrix.h" // for IterativeDMatrix
#include "../param.h" #include "../param.h"
#include "gradient_based_sampler.cuh" #include "gradient_based_sampler.cuh"
namespace xgboost { namespace xgboost::tree {
namespace tree {
/*! \brief A functor that returns random weights. */ /*! \brief A functor that returns random weights. */
class RandomWeight : public thrust::unary_function<size_t, float> { class RandomWeight : public thrust::unary_function<size_t, float> {
public: public:
@ -58,12 +58,14 @@ struct IsNonZero : public thrust::unary_function<GradientPair, bool> {
}; };
/*! \brief A functor that clears the row indexes with empty gradient. */ /*! \brief A functor that clears the row indexes with empty gradient. */
struct ClearEmptyRows : public thrust::binary_function<GradientPair, size_t, size_t> { struct ClearEmptyRows : public thrust::binary_function<GradientPair, bst_idx_t, bst_idx_t> {
static constexpr bst_idx_t InvalidRow() { return std::numeric_limits<std::size_t>::max(); }
XGBOOST_DEVICE size_t operator()(const GradientPair& gpair, size_t row_index) const { XGBOOST_DEVICE size_t operator()(const GradientPair& gpair, size_t row_index) const {
if (gpair.GetGrad() != 0 || gpair.GetHess() != 0) { if (gpair.GetGrad() != 0 || gpair.GetHess() != 0) {
return row_index; return row_index;
} else { } else {
return std::numeric_limits<std::size_t>::max(); return InvalidRow();
} }
} }
}; };
@ -148,10 +150,9 @@ class PoissonSampling : public thrust::binary_function<GradientPair, size_t, Gra
NoSampling::NoSampling(BatchParam batch_param) : batch_param_(std::move(batch_param)) {} NoSampling::NoSampling(BatchParam batch_param) : batch_param_(std::move(batch_param)) {}
GradientBasedSample NoSampling::Sample(Context const* ctx, common::Span<GradientPair> gpair, GradientBasedSample NoSampling::Sample(Context const*, common::Span<GradientPair> gpair,
DMatrix* dmat) { DMatrix* dmat) {
auto page = (*dmat->GetBatches<EllpackPage>(ctx, batch_param_).begin()).Impl(); return {dmat->Info().num_row_, dmat, gpair};
return {dmat->Info().num_row_, page, gpair};
} }
ExternalMemoryNoSampling::ExternalMemoryNoSampling(BatchParam batch_param) ExternalMemoryNoSampling::ExternalMemoryNoSampling(BatchParam batch_param)
@ -159,37 +160,39 @@ ExternalMemoryNoSampling::ExternalMemoryNoSampling(BatchParam batch_param)
GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx, GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx,
common::Span<GradientPair> gpair, common::Span<GradientPair> gpair,
DMatrix* dmat) { DMatrix* p_fmat) {
std::shared_ptr<EllpackPage> new_page;
if (!page_concatenated_) { if (!page_concatenated_) {
// Concatenate all the external memory ELLPACK pages into a single in-memory page. // Concatenate all the external memory ELLPACK pages into a single in-memory page.
page_.reset(nullptr);
bst_idx_t offset = 0; bst_idx_t offset = 0;
for (auto& batch : dmat->GetBatches<EllpackPage>(ctx, batch_param_)) { for (auto& batch : p_fmat->GetBatches<EllpackPage>(ctx, batch_param_)) {
auto page = batch.Impl(); auto page = batch.Impl();
if (!page_) { if (!new_page) {
page_ = std::make_unique<EllpackPageImpl>(ctx, page->CutsShared(), page->is_dense, new_page = std::make_shared<EllpackPage>();
page->row_stride, dmat->Info().num_row_); *new_page->Impl() = EllpackPageImpl(ctx, page->CutsShared(), page->is_dense,
page->row_stride, p_fmat->Info().num_row_);
} }
bst_idx_t num_elements = page_->Copy(ctx, page, offset); bst_idx_t num_elements = new_page->Impl()->Copy(ctx, page, offset);
offset += num_elements; offset += num_elements;
} }
page_concatenated_ = true; page_concatenated_ = true;
this->p_fmat_new_ =
std::make_unique<data::IterativeDMatrix>(new_page, p_fmat->Info(), batch_param_);
} }
return {dmat->Info().num_row_, page_.get(), gpair}; return {p_fmat->Info().num_row_, this->p_fmat_new_.get(), gpair};
} }
UniformSampling::UniformSampling(BatchParam batch_param, float subsample) UniformSampling::UniformSampling(BatchParam batch_param, float subsample)
: batch_param_{std::move(batch_param)}, subsample_(subsample) {} : batch_param_{std::move(batch_param)}, subsample_{subsample} {}
GradientBasedSample UniformSampling::Sample(Context const* ctx, common::Span<GradientPair> gpair, GradientBasedSample UniformSampling::Sample(Context const* ctx, common::Span<GradientPair> gpair,
DMatrix* dmat) { DMatrix* p_fmat) {
// Set gradient pair to 0 with p = 1 - subsample // Set gradient pair to 0 with p = 1 - subsample
auto cuctx = ctx->CUDACtx(); auto cuctx = ctx->CUDACtx();
thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<std::size_t>(0), thrust::counting_iterator<std::size_t>(0),
BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair()); BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair());
auto page = (*dmat->GetBatches<EllpackPage>(ctx, batch_param_).begin()).Impl(); return {p_fmat->Info().num_row_, p_fmat, gpair};
return {dmat->Info().num_row_, page, gpair};
} }
ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows, ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows,
@ -203,13 +206,17 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx,
common::Span<GradientPair> gpair, common::Span<GradientPair> gpair,
DMatrix* dmat) { DMatrix* dmat) {
auto cuctx = ctx->CUDACtx(); auto cuctx = ctx->CUDACtx();
std::shared_ptr<EllpackPage> new_page = std::make_shared<EllpackPage>();
auto page = new_page->Impl();
// Set gradient pair to 0 with p = 1 - subsample // Set gradient pair to 0 with p = 1 - subsample
thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<std::size_t>(0), thrust::counting_iterator<std::size_t>(0),
BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair{}); BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair{});
// Count the sampled rows. // Count the sampled rows.
size_t sample_rows = bst_idx_t sample_rows =
thrust::count_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), IsNonZero{}); thrust::count_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), IsNonZero{});
// Compact gradient pairs. // Compact gradient pairs.
@ -227,17 +234,25 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx,
auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_); auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_);
auto first_page = (*batch_iterator.begin()).Impl(); auto first_page = (*batch_iterator.begin()).Impl();
// Create a new ELLPACK page with empty rows. // Create a new ELLPACK page with empty rows.
page_.reset(); // Release the device memory first before reallocating *page = EllpackPageImpl{ctx, first_page->CutsShared(), first_page->is_dense,
page_.reset(new EllpackPageImpl(ctx, first_page->CutsShared(), first_page->is_dense, first_page->row_stride, sample_rows};
first_page->row_stride, sample_rows));
// Compact the ELLPACK pages into the single sample page. // Compact the ELLPACK pages into the single sample page.
thrust::fill(cuctx->CTP(), page_->gidx_buffer.begin(), page_->gidx_buffer.end(), 0); thrust::fill(cuctx->CTP(), page->gidx_buffer.begin(), page->gidx_buffer.end(), 0);
for (auto& batch : batch_iterator) { for (auto& batch : batch_iterator) {
page_->Compact(ctx, batch.Impl(), dh::ToSpan(sample_row_index_)); page->Compact(ctx, batch.Impl(), dh::ToSpan(sample_row_index_));
} }
// Select the metainfo
return {sample_rows, page_.get(), dh::ToSpan(gpair_)}; dmat->Info().feature_types.SetDevice(ctx->Device());
auto nnz = page->NumNonMissing(ctx, dmat->Info().feature_types.ConstDeviceSpan());
compact_row_index_.resize(sample_rows);
thrust::copy_if(
cuctx->TP(), sample_row_index_.cbegin(), sample_row_index_.cend(), compact_row_index_.begin(),
[] XGBOOST_DEVICE(std::size_t idx) { return idx != ClearEmptyRows::InvalidRow(); });
// Create the new DMatrix
this->p_fmat_new_ = std::make_unique<data::IterativeDMatrix>(
new_page, dmat->Info().Slice(ctx, dh::ToSpan(compact_row_index_), nnz), batch_param_);
return {sample_rows, this->p_fmat_new_.get(), dh::ToSpan(gpair_)};
} }
GradientBasedSampling::GradientBasedSampling(std::size_t n_rows, BatchParam batch_param, GradientBasedSampling::GradientBasedSampling(std::size_t n_rows, BatchParam batch_param,
@ -254,14 +269,12 @@ GradientBasedSample GradientBasedSampling::Sample(Context const* ctx,
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex( size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
ctx, gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_); ctx, gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
auto page = (*dmat->GetBatches<EllpackPage>(ctx, batch_param_).begin()).Impl();
// Perform Poisson sampling in place. // Perform Poisson sampling in place.
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<size_t>(0), dh::tbegin(gpair), thrust::counting_iterator<size_t>(0), dh::tbegin(gpair),
PoissonSampling(dh::ToSpan(threshold_), threshold_index, PoissonSampling(dh::ToSpan(threshold_), threshold_index,
RandomWeight(common::GlobalRandom()()))); RandomWeight(common::GlobalRandom()())));
return {n_rows, page, gpair}; return {n_rows, dmat, gpair};
} }
ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(size_t n_rows, ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(size_t n_rows,
@ -277,6 +290,8 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c
common::Span<GradientPair> gpair, common::Span<GradientPair> gpair,
DMatrix* dmat) { DMatrix* dmat) {
auto cuctx = ctx->CUDACtx(); auto cuctx = ctx->CUDACtx();
std::shared_ptr<EllpackPage> new_page = std::make_shared<EllpackPage>();
auto page = new_page->Impl();
bst_idx_t n_rows = dmat->Info().num_row_; bst_idx_t n_rows = dmat->Info().num_row_;
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex( size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
ctx, gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_); ctx, gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
@ -293,24 +308,33 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c
thrust::copy_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero()); thrust::copy_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());
// Index the sample rows. // Index the sample rows.
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
IsNonZero()); IsNonZero{});
thrust::exclusive_scan(cuctx->CTP(), sample_row_index_.begin(), sample_row_index_.end(), thrust::exclusive_scan(cuctx->CTP(), sample_row_index_.begin(), sample_row_index_.end(),
sample_row_index_.begin()); sample_row_index_.begin());
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
sample_row_index_.begin(), ClearEmptyRows()); sample_row_index_.begin(), ClearEmptyRows{});
auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_); auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_);
auto first_page = (*batch_iterator.begin()).Impl(); auto first_page = (*batch_iterator.begin()).Impl();
// Create a new ELLPACK page with empty rows. // Create a new ELLPACK page with empty rows.
page_.reset(); // Release the device memory first before reallocating
page_.reset(new EllpackPageImpl{ctx, first_page->CutsShared(), dmat->IsDense(),
first_page->row_stride, sample_rows});
// Compact the ELLPACK pages into the single sample page.
thrust::fill(cuctx->CTP(), page_->gidx_buffer.begin(), page_->gidx_buffer.end(), 0);
for (auto& batch : batch_iterator) {
page_->Compact(ctx, batch.Impl(), dh::ToSpan(sample_row_index_));
}
return {sample_rows, page_.get(), dh::ToSpan(gpair_)}; *page = EllpackPageImpl{ctx, first_page->CutsShared(), dmat->IsDense(), first_page->row_stride,
sample_rows};
// Compact the ELLPACK pages into the single sample page.
thrust::fill(cuctx->CTP(), page->gidx_buffer.begin(), page->gidx_buffer.end(), 0);
for (auto& batch : batch_iterator) {
page->Compact(ctx, batch.Impl(), dh::ToSpan(sample_row_index_));
}
// Select the metainfo
dmat->Info().feature_types.SetDevice(ctx->Device());
auto nnz = page->NumNonMissing(ctx, dmat->Info().feature_types.ConstDeviceSpan());
compact_row_index_.resize(sample_rows);
thrust::copy_if(
cuctx->TP(), sample_row_index_.cbegin(), sample_row_index_.cend(), compact_row_index_.begin(),
[] XGBOOST_DEVICE(std::size_t idx) { return idx != ClearEmptyRows::InvalidRow(); });
// Create the new DMatrix
this->p_fmat_new_ = std::make_unique<data::IterativeDMatrix>(
new_page, dmat->Info().Slice(ctx, dh::ToSpan(compact_row_index_), nnz), batch_param_);
return {sample_rows, this->p_fmat_new_.get(), dh::ToSpan(gpair_)};
} }
GradientBasedSampler::GradientBasedSampler(Context const* /*ctx*/, size_t n_rows, GradientBasedSampler::GradientBasedSampler(Context const* /*ctx*/, size_t n_rows,
@ -378,5 +402,4 @@ size_t GradientBasedSampler::CalculateThresholdIndex(Context const* ctx,
thrust::min_element(cuctx->CTP(), dh::tbegin(grad_sum), dh::tend(grad_sum)); thrust::min_element(cuctx->CTP(), dh::tbegin(grad_sum), dh::tend(grad_sum));
return thrust::distance(dh::tbegin(grad_sum), min) + 1; return thrust::distance(dh::tbegin(grad_sum), min) + 1;
} }
}; // namespace tree }; // namespace xgboost::tree
}; // namespace xgboost

View File

@ -5,7 +5,7 @@
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include "../../common/device_vector.cuh" // for device_vector, caching_device_vector #include "../../common/device_vector.cuh" // for device_vector, caching_device_vector
#include "../../data/ellpack_page.cuh" // for EllpackPageImpl #include "../../common/timer.h" // for Monitor
#include "xgboost/base.h" // for GradientPair #include "xgboost/base.h" // for GradientPair
#include "xgboost/data.h" // for BatchParam #include "xgboost/data.h" // for BatchParam
#include "xgboost/span.h" // for Span #include "xgboost/span.h" // for Span
@ -13,11 +13,11 @@
namespace xgboost::tree { namespace xgboost::tree {
struct GradientBasedSample { struct GradientBasedSample {
/*!\brief Number of sampled rows. */ /*!\brief Number of sampled rows. */
std::size_t sample_rows; bst_idx_t sample_rows;
/*!\brief Sampled rows in ELLPACK format. */ /*!\brief Sampled rows in ELLPACK format. */
EllpackPageImpl const* page; DMatrix* p_fmat;
/*!\brief Gradient pairs for the sampled rows. */ /*!\brief Gradient pairs for the sampled rows. */
common::Span<GradientPair> gpair; common::Span<GradientPair const> gpair;
}; };
class SamplingStrategy { class SamplingStrategy {
@ -48,7 +48,7 @@ class ExternalMemoryNoSampling : public SamplingStrategy {
private: private:
BatchParam batch_param_; BatchParam batch_param_;
std::unique_ptr<EllpackPageImpl> page_{nullptr}; std::unique_ptr<DMatrix> p_fmat_new_{nullptr};
bool page_concatenated_{false}; bool page_concatenated_{false};
}; };
@ -74,9 +74,10 @@ class ExternalMemoryUniformSampling : public SamplingStrategy {
private: private:
BatchParam batch_param_; BatchParam batch_param_;
float subsample_; float subsample_;
std::unique_ptr<EllpackPageImpl> page_; std::unique_ptr<DMatrix> p_fmat_new_{nullptr};
dh::device_vector<GradientPair> gpair_{}; dh::device_vector<GradientPair> gpair_{};
dh::caching_device_vector<size_t> sample_row_index_; dh::caching_device_vector<bst_idx_t> sample_row_index_;
dh::device_vector<bst_idx_t> compact_row_index_;
}; };
/*! \brief Gradient-based sampling in in-memory mode.. */ /*! \brief Gradient-based sampling in in-memory mode.. */
@ -105,9 +106,10 @@ class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
float subsample_; float subsample_;
dh::device_vector<float> threshold_; dh::device_vector<float> threshold_;
dh::device_vector<float> grad_sum_; dh::device_vector<float> grad_sum_;
std::unique_ptr<EllpackPageImpl> page_; std::unique_ptr<DMatrix> p_fmat_new_{nullptr};
dh::device_vector<GradientPair> gpair_; dh::device_vector<GradientPair> gpair_;
dh::device_vector<size_t> sample_row_index_; dh::device_vector<bst_idx_t> sample_row_index_;
dh::device_vector<bst_idx_t> compact_row_index_;
}; };
/*! \brief Draw a sample of rows from a DMatrix. /*! \brief Draw a sample of rows from a DMatrix.

View File

@ -119,9 +119,9 @@ struct DeviceSplitCandidate {
}; };
namespace cuda_impl { namespace cuda_impl {
inline BatchParam HistBatch(TrainParam const& param, bool prefetch_copy = true) { inline BatchParam HistBatch(TrainParam const& param) {
auto p = BatchParam{param.max_bin, TrainParam::DftSparseThreshold()}; auto p = BatchParam{param.max_bin, TrainParam::DftSparseThreshold()};
p.prefetch_copy = prefetch_copy; p.prefetch_copy = true;
p.n_prefetch_batches = 1; p.n_prefetch_batches = 1;
return p; return p;
} }
@ -134,6 +134,14 @@ inline BatchParam ApproxBatch(TrainParam const& p, common::Span<float const> hes
ObjInfo const& task) { ObjInfo const& task) {
return BatchParam{p.max_bin, hess, !task.const_hess}; return BatchParam{p.max_bin, hess, !task.const_hess};
} }
// Empty parameter to prevent regen, only used to control external memory prefetching.
inline BatchParam StaticBatch(bool prefetch_copy) {
BatchParam p;
p.prefetch_copy = prefetch_copy;
p.n_prefetch_batches = 1;
return p;
}
} // namespace cuda_impl } // namespace cuda_impl
template <typename T> template <typename T>

View File

@ -52,6 +52,11 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
using cuda_impl::ApproxBatch; using cuda_impl::ApproxBatch;
using cuda_impl::HistBatch; using cuda_impl::HistBatch;
// Both the approx and hist initializes the DMatrix before creating the actual
// implementation (InitDataOnce). Therefore, the `GPUHistMakerDevice` can use an empty
// parameter to avoid any regen.
using cuda_impl::StaticBatch;
// GPU tree updater implementation. // GPU tree updater implementation.
struct GPUHistMakerDevice { struct GPUHistMakerDevice {
private: private:
@ -64,6 +69,7 @@ struct GPUHistMakerDevice {
// node idx for each sample // node idx for each sample
dh::device_vector<bst_node_t> positions_; dh::device_vector<bst_node_t> positions_;
std::unique_ptr<RowPartitioner> row_partitioner_; std::unique_ptr<RowPartitioner> row_partitioner_;
std::shared_ptr<common::HistogramCuts const> cuts_{nullptr};
public: public:
// Extra data for each node that is passed to the update position function // Extra data for each node that is passed to the update position function
@ -75,13 +81,12 @@ struct GPUHistMakerDevice {
static_assert(std::is_trivially_copyable_v<NodeSplitData>); static_assert(std::is_trivially_copyable_v<NodeSplitData>);
public: public:
EllpackPageImpl const* page{nullptr};
common::Span<FeatureType const> feature_types; common::Span<FeatureType const> feature_types;
DeviceHistogramStorage<> hist{}; DeviceHistogramStorage<> hist{};
dh::device_vector<GradientPair> d_gpair; // storage for gpair; dh::device_vector<GradientPair> d_gpair; // storage for gpair;
common::Span<GradientPair> gpair; common::Span<GradientPair const> gpair;
dh::device_vector<int> monotone_constraints; dh::device_vector<int> monotone_constraints;
@ -99,19 +104,21 @@ struct GPUHistMakerDevice {
std::unique_ptr<FeatureGroups> feature_groups; std::unique_ptr<FeatureGroups> feature_groups;
GPUHistMakerDevice(Context const* ctx, bool is_external_memory, GPUHistMakerDevice(Context const* ctx, std::shared_ptr<common::HistogramCuts const> cuts,
common::Span<FeatureType const> _feature_types, bst_idx_t _n_rows, bool is_external_memory, common::Span<FeatureType const> _feature_types,
TrainParam _param, std::shared_ptr<common::ColumnSampler> column_sampler, TrainParam _param, std::shared_ptr<common::ColumnSampler> column_sampler,
uint32_t n_features, BatchParam batch_param, MetaInfo const& info) BatchParam batch_param, MetaInfo const& info)
: evaluator_{_param, n_features, ctx->Device()}, : evaluator_{_param, static_cast<bst_feature_t>(info.num_col_), ctx->Device()},
ctx_(ctx), ctx_(ctx),
feature_types{_feature_types}, feature_types{_feature_types},
param(std::move(_param)), param(std::move(_param)),
column_sampler_(std::move(column_sampler)), column_sampler_(std::move(column_sampler)),
interaction_constraints(param, n_features), interaction_constraints(param, info.num_col_),
info_{info} { info_{info},
sampler = std::make_unique<GradientBasedSampler>(ctx, _n_rows, batch_param, param.subsample, cuts_{std::move(cuts)} {
param.sampling_method, is_external_memory); sampler =
std::make_unique<GradientBasedSampler>(ctx, info.num_row_, batch_param, param.subsample,
param.sampling_method, is_external_memory);
if (!param.monotone_constraints.empty()) { if (!param.monotone_constraints.empty()) {
// Copy assigning an empty vector causes an exception in MSVC debug builds // Copy assigning an empty vector causes an exception in MSVC debug builds
monotone_constraints = param.monotone_constraints; monotone_constraints = param.monotone_constraints;
@ -123,19 +130,19 @@ struct GPUHistMakerDevice {
~GPUHistMakerDevice() = default; ~GPUHistMakerDevice() = default;
void InitFeatureGroupsOnce() { void InitFeatureGroupsOnce(MetaInfo const& info) {
if (!feature_groups) { if (!feature_groups) {
CHECK(page); CHECK(cuts_);
feature_groups = std::make_unique<FeatureGroups>(page->Cuts(), page->is_dense, feature_groups = std::make_unique<FeatureGroups>(*cuts_, info.IsDense(),
dh::MaxSharedMemoryOptin(ctx_->Ordinal()), dh::MaxSharedMemoryOptin(ctx_->Ordinal()),
sizeof(GradientPairPrecise)); sizeof(GradientPairPrecise));
} }
} }
// Reset values for each update iteration // Reset values for each update iteration
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) { [[nodiscard]] DMatrix* Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* p_fmat) {
auto const& info = dmat->Info(); auto const& info = p_fmat->Info();
this->column_sampler_->Init(ctx_, num_columns, info.feature_weights.HostVector(), this->column_sampler_->Init(ctx_, p_fmat->Info().num_col_, info.feature_weights.HostVector(),
param.colsample_bynode, param.colsample_bylevel, param.colsample_bynode, param.colsample_bylevel,
param.colsample_bytree); param.colsample_bytree);
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
@ -148,54 +155,54 @@ struct GPUHistMakerDevice {
dh::safe_cuda(cudaMemcpyAsync(d_gpair.data().get(), dh_gpair->ConstDevicePointer(), dh::safe_cuda(cudaMemcpyAsync(d_gpair.data().get(), dh_gpair->ConstDevicePointer(),
dh_gpair->Size() * sizeof(GradientPair), dh_gpair->Size() * sizeof(GradientPair),
cudaMemcpyDeviceToDevice)); cudaMemcpyDeviceToDevice));
auto sample = sampler->Sample(ctx_, dh::ToSpan(d_gpair), dmat); auto sample = sampler->Sample(ctx_, dh::ToSpan(d_gpair), p_fmat);
page = sample.page; this->gpair = sample.gpair;
gpair = sample.gpair; p_fmat = sample.p_fmat;
CHECK(p_fmat->SingleColBlock());
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param, this->evaluator_.Reset(*cuts_, feature_types, p_fmat->Info().num_col_, param,
dmat->Info().IsColumnSplit(), ctx_->Device()); p_fmat->Info().IsColumnSplit(), ctx_->Device());
quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, dmat->Info()); quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, p_fmat->Info());
if (!row_partitioner_) { if (!row_partitioner_) {
row_partitioner_ = std::make_unique<RowPartitioner>(); row_partitioner_ = std::make_unique<RowPartitioner>();
} }
row_partitioner_->Reset(ctx_, sample.sample_rows, page->base_rowid); row_partitioner_->Reset(ctx_, sample.sample_rows, 0);
CHECK_EQ(page->base_rowid, 0);
// Init histogram // Init histogram
hist.Init(ctx_->Device(), page->Cuts().TotalBins()); hist.Init(ctx_->Device(), this->cuts_->TotalBins());
hist.Reset(ctx_); hist.Reset(ctx_);
this->InitFeatureGroupsOnce(); this->InitFeatureGroupsOnce(info);
this->histogram_.Reset(ctx_, feature_groups->DeviceAccessor(ctx_->Device()), false); this->histogram_.Reset(ctx_, feature_groups->DeviceAccessor(ctx_->Device()), false);
return p_fmat;
} }
GPUExpandEntry EvaluateRootSplit(GradientPairInt64 root_sum) { GPUExpandEntry EvaluateRootSplit(DMatrix const * p_fmat, GradientPairInt64 root_sum) {
int nidx = RegTree::kRoot; int nidx = RegTree::kRoot;
GPUTrainingParam gpu_param(param); GPUTrainingParam gpu_param(param);
auto sampled_features = column_sampler_->GetFeatureSet(0); auto sampled_features = column_sampler_->GetFeatureSet(0);
sampled_features->SetDevice(ctx_->Device()); sampled_features->SetDevice(ctx_->Device());
common::Span<bst_feature_t> feature_set = common::Span<bst_feature_t> feature_set =
interaction_constraints.Query(sampled_features->DeviceSpan(), nidx); interaction_constraints.Query(sampled_features->DeviceSpan(), nidx);
auto matrix = page->GetDeviceAccessor(ctx_->Device());
EvaluateSplitInputs inputs{nidx, 0, root_sum, feature_set, hist.GetNodeHistogram(nidx)}; EvaluateSplitInputs inputs{nidx, 0, root_sum, feature_set, hist.GetNodeHistogram(nidx)};
EvaluateSplitSharedInputs shared_inputs{ EvaluateSplitSharedInputs shared_inputs{
gpu_param, gpu_param,
*quantiser, *quantiser,
feature_types, feature_types,
matrix.feature_segments, cuts_->cut_ptrs_.ConstDeviceSpan(),
matrix.gidx_fvalue_map, cuts_->cut_values_.ConstDeviceSpan(),
matrix.min_fvalue, cuts_->min_vals_.ConstDeviceSpan(),
matrix.is_dense && !collective::IsDistributed() p_fmat->IsDense() && !collective::IsDistributed()
}; };
auto split = this->evaluator_.EvaluateSingleSplit(ctx_, inputs, shared_inputs); auto split = this->evaluator_.EvaluateSingleSplit(ctx_, inputs, shared_inputs);
return split; return split;
} }
void EvaluateSplits(const std::vector<GPUExpandEntry>& candidates, const RegTree& tree, void EvaluateSplits(DMatrix const* p_fmat, const std::vector<GPUExpandEntry>& candidates,
common::Span<GPUExpandEntry> pinned_candidates_out) { const RegTree& tree, common::Span<GPUExpandEntry> pinned_candidates_out) {
if (candidates.empty()) { if (candidates.empty()) {
return; return;
} }
@ -204,12 +211,11 @@ struct GPUHistMakerDevice {
dh::TemporaryArray<DeviceSplitCandidate> splits_out(2 * candidates.size()); dh::TemporaryArray<DeviceSplitCandidate> splits_out(2 * candidates.size());
std::vector<bst_node_t> nidx(2 * candidates.size()); std::vector<bst_node_t> nidx(2 * candidates.size());
auto h_node_inputs = pinned2.GetSpan<EvaluateSplitInputs>(2 * candidates.size()); auto h_node_inputs = pinned2.GetSpan<EvaluateSplitInputs>(2 * candidates.size());
auto matrix = page->GetDeviceAccessor(ctx_->Device()); EvaluateSplitSharedInputs shared_inputs{
EvaluateSplitSharedInputs shared_inputs{GPUTrainingParam{param}, *quantiser, feature_types, GPUTrainingParam{param}, *quantiser, feature_types, cuts_->cut_ptrs_.ConstDeviceSpan(),
matrix.feature_segments, matrix.gidx_fvalue_map, cuts_->cut_values_.ConstDeviceSpan(), cuts_->min_vals_.ConstDeviceSpan(),
matrix.min_fvalue, // is_dense represents the local data
// is_dense represents the local data p_fmat->IsDense() && !collective::IsDistributed()};
matrix.is_dense && !collective::IsDistributed()};
dh::TemporaryArray<GPUExpandEntry> entries(2 * candidates.size()); dh::TemporaryArray<GPUExpandEntry> entries(2 * candidates.size());
// Store the feature set ptrs so they dont go out of scope before the kernel is called // Store the feature set ptrs so they dont go out of scope before the kernel is called
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> feature_sets; std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> feature_sets;
@ -254,7 +260,7 @@ struct GPUHistMakerDevice {
this->monitor.Stop(__func__); this->monitor.Stop(__func__);
} }
void BuildHist(int nidx) { void BuildHist(EllpackPageImpl const* page, int nidx) {
auto d_node_hist = hist.GetNodeHistogram(nidx); auto d_node_hist = hist.GetNodeHistogram(nidx);
auto d_ridx = row_partitioner_->GetRows(nidx); auto d_ridx = row_partitioner_->GetRows(nidx);
this->histogram_.BuildHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()), this->histogram_.BuildHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()),
@ -272,9 +278,8 @@ struct GPUHistMakerDevice {
auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram); auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram);
auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction); auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction);
dh::LaunchN(page->Cuts().TotalBins(), [=] __device__(size_t idx) { dh::LaunchN(cuts_->TotalBins(), [=] __device__(size_t idx) {
d_node_hist_subtraction[idx] = d_node_hist_subtraction[idx] = d_node_hist_parent[idx] - d_node_hist_histogram[idx];
d_node_hist_parent[idx] - d_node_hist_histogram[idx];
}); });
return true; return true;
} }
@ -366,7 +371,8 @@ struct GPUHistMakerDevice {
} }
}; };
void UpdatePosition(std::vector<GPUExpandEntry> const& candidates, RegTree* p_tree) { void UpdatePosition(DMatrix* p_fmat, std::vector<GPUExpandEntry> const& candidates,
RegTree* p_tree) {
if (candidates.empty()) { if (candidates.empty()) {
return; return;
} }
@ -390,30 +396,33 @@ struct GPUHistMakerDevice {
CHECK_EQ(split_type == FeatureType::kCategorical, e.split.is_cat); CHECK_EQ(split_type == FeatureType::kCategorical, e.split.is_cat);
} }
auto d_matrix = page->GetDeviceAccessor(ctx_->Device()); for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_->Device());
if (info_.IsColumnSplit()) { if (info_.IsColumnSplit()) {
UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx); UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx);
monitor.Stop(__func__); monitor.Stop(__func__);
return; return;
}
auto go_left = GoLeftOp{d_matrix};
row_partitioner_->UpdatePositionBatch(
nidx, left_nidx, right_nidx, split_data,
[=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/,
const NodeSplitData& data) { return go_left(ridx, data); });
} }
auto go_left = GoLeftOp{d_matrix};
row_partitioner_->UpdatePositionBatch(
nidx, left_nidx, right_nidx, split_data,
[=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/,
const NodeSplitData& data) { return go_left(ridx, data); });
monitor.Stop(__func__); monitor.Stop(__func__);
} }
// After tree update is finished, update the position of all training // After tree update is finished, update the position of all training
// instances to their final leaf. This information is used later to update the // instances to their final leaf. This information is used later to update the
// prediction cache // prediction cache
void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat, ObjInfo task, void FinalisePosition(DMatrix* p_fmat, RegTree const* p_tree, ObjInfo task, bst_idx_t n_samples,
HostDeviceVector<bst_node_t>* p_out_position) { HostDeviceVector<bst_node_t>* p_out_position) {
if (!p_fmat->SingleColBlock() && task.UpdateTreeLeaf()) { if (!p_fmat->SingleColBlock() && task.UpdateTreeLeaf()) {
LOG(FATAL) << "Current objective function can not be used with external memory."; LOG(FATAL) << "Current objective function can not be used with external memory.";
} }
if (p_fmat->Info().num_row_ != row_partitioner_->GetRows().size()) { if (p_fmat->Info().num_row_ != n_samples) {
// Subsampling with external memory. Not supported. // Subsampling with external memory. Not supported.
p_out_position->Resize(0); p_out_position->Resize(0);
positions_.clear(); positions_.clear();
@ -438,37 +447,40 @@ struct GPUHistMakerDevice {
} }
dh::caching_device_vector<uint32_t> categories; dh::caching_device_vector<uint32_t> categories;
dh::CopyToD(p_tree->GetSplitCategories(), &categories); dh::CopyTo(p_tree->GetSplitCategories(), &categories);
auto const& cat_segments = p_tree->GetSplitCategoriesPtr(); auto const& cat_segments = p_tree->GetSplitCategoriesPtr();
auto d_categories = dh::ToSpan(categories); auto d_categories = dh::ToSpan(categories);
auto d_matrix = page->GetDeviceAccessor(ctx_->Device()); for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_->Device());
std::vector<NodeSplitData> split_data(p_tree->NumNodes()); std::vector<NodeSplitData> split_data(p_tree->NumNodes());
auto const& tree = *p_tree; auto const& tree = *p_tree;
for (std::size_t i = 0, n = split_data.size(); i < n; ++i) { for (std::size_t i = 0, n = split_data.size(); i < n; ++i) {
RegTree::Node split_node = tree[i]; RegTree::Node split_node = tree[i];
auto split_type = p_tree->NodeSplitType(i); auto split_type = p_tree->NodeSplitType(i);
auto node_cats = common::GetNodeCats(d_categories, cat_segments[i]); auto node_cats = common::GetNodeCats(d_categories, cat_segments[i]);
split_data[i] = NodeSplitData{std::move(split_node), split_type, node_cats}; split_data[i] = NodeSplitData{std::move(split_node), split_type, node_cats};
}
auto go_left_op = GoLeftOp{d_matrix};
dh::caching_device_vector<NodeSplitData> d_split_data;
dh::CopyToD(split_data, &d_split_data);
auto s_split_data = dh::ToSpan(d_split_data);
row_partitioner_->FinalisePosition(d_out_position,
[=] __device__(bst_idx_t row_id, bst_node_t nidx) {
auto split_data = s_split_data[nidx];
auto node = split_data.split_node;
while (!node.IsLeaf()) {
auto go_left = go_left_op(row_id, split_data);
nidx = go_left ? node.LeftChild() : node.RightChild();
node = s_split_data[nidx].split_node;
}
return encode_op(row_id, nidx);
});
} }
auto go_left_op = GoLeftOp{d_matrix};
dh::caching_device_vector<NodeSplitData> d_split_data;
dh::CopyToD(split_data, &d_split_data);
auto s_split_data = dh::ToSpan(d_split_data);
row_partitioner_->FinalisePosition(d_out_position,
[=] __device__(bst_idx_t row_id, bst_node_t nidx) {
auto split_data = s_split_data[nidx];
auto node = split_data.split_node;
while (!node.IsLeaf()) {
auto go_left = go_left_op(row_id, split_data);
nidx = go_left ? node.LeftChild() : node.RightChild();
node = s_split_data[nidx].split_node;
}
return encode_op(row_id, nidx);
});
dh::CopyTo(d_out_position, &positions_); dh::CopyTo(d_out_position, &positions_);
} }
@ -508,7 +520,7 @@ struct GPUHistMakerDevice {
auto rc = collective::GlobalSum( auto rc = collective::GlobalSum(
ctx_, info_, ctx_, info_,
linalg::MakeVec(reinterpret_cast<ReduceT*>(d_node_hist), linalg::MakeVec(reinterpret_cast<ReduceT*>(d_node_hist),
page->Cuts().TotalBins() * 2 * num_histograms, ctx_->Device())); cuts_->TotalBins() * 2 * num_histograms, ctx_->Device()));
SafeColl(rc); SafeColl(rc);
monitor.Stop("AllReduce"); monitor.Stop("AllReduce");
@ -517,7 +529,8 @@ struct GPUHistMakerDevice {
/** /**
* \brief Build GPU local histograms for the left and right child of some parent node * \brief Build GPU local histograms for the left and right child of some parent node
*/ */
void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates, const RegTree& tree) { void BuildHistLeftRight(DMatrix* p_fmat, std::vector<GPUExpandEntry> const& candidates,
const RegTree& tree) {
if (candidates.empty()) { if (candidates.empty()) {
return; return;
} }
@ -544,8 +557,10 @@ struct GPUHistMakerDevice {
// Guaranteed contiguous memory // Guaranteed contiguous memory
hist.AllocateHistograms(ctx_, all_new); hist.AllocateHistograms(ctx_, all_new);
for (auto nidx : hist_nidx) { for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
this->BuildHist(nidx); for (auto nidx : hist_nidx) {
this->BuildHist(page.Impl(), nidx);
}
} }
// Reduce all in one go // Reduce all in one go
@ -560,7 +575,9 @@ struct GPUHistMakerDevice {
if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) { if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) {
// Calculate other histogram manually // Calculate other histogram manually
this->BuildHist(subtraction_trick_nidx); for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
this->BuildHist(page.Impl(), subtraction_trick_nidx);
}
this->AllReduceHist(subtraction_trick_nidx, 1); this->AllReduceHist(subtraction_trick_nidx, 1);
} }
} }
@ -595,7 +612,7 @@ struct GPUHistMakerDevice {
std::vector<common::CatBitField::value_type> split_cats; std::vector<common::CatBitField::value_type> split_cats;
auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid); auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid);
auto n_bins_feature = page->Cuts().FeatureBins(candidate.split.findex); auto n_bins_feature = cuts_->FeatureBins(candidate.split.findex);
split_cats.resize(common::CatBitField::ComputeStorageSize(n_bins_feature), 0); split_cats.resize(common::CatBitField::ComputeStorageSize(n_bins_feature), 0);
CHECK_LE(split_cats.size(), h_cats.size()); CHECK_LE(split_cats.size(), h_cats.size());
std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data()); std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data());
@ -618,7 +635,7 @@ struct GPUHistMakerDevice {
parent.RightChild()); parent.RightChild());
} }
GPUExpandEntry InitRoot(RegTree* p_tree) { GPUExpandEntry InitRoot(DMatrix* p_fmat, RegTree* p_tree) {
constexpr bst_node_t kRootNIdx = 0; constexpr bst_node_t kRootNIdx = 0;
dh::XGBCachingDeviceAllocator<char> alloc; dh::XGBCachingDeviceAllocator<char> alloc;
auto quantiser = *this->quantiser; auto quantiser = *this->quantiser;
@ -635,7 +652,9 @@ struct GPUHistMakerDevice {
collective::SafeColl(rc); collective::SafeColl(rc);
hist.AllocateHistograms(ctx_, {kRootNIdx}); hist.AllocateHistograms(ctx_, {kRootNIdx});
this->BuildHist(kRootNIdx); for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
this->BuildHist(page.Impl(), kRootNIdx);
}
this->AllReduceHist(kRootNIdx, 1); this->AllReduceHist(kRootNIdx, 1);
// Remember root stats // Remember root stats
@ -646,24 +665,25 @@ struct GPUHistMakerDevice {
(*p_tree)[kRootNIdx].SetLeaf(param.learning_rate * weight); (*p_tree)[kRootNIdx].SetLeaf(param.learning_rate * weight);
// Generate first split // Generate first split
auto root_entry = this->EvaluateRootSplit(root_sum_quantised); auto root_entry = this->EvaluateRootSplit(p_fmat, root_sum_quantised);
return root_entry; return root_entry;
} }
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo const* task, void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo const* task,
RegTree* p_tree, HostDeviceVector<bst_node_t>* p_out_position) { RegTree* p_tree, HostDeviceVector<bst_node_t>* p_out_position) {
bool const is_single_block = p_fmat->SingleColBlock(); bool const is_single_block = p_fmat->SingleColBlock();
bst_idx_t const n_samples = p_fmat->Info().num_row_;
auto& tree = *p_tree; auto& tree = *p_tree;
// Process maximum 32 nodes at a time // Process maximum 32 nodes at a time
Driver<GPUExpandEntry> driver(param, 32); Driver<GPUExpandEntry> driver(param, 32);
monitor.Start("Reset"); monitor.Start("Reset");
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_); p_fmat = this->Reset(gpair_all, p_fmat);
monitor.Stop("Reset"); monitor.Stop("Reset");
monitor.Start("InitRoot"); monitor.Start("InitRoot");
driver.Push({this->InitRoot(p_tree)}); driver.Push({this->InitRoot(p_fmat, p_tree)});
monitor.Stop("InitRoot"); monitor.Stop("InitRoot");
// The set of leaves that can be expanded asynchronously // The set of leaves that can be expanded asynchronously
@ -683,11 +703,11 @@ struct GPUHistMakerDevice {
// Update all the nodes if working with external memory, this saves us from working // Update all the nodes if working with external memory, this saves us from working
// with the finalize position call, which adds an additional iteration and requires // with the finalize position call, which adds an additional iteration and requires
// special handling for row index. // special handling for row index.
this->UpdatePosition(is_single_block ? filtered_expand_set : expand_set, p_tree); this->UpdatePosition(p_fmat, is_single_block ? filtered_expand_set : expand_set, p_tree);
this->BuildHistLeftRight(filtered_expand_set, tree); this->BuildHistLeftRight(p_fmat, filtered_expand_set, tree);
this->EvaluateSplits(filtered_expand_set, *p_tree, new_candidates); this->EvaluateSplits(p_fmat, filtered_expand_set, *p_tree, new_candidates);
dh::DefaultStream().Sync(); dh::DefaultStream().Sync();
driver.Push(new_candidates.begin(), new_candidates.end()); driver.Push(new_candidates.begin(), new_candidates.end());
@ -701,7 +721,7 @@ struct GPUHistMakerDevice {
if (is_single_block) { if (is_single_block) {
CHECK_GE(p_tree->NumNodes(), this->row_partitioner_->GetNumNodes()); CHECK_GE(p_tree->NumNodes(), this->row_partitioner_->GetNumNodes());
} }
this->FinalisePosition(p_tree, p_fmat, *task, p_out_position); this->FinalisePosition(p_fmat, p_tree, *task, n_samples, p_out_position);
} }
}; };
@ -750,9 +770,8 @@ class GPUHistMaker : public TreeUpdater {
monitor_.Stop(__func__); monitor_.Stop(__func__);
} }
void InitDataOnce(TrainParam const* param, DMatrix* dmat) { void InitDataOnce(TrainParam const* param, DMatrix* p_fmat) {
CHECK_GE(ctx_->Ordinal(), 0) << "Must have at least one device"; CHECK_GE(ctx_->Ordinal(), 0) << "Must have at least one device";
info_ = &dmat->Info();
// Synchronise the column sampling seed // Synchronise the column sampling seed
uint32_t column_sampling_seed = common::GlobalRandom()(); uint32_t column_sampling_seed = common::GlobalRandom()();
@ -761,13 +780,19 @@ class GPUHistMaker : public TreeUpdater {
SafeColl(rc); SafeColl(rc);
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed); this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); std::shared_ptr<common::HistogramCuts const> cuts;
info_->feature_types.SetDevice(ctx_->Device()); auto batch = HistBatch(*param);
maker = std::make_unique<GPUHistMakerDevice>( for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, HistBatch(*param))) {
ctx_, !dmat->SingleColBlock(), info_->feature_types.ConstDeviceSpan(), info_->num_row_, cuts = page.Impl()->CutsShared();
*param, column_sampler_, info_->num_col_, HistBatch(*param), dmat->Info()); }
p_last_fmat_ = dmat; dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
p_fmat->Info().feature_types.SetDevice(ctx_->Device());
maker = std::make_unique<GPUHistMakerDevice>(ctx_, cuts, !p_fmat->SingleColBlock(),
p_fmat->Info().feature_types.ConstDeviceSpan(),
*param, column_sampler_, batch, p_fmat->Info());
p_last_fmat_ = p_fmat;
initialised_ = true; initialised_ = true;
} }
@ -801,8 +826,6 @@ class GPUHistMaker : public TreeUpdater {
return result; return result;
} }
MetaInfo* info_{}; // NOLINT
std::unique_ptr<GPUHistMakerDevice> maker; // NOLINT std::unique_ptr<GPUHistMakerDevice> maker; // NOLINT
[[nodiscard]] char const* Name() const override { return "grow_gpu_hist"; } [[nodiscard]] char const* Name() const override { return "grow_gpu_hist"; }
@ -873,9 +896,15 @@ class GPUGlobalApproxMaker : public TreeUpdater {
auto const& info = p_fmat->Info(); auto const& info = p_fmat->Info();
info.feature_types.SetDevice(ctx_->Device()); info.feature_types.SetDevice(ctx_->Device());
maker_ = std::make_unique<GPUHistMakerDevice>( std::shared_ptr<common::HistogramCuts const> cuts;
ctx_, !p_fmat->SingleColBlock(), info.feature_types.ConstDeviceSpan(), info.num_row_, auto batch = ApproxBatch(*param, hess, *task_);
*param, column_sampler_, info.num_col_, ApproxBatch(*param, hess, *task_), p_fmat->Info()); for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, batch)) {
cuts = page.Impl()->CutsShared();
}
batch.regen = false; // Regen only at the beginning of the iteration.
maker_ = std::make_unique<GPUHistMakerDevice>(ctx_, cuts, !p_fmat->SingleColBlock(),
info.feature_types.ConstDeviceSpan(), *param,
column_sampler_, batch, p_fmat->Info());
std::size_t t_idx{0}; std::size_t t_idx{0};
for (xgboost::RegTree* tree : trees) { for (xgboost::RegTree* tree : trees) {

View File

@ -2,8 +2,9 @@
* Copyright 2021-2024, XGBoost Contributors * Copyright 2021-2024, XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <thrust/equal.h> // for equal #include <thrust/equal.h> // for equal
#include <thrust/sequence.h> // for sequence #include <thrust/iterator/constant_iterator.h> // for make_constant_iterator
#include <thrust/sequence.h> // for sequence
#include "../../../src/common/cuda_context.cuh" #include "../../../src/common/cuda_context.cuh"
#include "../../../src/common/linalg_op.cuh" #include "../../../src/common/linalg_op.cuh"
@ -83,6 +84,14 @@ void TestSlice() {
} }
}); });
} }
void TestWriteAccess(CUDAContext const* cuctx, linalg::TensorView<double, 3> t) {
thrust::for_each(cuctx->CTP(), linalg::tbegin(t), linalg::tend(t),
[=] XGBOOST_DEVICE(double& v) { v = 0; });
auto eq = thrust::equal(cuctx->CTP(), linalg::tcbegin(t), linalg::tcend(t),
thrust::make_constant_iterator<double>(0.0), thrust::equal_to<>{});
ASSERT_TRUE(eq);
}
} // anonymous namespace } // anonymous namespace
TEST(Linalg, GPUElementWise) { TestElementWiseKernel(); } TEST(Linalg, GPUElementWise) { TestElementWiseKernel(); }
@ -106,5 +115,7 @@ TEST(Linalg, GPUIter) {
bool eq = thrust::equal(cuctx->CTP(), data.cbegin(), data.cend(), linalg::tcbegin(t)); bool eq = thrust::equal(cuctx->CTP(), data.cbegin(), data.cend(), linalg::tcbegin(t));
ASSERT_TRUE(eq); ASSERT_TRUE(eq);
TestWriteAccess(cuctx, t);
} }
} // namespace xgboost::linalg } // namespace xgboost::linalg

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019-2023, XGBoost contributors * Copyright 2019-2024, XGBoost contributors
*/ */
#include <xgboost/base.h> #include <xgboost/base.h>
@ -15,7 +15,6 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
namespace xgboost { namespace xgboost {
TEST(EllpackPage, EmptyDMatrix) { TEST(EllpackPage, EmptyDMatrix) {
constexpr int kNRows = 0, kNCols = 0, kMaxBin = 256; constexpr int kNRows = 0, kNCols = 0, kMaxBin = 256;
constexpr float kSparsity = 0; constexpr float kSparsity = 0;
@ -242,7 +241,7 @@ TEST(EllpackPage, Compact) {
namespace { namespace {
class EllpackPageTest : public testing::TestWithParam<float> { class EllpackPageTest : public testing::TestWithParam<float> {
protected: protected:
void Run(float sparsity) { void TestFromGHistIndex(float sparsity) const {
// Only testing with small sample size as the cuts might be different between host and // Only testing with small sample size as the cuts might be different between host and
// device. // device.
size_t n_samples{128}, n_features{13}; size_t n_samples{128}, n_features{13};
@ -273,9 +272,25 @@ class EllpackPageTest : public testing::TestWithParam<float> {
} }
} }
} }
void TestNumNonMissing(float sparsity) const {
size_t n_samples{1024}, n_features{13};
auto ctx = MakeCUDACtx(0);
auto p_fmat = RandomDataGenerator{n_samples, n_features, sparsity}.GenerateDMatrix(true);
auto nnz = p_fmat->Info().num_nonzero_;
for (auto const& page : p_fmat->GetBatches<EllpackPage>(
&ctx, BatchParam{17, tree::TrainParam::DftSparseThreshold()})) {
auto ellpack_nnz =
page.Impl()->NumNonMissing(&ctx, p_fmat->Info().feature_types.ConstDeviceSpan());
ASSERT_EQ(nnz, ellpack_nnz);
}
}
}; };
} // namespace } // namespace
TEST_P(EllpackPageTest, FromGHistIndex) { this->Run(GetParam()); } TEST_P(EllpackPageTest, FromGHistIndex) { this->TestFromGHistIndex(GetParam()); }
TEST_P(EllpackPageTest, NumNonMissing) { this->TestNumNonMissing(this->GetParam()); }
INSTANTIATE_TEST_SUITE_P(EllpackPage, EllpackPageTest, testing::Values(.0f, .2f, .4f, .8f)); INSTANTIATE_TEST_SUITE_P(EllpackPage, EllpackPageTest, testing::Values(.0f, .2f, .4f, .8f));
} // namespace xgboost } // namespace xgboost

View File

@ -355,4 +355,70 @@ TEST(MetaInfo, HostExtend) {
} }
TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(DeviceOrd::CPU()); } TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(DeviceOrd::CPU()); }
namespace {
class TestMetaInfo : public ::testing::TestWithParam<std::tuple<bst_target_t, bool>> {
public:
void Run(Context const *ctx, bst_target_t n_targets) {
MetaInfo info;
info.num_row_ = 128;
info.num_col_ = 3;
info.feature_names.resize(info.num_col_, "a");
info.labels.Reshape(info.num_row_, n_targets);
HostDeviceVector<bst_idx_t> ridx(info.num_row_ / 2, 0);
ridx.SetDevice(ctx->Device());
auto h_ridx = ridx.HostSpan();
for (std::size_t i = 0, j = 0; i < ridx.Size(); i++, j += 2) {
h_ridx[i] = j;
}
{
info.weights_.Resize(info.num_row_);
auto h_w = info.weights_.HostSpan();
std::iota(h_w.begin(), h_w.end(), 0);
}
auto out = info.Slice(ctx, ctx->IsCPU() ? h_ridx : ridx.ConstDeviceSpan(), /*nnz=*/256);
ASSERT_EQ(info.labels.Device(), ctx->Device());
auto h_y = info.labels.HostView();
auto h_y_out = out.labels.HostView();
ASSERT_EQ(h_y_out.Shape(0), ridx.Size());
ASSERT_EQ(h_y_out.Shape(1), n_targets);
auto h_w = info.weights_.ConstHostSpan();
auto h_w_out = out.weights_.ConstHostSpan();
ASSERT_EQ(h_w_out.size(), ridx.Size());
for (std::size_t i = 0; i < ridx.Size(); ++i) {
for (bst_target_t t = 0; t < n_targets; ++t) {
ASSERT_EQ(h_y_out(i, t), h_y(h_ridx[i], t));
}
ASSERT_EQ(h_w_out[i], h_w[h_ridx[i]]);
}
for (auto v : info.feature_names) {
ASSERT_EQ(v, "a");
}
}
};
} // anonymous namespace
TEST_P(TestMetaInfo, Slice) {
Context ctx;
auto [n_targets, is_cuda] = this->GetParam();
if (is_cuda) {
ctx = MakeCUDACtx(0);
}
this->Run(&ctx, n_targets);
}
INSTANTIATE_TEST_SUITE_P(Cpu, TestMetaInfo,
::testing::Values(std::tuple{1u, false}, std::tuple{3u, false}));
#if defined(XGBOOST_USE_CUDA)
INSTANTIATE_TEST_SUITE_P(Gpu, TestMetaInfo,
::testing::Values(std::tuple{1u, true}, std::tuple{3u, true}));
#endif // defined(XGBOOST_USE_CUDA)
} // namespace xgboost } // namespace xgboost

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright 2021 by XGBoost Contributors * Copyright 2021-2024, XGBoost Contributors
*/ */
#ifndef XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_ #ifndef XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_
#define XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_ #define XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_
@ -11,7 +11,6 @@
#include <numeric> #include <numeric>
#include "../../../src/common/linalg_op.h" #include "../../../src/common/linalg_op.h"
#include "../../../src/data/array_interface.h"
namespace xgboost { namespace xgboost {
inline void TestMetaInfoStridedData(DeviceOrd device) { inline void TestMetaInfoStridedData(DeviceOrd device) {

View File

@ -39,11 +39,11 @@ void VerifySampling(size_t page_size, float subsample, int sampling_method,
if (fixed_size_sampling) { if (fixed_size_sampling) {
EXPECT_EQ(sample.sample_rows, kRows); EXPECT_EQ(sample.sample_rows, kRows);
EXPECT_EQ(sample.page->n_rows, kRows); EXPECT_EQ(sample.p_fmat->Info().num_row_, kRows);
EXPECT_EQ(sample.gpair.size(), kRows); EXPECT_EQ(sample.gpair.size(), kRows);
} else { } else {
EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.03); EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.03);
EXPECT_NEAR(sample.page->n_rows, sample_rows, kRows * 0.03f); EXPECT_NEAR(sample.p_fmat->Info().num_row_, sample_rows, kRows * 0.03f);
EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.03f); EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.03f);
} }
@ -88,25 +88,28 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) {
GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true); GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true);
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get()); auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
auto sampled_page = sample.page; auto p_fmat = sample.p_fmat;
EXPECT_EQ(sample.sample_rows, kRows); EXPECT_EQ(sample.sample_rows, kRows);
EXPECT_EQ(sample.gpair.size(), gpair.Size()); EXPECT_EQ(sample.gpair.size(), gpair.Size());
EXPECT_EQ(sample.gpair.data(), gpair.DevicePointer()); EXPECT_EQ(sample.gpair.data(), gpair.DevicePointer());
EXPECT_EQ(sampled_page->n_rows, kRows); EXPECT_EQ(p_fmat->Info().num_row_, kRows);
std::vector<common::CompressedByteT> h_gidx_buffer; ASSERT_EQ(p_fmat->NumBatches(), 1);
auto h_accessor = sampled_page->GetHostAccessor(&ctx, &h_gidx_buffer); for (auto const& sampled_page : p_fmat->GetBatches<EllpackPage>(&ctx, param)) {
std::vector<common::CompressedByteT> h_gidx_buffer;
auto h_accessor = sampled_page.Impl()->GetHostAccessor(&ctx, &h_gidx_buffer);
std::size_t offset = 0; std::size_t offset = 0;
for (auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) { for (auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) {
auto page = batch.Impl(); auto page = batch.Impl();
std::vector<common::CompressedByteT> h_page_gidx_buffer; std::vector<common::CompressedByteT> h_page_gidx_buffer;
auto page_accessor = page->GetHostAccessor(&ctx, &h_page_gidx_buffer); auto page_accessor = page->GetHostAccessor(&ctx, &h_page_gidx_buffer);
size_t num_elements = page->n_rows * page->row_stride; size_t num_elements = page->n_rows * page->row_stride;
for (size_t i = 0; i < num_elements; i++) { for (size_t i = 0; i < num_elements; i++) {
EXPECT_EQ(h_accessor.gidx_iter[i + offset], page_accessor.gidx_iter[i]); EXPECT_EQ(h_accessor.gidx_iter[i + offset], page_accessor.gidx_iter[i]);
}
offset += num_elements;
} }
offset += num_elements;
} }
} }