diff --git a/.gitignore b/.gitignore index 4a780c305..8a2df2a9b 100644 --- a/.gitignore +++ b/.gitignore @@ -63,6 +63,7 @@ java/xgboost4j-demo/data/ java/xgboost4j-demo/tmp/ java/xgboost4j-demo/model/ nb-configuration* + # Eclipse .project .cproject @@ -154,3 +155,6 @@ model*.json *.rds Rplots.pdf *.zip + +# nsys +*.nsys-rep \ No newline at end of file diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 0821ce648..fc7f2c79b 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -110,8 +110,15 @@ class MetaInfo { * @brief Validate all metainfo. */ void Validate(DeviceOrd device) const; - - MetaInfo Slice(common::Span ridxs) const; + /** + * @brief Slice the meta info. + * + * The device of ridxs is specified by the ctx object. + * + * @param ridxs Index of selected rows. + * @param nnz The number of non-missing values. + */ + MetaInfo Slice(Context const* ctx, common::Span ridxs, bst_idx_t nnz) const; MetaInfo Copy() const; /** diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 00b2a65f8..c8b13ffe1 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -508,6 +508,11 @@ xgboost::common::Span ToSpan(DeviceUVector &vec) { return {vec.data(), vec.size()}; } +template +xgboost::common::Span> ToSpan(DeviceUVector const &vec) { + return {vec.data(), vec.size()}; +} + // thrust begin, similiar to std::begin template thrust::device_ptr tbegin(xgboost::HostDeviceVector& vector) { // NOLINT diff --git a/src/common/linalg_op.cuh b/src/common/linalg_op.cuh index 21fad2dc0..0920f99ad 100644 --- a/src/common/linalg_op.cuh +++ b/src/common/linalg_op.cuh @@ -76,7 +76,7 @@ struct IterOp { // returns a thrust iterator for a tensor view. template auto tcbegin(TensorView v) { // NOLINT - return dh::MakeTransformIterator( + return thrust::make_transform_iterator( thrust::make_counting_iterator(0ul), detail::IterOp>, kDim>{v}); } @@ -85,5 +85,16 @@ template auto tcend(TensorView v) { // NOLINT return tcbegin(v) + v.Size(); } + +template +auto tbegin(TensorView v) { // NOLINT + return thrust::make_transform_iterator(thrust::make_counting_iterator(0ul), + detail::IterOp, kDim>{v}); +} + +template +auto tend(TensorView v) { // NOLINT + return tbegin(v) + v.Size(); +} } // namespace xgboost::linalg #endif // XGBOOST_COMMON_LINALG_OP_CUH_ diff --git a/src/data/data.cc b/src/data/data.cc index fcbd6cae2..b71820a96 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -351,8 +351,10 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) { this->has_categorical_ = LoadFeatureType(feature_type_names, &feature_types.HostVector()); } +namespace { template -std::vector Gather(const std::vector &in, common::Span ridxs, size_t stride = 1) { +std::vector Gather(const std::vector& in, common::Span ridxs, + size_t stride = 1) { if (in.empty()) { return {}; } @@ -361,16 +363,56 @@ std::vector Gather(const std::vector &in, common::Span ridxs, s for (auto i = 0ull; i < size; i++) { auto ridx = ridxs[i]; for (size_t j = 0; j < stride; ++j) { - out[i * stride +j] = in[ridx * stride + j]; + out[i * stride + j] = in[ridx * stride + j]; } } return out; } +} // namespace -MetaInfo MetaInfo::Slice(common::Span ridxs) const { +namespace cuda_impl { +void SliceMetaInfo(Context const* ctx, MetaInfo const& info, common::Span ridx, + MetaInfo* p_out); +#if !defined(XGBOOST_USE_CUDA) +void SliceMetaInfo(Context const*, MetaInfo const&, common::Span, MetaInfo*) { + common::AssertGPUSupport(); +} +#endif +} // namespace cuda_impl + +MetaInfo MetaInfo::Slice(Context const* ctx, common::Span ridxs, + bst_idx_t nnz) const { + /** + * Shape + */ MetaInfo out; out.num_row_ = ridxs.size(); out.num_col_ = this->num_col_; + out.num_nonzero_ = nnz; + + /** + * Feature Info + */ + out.feature_weights.SetDevice(ctx->Device()); + out.feature_weights.Resize(this->feature_weights.Size()); + out.feature_weights.Copy(this->feature_weights); + + out.feature_names = this->feature_names; + + out.feature_types.SetDevice(ctx->Device()); + out.feature_types.Resize(this->feature_types.Size()); + out.feature_types.Copy(this->feature_types); + + out.feature_type_names = this->feature_type_names; + + /** + * Sample Info + */ + if (ctx->IsCUDA()) { + cuda_impl::SliceMetaInfo(ctx, *this, ridxs, &out); + return out; + } + // Groups is maintained by a higher level Python function. We should aim at deprecating // the slice function. if (this->labels.Size() != this->num_row_) { @@ -386,13 +428,11 @@ MetaInfo MetaInfo::Slice(common::Span ridxs) const { }); } - out.labels_upper_bound_.HostVector() = - Gather(this->labels_upper_bound_.HostVector(), ridxs); - out.labels_lower_bound_.HostVector() = - Gather(this->labels_lower_bound_.HostVector(), ridxs); + out.labels_upper_bound_.HostVector() = Gather(this->labels_upper_bound_.HostVector(), ridxs); + out.labels_lower_bound_.HostVector() = Gather(this->labels_lower_bound_.HostVector(), ridxs); // weights if (this->weights_.Size() + 1 == this->group_ptr_.size()) { - auto& h_weights = out.weights_.HostVector(); + auto& h_weights = out.weights_.HostVector(); // Assuming all groups are available. out.weights_.HostVector() = h_weights; } else { @@ -414,14 +454,6 @@ MetaInfo MetaInfo::Slice(common::Span ridxs) const { }); } - out.feature_weights.Resize(this->feature_weights.Size()); - out.feature_weights.Copy(this->feature_weights); - - out.feature_names = this->feature_names; - out.feature_types.Resize(this->feature_types.Size()); - out.feature_types.Copy(this->feature_types); - out.feature_type_names = this->feature_type_names; - return out; } diff --git a/src/data/data.cu b/src/data/data.cu index 670af48c7..37950803b 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -1,9 +1,11 @@ /** - * Copyright 2019-2022 by XGBoost Contributors + * Copyright 2019-2024, XGBoost Contributors * * \file data.cu * \brief Handles setting metainfo from array interface. */ +#include // for gather + #include "../common/cuda_context.cuh" #include "../common/device_helpers.cuh" #include "../common/linalg_op.cuh" @@ -169,6 +171,62 @@ void MetaInfo::SetInfoFromCUDA(Context const& ctx, StringView key, Json array) { } } +namespace { +void Gather(Context const* ctx, linalg::MatrixView in, + common::Span ridx, linalg::Matrix* p_out) { + if (in.Empty()) { + return; + } + auto& out = *p_out; + out.Reshape(ridx.size(), in.Shape(1)); + auto d_out = out.View(ctx->Device()); + + auto cuctx = ctx->CUDACtx(); + auto map_it = thrust::make_transform_iterator(thrust::make_counting_iterator(0ull), + [=] XGBOOST_DEVICE(bst_idx_t i) { + auto [r, c] = linalg::UnravelIndex(i, in.Shape()); + return (ridx[r] * in.Shape(1)) + c; + }); + CHECK_NE(in.Shape(1), 0); + thrust::gather(cuctx->TP(), map_it, map_it + out.Size(), linalg::tcbegin(in), + linalg::tbegin(d_out)); +} + +template +void Gather(Context const* ctx, HostDeviceVector const& in, common::Span ridx, + HostDeviceVector* p_out) { + if (in.Empty()) { + return; + } + in.SetDevice(ctx->Device()); + + auto& out = *p_out; + out.SetDevice(ctx->Device()); + out.Resize(ridx.size()); + auto d_out = out.DeviceSpan(); + + auto cuctx = ctx->CUDACtx(); + auto d_in = in.ConstDeviceSpan(); + thrust::gather(cuctx->TP(), dh::tcbegin(ridx), dh::tcend(ridx), dh::tcbegin(d_in), + dh::tbegin(d_out)); +} +} // anonymous namespace + +namespace cuda_impl { +void SliceMetaInfo(Context const* ctx, MetaInfo const& info, common::Span ridx, + MetaInfo* p_out) { + auto& out = *p_out; + + Gather(ctx, info.labels.View(ctx->Device()), ridx, &p_out->labels); + Gather(ctx, info.base_margin_.View(ctx->Device()), ridx, &p_out->base_margin_); + + Gather(ctx, info.labels_lower_bound_, ridx, &out.labels_lower_bound_); + Gather(ctx, info.labels_upper_bound_, ridx, &out.labels_upper_bound_); + + Gather(ctx, info.weights_, ridx, &out.weights_); +} +} // namespace cuda_impl + template DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string& cache_prefix, DataSplitMode data_split_mode) { diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 515625c24..727ef4774 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -1,12 +1,13 @@ /** * Copyright 2019-2024, XGBoost contributors */ +#include // for proclaim_return_type #include #include -#include // for copy -#include // for move -#include // for vector +#include // for copy +#include // for move +#include // for vector #include "../common/categorical.h" #include "../common/cuda_context.cuh" @@ -576,4 +577,17 @@ EllpackDeviceAccessor EllpackPageImpl::GetHostAccessor( common::CompressedIterator(h_gidx_buffer->data(), NumSymbols()), feature_types}; } + +[[nodiscard]] bst_idx_t EllpackPageImpl::NumNonMissing( + Context const* ctx, common::Span feature_types) const { + auto d_acc = this->GetDeviceAccessor(ctx->Device(), feature_types); + using T = typename decltype(d_acc.gidx_iter)::value_type; + auto it = thrust::make_transform_iterator( + thrust::make_counting_iterator(0ull), + cuda::proclaim_return_type([=] __device__(std::size_t i) { return d_acc.gidx_iter[i]; })); + auto nnz = thrust::count_if(ctx->CUDACtx()->CTP(), it, it + d_acc.row_stride * d_acc.n_rows, + cuda::proclaim_return_type( + [=] __device__(T gidx) { return gidx != d_acc.NullValue(); })); + return nnz; +} } // namespace xgboost diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index 9cc2a5130..b9a67ba22 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -236,6 +236,11 @@ class EllpackPageImpl { [[nodiscard]] EllpackDeviceAccessor GetHostAccessor( Context const* ctx, std::vector* h_gidx_buffer, common::Span feature_types = {}) const; + /** + * @brief Calculate the number of non-missing values. + */ + [[nodiscard]] bst_idx_t NumNonMissing(Context const* ctx, + common::Span feature_types) const; private: /** diff --git a/src/data/iterative_dmatrix.cu b/src/data/iterative_dmatrix.cu index 0b15604e3..31bac8548 100644 --- a/src/data/iterative_dmatrix.cu +++ b/src/data/iterative_dmatrix.cu @@ -101,6 +101,17 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p, // Synchronise worker columns } +IterativeDMatrix::IterativeDMatrix(std::shared_ptr ellpack, MetaInfo const& info, + BatchParam batch) { + this->ellpack_ = ellpack; + CHECK_EQ(this->Info().num_row_, 0); + CHECK_EQ(this->Info().num_col_, 0); + this->Info().Extend(info, true, true); + this->Info().num_nonzero_ = info.num_nonzero_; + CHECK_EQ(this->Info().num_row_, info.num_row_); + this->batch_ = batch; +} + BatchSet IterativeDMatrix::GetEllpackBatches(Context const* ctx, BatchParam const& param) { if (param.Initialized()) { diff --git a/src/data/iterative_dmatrix.h b/src/data/iterative_dmatrix.h index acec4708e..281d08248 100644 --- a/src/data/iterative_dmatrix.h +++ b/src/data/iterative_dmatrix.h @@ -48,6 +48,11 @@ class IterativeDMatrix : public QuantileDMatrix { std::shared_ptr ref, DataIterResetCallback *reset, XGDMatrixCallbackNext *next, float missing, int nthread, bst_bin_t max_bin); + /** + * @param Directly construct a QDM from an existing one. + */ + IterativeDMatrix(std::shared_ptr ellpack, MetaInfo const &info, BatchParam batch); + ~IterativeDMatrix() override = default; bool EllpackExists() const override { return static_cast(ellpack_); } diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index a9bac5062..0edc0bbc7 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -31,6 +31,9 @@ const MetaInfo& SimpleDMatrix::Info() const { return info_; } DMatrix* SimpleDMatrix::Slice(common::Span ridxs) { auto out = new SimpleDMatrix; SparsePage& out_page = *out->sparse_page_; + // Convert to uint64 to avoid a breaking change in the C API. The performance impact is + // small since we have to iteratve through the sparse page. + std::vector h_ridx(ridxs.data(), ridxs.data() + ridxs.size()); for (auto const& page : this->GetBatches()) { auto batch = page.GetView(); auto& h_data = out_page.data.HostVector(); @@ -42,8 +45,8 @@ DMatrix* SimpleDMatrix::Slice(common::Span ridxs) { std::copy(inst.begin(), inst.end(), std::back_inserter(h_data)); h_offset.emplace_back(rptr); } - out->Info() = this->Info().Slice(ridxs); - out->Info().num_nonzero_ = h_offset.back(); + auto ctx = this->fmat_ctx_.MakeCPU(); + out->Info() = this->Info().Slice(&ctx, h_ridx, h_offset.back()); } out->fmat_ctx_ = this->fmat_ctx_; return out; diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 3235e9ec3..44980ac06 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -14,12 +14,12 @@ #include "../../common/cuda_context.cuh" // for CUDAContext #include "../../common/random.h" +#include "../../data/ellpack_page.cuh" // for EllpackPageImpl +#include "../../data/iterative_dmatrix.h" // for IterativeDMatrix #include "../param.h" #include "gradient_based_sampler.cuh" -namespace xgboost { -namespace tree { - +namespace xgboost::tree { /*! \brief A functor that returns random weights. */ class RandomWeight : public thrust::unary_function { public: @@ -58,12 +58,14 @@ struct IsNonZero : public thrust::unary_function { }; /*! \brief A functor that clears the row indexes with empty gradient. */ -struct ClearEmptyRows : public thrust::binary_function { +struct ClearEmptyRows : public thrust::binary_function { + static constexpr bst_idx_t InvalidRow() { return std::numeric_limits::max(); } + XGBOOST_DEVICE size_t operator()(const GradientPair& gpair, size_t row_index) const { if (gpair.GetGrad() != 0 || gpair.GetHess() != 0) { return row_index; } else { - return std::numeric_limits::max(); + return InvalidRow(); } } }; @@ -148,10 +150,9 @@ class PoissonSampling : public thrust::binary_function gpair, +GradientBasedSample NoSampling::Sample(Context const*, common::Span gpair, DMatrix* dmat) { - auto page = (*dmat->GetBatches(ctx, batch_param_).begin()).Impl(); - return {dmat->Info().num_row_, page, gpair}; + return {dmat->Info().num_row_, dmat, gpair}; } ExternalMemoryNoSampling::ExternalMemoryNoSampling(BatchParam batch_param) @@ -159,37 +160,39 @@ ExternalMemoryNoSampling::ExternalMemoryNoSampling(BatchParam batch_param) GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx, common::Span gpair, - DMatrix* dmat) { + DMatrix* p_fmat) { + std::shared_ptr new_page; if (!page_concatenated_) { // Concatenate all the external memory ELLPACK pages into a single in-memory page. - page_.reset(nullptr); bst_idx_t offset = 0; - for (auto& batch : dmat->GetBatches(ctx, batch_param_)) { + for (auto& batch : p_fmat->GetBatches(ctx, batch_param_)) { auto page = batch.Impl(); - if (!page_) { - page_ = std::make_unique(ctx, page->CutsShared(), page->is_dense, - page->row_stride, dmat->Info().num_row_); + if (!new_page) { + new_page = std::make_shared(); + *new_page->Impl() = EllpackPageImpl(ctx, page->CutsShared(), page->is_dense, + page->row_stride, p_fmat->Info().num_row_); } - bst_idx_t num_elements = page_->Copy(ctx, page, offset); + bst_idx_t num_elements = new_page->Impl()->Copy(ctx, page, offset); offset += num_elements; } page_concatenated_ = true; + this->p_fmat_new_ = + std::make_unique(new_page, p_fmat->Info(), batch_param_); } - return {dmat->Info().num_row_, page_.get(), gpair}; + return {p_fmat->Info().num_row_, this->p_fmat_new_.get(), gpair}; } UniformSampling::UniformSampling(BatchParam batch_param, float subsample) - : batch_param_{std::move(batch_param)}, subsample_(subsample) {} + : batch_param_{std::move(batch_param)}, subsample_{subsample} {} GradientBasedSample UniformSampling::Sample(Context const* ctx, common::Span gpair, - DMatrix* dmat) { + DMatrix* p_fmat) { // Set gradient pair to 0 with p = 1 - subsample auto cuctx = ctx->CUDACtx(); thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), thrust::counting_iterator(0), BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair()); - auto page = (*dmat->GetBatches(ctx, batch_param_).begin()).Impl(); - return {dmat->Info().num_row_, page, gpair}; + return {p_fmat->Info().num_row_, p_fmat, gpair}; } ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows, @@ -203,13 +206,17 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx, common::Span gpair, DMatrix* dmat) { auto cuctx = ctx->CUDACtx(); + + std::shared_ptr new_page = std::make_shared(); + auto page = new_page->Impl(); + // Set gradient pair to 0 with p = 1 - subsample thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), thrust::counting_iterator(0), BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair{}); // Count the sampled rows. - size_t sample_rows = + bst_idx_t sample_rows = thrust::count_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), IsNonZero{}); // Compact gradient pairs. @@ -227,17 +234,25 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx, auto batch_iterator = dmat->GetBatches(ctx, batch_param_); auto first_page = (*batch_iterator.begin()).Impl(); // Create a new ELLPACK page with empty rows. - page_.reset(); // Release the device memory first before reallocating - page_.reset(new EllpackPageImpl(ctx, first_page->CutsShared(), first_page->is_dense, - first_page->row_stride, sample_rows)); + *page = EllpackPageImpl{ctx, first_page->CutsShared(), first_page->is_dense, + first_page->row_stride, sample_rows}; // Compact the ELLPACK pages into the single sample page. - thrust::fill(cuctx->CTP(), page_->gidx_buffer.begin(), page_->gidx_buffer.end(), 0); + thrust::fill(cuctx->CTP(), page->gidx_buffer.begin(), page->gidx_buffer.end(), 0); for (auto& batch : batch_iterator) { - page_->Compact(ctx, batch.Impl(), dh::ToSpan(sample_row_index_)); + page->Compact(ctx, batch.Impl(), dh::ToSpan(sample_row_index_)); } - - return {sample_rows, page_.get(), dh::ToSpan(gpair_)}; + // Select the metainfo + dmat->Info().feature_types.SetDevice(ctx->Device()); + auto nnz = page->NumNonMissing(ctx, dmat->Info().feature_types.ConstDeviceSpan()); + compact_row_index_.resize(sample_rows); + thrust::copy_if( + cuctx->TP(), sample_row_index_.cbegin(), sample_row_index_.cend(), compact_row_index_.begin(), + [] XGBOOST_DEVICE(std::size_t idx) { return idx != ClearEmptyRows::InvalidRow(); }); + // Create the new DMatrix + this->p_fmat_new_ = std::make_unique( + new_page, dmat->Info().Slice(ctx, dh::ToSpan(compact_row_index_), nnz), batch_param_); + return {sample_rows, this->p_fmat_new_.get(), dh::ToSpan(gpair_)}; } GradientBasedSampling::GradientBasedSampling(std::size_t n_rows, BatchParam batch_param, @@ -254,14 +269,12 @@ GradientBasedSample GradientBasedSampling::Sample(Context const* ctx, size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex( ctx, gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_); - auto page = (*dmat->GetBatches(ctx, batch_param_).begin()).Impl(); - // Perform Poisson sampling in place. thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), thrust::counting_iterator(0), dh::tbegin(gpair), PoissonSampling(dh::ToSpan(threshold_), threshold_index, RandomWeight(common::GlobalRandom()()))); - return {n_rows, page, gpair}; + return {n_rows, dmat, gpair}; } ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(size_t n_rows, @@ -277,6 +290,8 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c common::Span gpair, DMatrix* dmat) { auto cuctx = ctx->CUDACtx(); + std::shared_ptr new_page = std::make_shared(); + auto page = new_page->Impl(); bst_idx_t n_rows = dmat->Info().num_row_; size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex( ctx, gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_); @@ -293,24 +308,33 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c thrust::copy_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero()); // Index the sample rows. thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), - IsNonZero()); + IsNonZero{}); thrust::exclusive_scan(cuctx->CTP(), sample_row_index_.begin(), sample_row_index_.end(), sample_row_index_.begin()); thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), - sample_row_index_.begin(), ClearEmptyRows()); + sample_row_index_.begin(), ClearEmptyRows{}); auto batch_iterator = dmat->GetBatches(ctx, batch_param_); auto first_page = (*batch_iterator.begin()).Impl(); // Create a new ELLPACK page with empty rows. - page_.reset(); // Release the device memory first before reallocating - page_.reset(new EllpackPageImpl{ctx, first_page->CutsShared(), dmat->IsDense(), - first_page->row_stride, sample_rows}); - // Compact the ELLPACK pages into the single sample page. - thrust::fill(cuctx->CTP(), page_->gidx_buffer.begin(), page_->gidx_buffer.end(), 0); - for (auto& batch : batch_iterator) { - page_->Compact(ctx, batch.Impl(), dh::ToSpan(sample_row_index_)); - } - return {sample_rows, page_.get(), dh::ToSpan(gpair_)}; + *page = EllpackPageImpl{ctx, first_page->CutsShared(), dmat->IsDense(), first_page->row_stride, + sample_rows}; + // Compact the ELLPACK pages into the single sample page. + thrust::fill(cuctx->CTP(), page->gidx_buffer.begin(), page->gidx_buffer.end(), 0); + for (auto& batch : batch_iterator) { + page->Compact(ctx, batch.Impl(), dh::ToSpan(sample_row_index_)); + } + // Select the metainfo + dmat->Info().feature_types.SetDevice(ctx->Device()); + auto nnz = page->NumNonMissing(ctx, dmat->Info().feature_types.ConstDeviceSpan()); + compact_row_index_.resize(sample_rows); + thrust::copy_if( + cuctx->TP(), sample_row_index_.cbegin(), sample_row_index_.cend(), compact_row_index_.begin(), + [] XGBOOST_DEVICE(std::size_t idx) { return idx != ClearEmptyRows::InvalidRow(); }); + // Create the new DMatrix + this->p_fmat_new_ = std::make_unique( + new_page, dmat->Info().Slice(ctx, dh::ToSpan(compact_row_index_), nnz), batch_param_); + return {sample_rows, this->p_fmat_new_.get(), dh::ToSpan(gpair_)}; } GradientBasedSampler::GradientBasedSampler(Context const* /*ctx*/, size_t n_rows, @@ -378,5 +402,4 @@ size_t GradientBasedSampler::CalculateThresholdIndex(Context const* ctx, thrust::min_element(cuctx->CTP(), dh::tbegin(grad_sum), dh::tend(grad_sum)); return thrust::distance(dh::tbegin(grad_sum), min) + 1; } -}; // namespace tree -}; // namespace xgboost +}; // namespace xgboost::tree diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index 79008b1ae..22de2c1fb 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -5,7 +5,7 @@ #include // for size_t #include "../../common/device_vector.cuh" // for device_vector, caching_device_vector -#include "../../data/ellpack_page.cuh" // for EllpackPageImpl +#include "../../common/timer.h" // for Monitor #include "xgboost/base.h" // for GradientPair #include "xgboost/data.h" // for BatchParam #include "xgboost/span.h" // for Span @@ -13,11 +13,11 @@ namespace xgboost::tree { struct GradientBasedSample { /*!\brief Number of sampled rows. */ - std::size_t sample_rows; + bst_idx_t sample_rows; /*!\brief Sampled rows in ELLPACK format. */ - EllpackPageImpl const* page; + DMatrix* p_fmat; /*!\brief Gradient pairs for the sampled rows. */ - common::Span gpair; + common::Span gpair; }; class SamplingStrategy { @@ -48,7 +48,7 @@ class ExternalMemoryNoSampling : public SamplingStrategy { private: BatchParam batch_param_; - std::unique_ptr page_{nullptr}; + std::unique_ptr p_fmat_new_{nullptr}; bool page_concatenated_{false}; }; @@ -74,9 +74,10 @@ class ExternalMemoryUniformSampling : public SamplingStrategy { private: BatchParam batch_param_; float subsample_; - std::unique_ptr page_; + std::unique_ptr p_fmat_new_{nullptr}; dh::device_vector gpair_{}; - dh::caching_device_vector sample_row_index_; + dh::caching_device_vector sample_row_index_; + dh::device_vector compact_row_index_; }; /*! \brief Gradient-based sampling in in-memory mode.. */ @@ -105,9 +106,10 @@ class ExternalMemoryGradientBasedSampling : public SamplingStrategy { float subsample_; dh::device_vector threshold_; dh::device_vector grad_sum_; - std::unique_ptr page_; + std::unique_ptr p_fmat_new_{nullptr}; dh::device_vector gpair_; - dh::device_vector sample_row_index_; + dh::device_vector sample_row_index_; + dh::device_vector compact_row_index_; }; /*! \brief Draw a sample of rows from a DMatrix. diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index 31f93d18a..f60d45196 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -119,9 +119,9 @@ struct DeviceSplitCandidate { }; namespace cuda_impl { -inline BatchParam HistBatch(TrainParam const& param, bool prefetch_copy = true) { +inline BatchParam HistBatch(TrainParam const& param) { auto p = BatchParam{param.max_bin, TrainParam::DftSparseThreshold()}; - p.prefetch_copy = prefetch_copy; + p.prefetch_copy = true; p.n_prefetch_batches = 1; return p; } @@ -134,6 +134,14 @@ inline BatchParam ApproxBatch(TrainParam const& p, common::Span hes ObjInfo const& task) { return BatchParam{p.max_bin, hess, !task.const_hess}; } + +// Empty parameter to prevent regen, only used to control external memory prefetching. +inline BatchParam StaticBatch(bool prefetch_copy) { + BatchParam p; + p.prefetch_copy = prefetch_copy; + p.n_prefetch_batches = 1; + return p; +} } // namespace cuda_impl template diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 5d364aa82..8ca6ef71c 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -52,6 +52,11 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); using cuda_impl::ApproxBatch; using cuda_impl::HistBatch; +// Both the approx and hist initializes the DMatrix before creating the actual +// implementation (InitDataOnce). Therefore, the `GPUHistMakerDevice` can use an empty +// parameter to avoid any regen. +using cuda_impl::StaticBatch; + // GPU tree updater implementation. struct GPUHistMakerDevice { private: @@ -64,6 +69,7 @@ struct GPUHistMakerDevice { // node idx for each sample dh::device_vector positions_; std::unique_ptr row_partitioner_; + std::shared_ptr cuts_{nullptr}; public: // Extra data for each node that is passed to the update position function @@ -75,13 +81,12 @@ struct GPUHistMakerDevice { static_assert(std::is_trivially_copyable_v); public: - EllpackPageImpl const* page{nullptr}; common::Span feature_types; DeviceHistogramStorage<> hist{}; dh::device_vector d_gpair; // storage for gpair; - common::Span gpair; + common::Span gpair; dh::device_vector monotone_constraints; @@ -99,19 +104,21 @@ struct GPUHistMakerDevice { std::unique_ptr feature_groups; - GPUHistMakerDevice(Context const* ctx, bool is_external_memory, - common::Span _feature_types, bst_idx_t _n_rows, + GPUHistMakerDevice(Context const* ctx, std::shared_ptr cuts, + bool is_external_memory, common::Span _feature_types, TrainParam _param, std::shared_ptr column_sampler, - uint32_t n_features, BatchParam batch_param, MetaInfo const& info) - : evaluator_{_param, n_features, ctx->Device()}, + BatchParam batch_param, MetaInfo const& info) + : evaluator_{_param, static_cast(info.num_col_), ctx->Device()}, ctx_(ctx), feature_types{_feature_types}, param(std::move(_param)), column_sampler_(std::move(column_sampler)), - interaction_constraints(param, n_features), - info_{info} { - sampler = std::make_unique(ctx, _n_rows, batch_param, param.subsample, - param.sampling_method, is_external_memory); + interaction_constraints(param, info.num_col_), + info_{info}, + cuts_{std::move(cuts)} { + sampler = + std::make_unique(ctx, info.num_row_, batch_param, param.subsample, + param.sampling_method, is_external_memory); if (!param.monotone_constraints.empty()) { // Copy assigning an empty vector causes an exception in MSVC debug builds monotone_constraints = param.monotone_constraints; @@ -123,19 +130,19 @@ struct GPUHistMakerDevice { ~GPUHistMakerDevice() = default; - void InitFeatureGroupsOnce() { + void InitFeatureGroupsOnce(MetaInfo const& info) { if (!feature_groups) { - CHECK(page); - feature_groups = std::make_unique(page->Cuts(), page->is_dense, + CHECK(cuts_); + feature_groups = std::make_unique(*cuts_, info.IsDense(), dh::MaxSharedMemoryOptin(ctx_->Ordinal()), sizeof(GradientPairPrecise)); } } // Reset values for each update iteration - void Reset(HostDeviceVector* dh_gpair, DMatrix* dmat, int64_t num_columns) { - auto const& info = dmat->Info(); - this->column_sampler_->Init(ctx_, num_columns, info.feature_weights.HostVector(), + [[nodiscard]] DMatrix* Reset(HostDeviceVector* dh_gpair, DMatrix* p_fmat) { + auto const& info = p_fmat->Info(); + this->column_sampler_->Init(ctx_, p_fmat->Info().num_col_, info.feature_weights.HostVector(), param.colsample_bynode, param.colsample_bylevel, param.colsample_bytree); dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); @@ -148,54 +155,54 @@ struct GPUHistMakerDevice { dh::safe_cuda(cudaMemcpyAsync(d_gpair.data().get(), dh_gpair->ConstDevicePointer(), dh_gpair->Size() * sizeof(GradientPair), cudaMemcpyDeviceToDevice)); - auto sample = sampler->Sample(ctx_, dh::ToSpan(d_gpair), dmat); - page = sample.page; - gpair = sample.gpair; + auto sample = sampler->Sample(ctx_, dh::ToSpan(d_gpair), p_fmat); + this->gpair = sample.gpair; + p_fmat = sample.p_fmat; + CHECK(p_fmat->SingleColBlock()); - this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param, - dmat->Info().IsColumnSplit(), ctx_->Device()); + this->evaluator_.Reset(*cuts_, feature_types, p_fmat->Info().num_col_, param, + p_fmat->Info().IsColumnSplit(), ctx_->Device()); - quantiser = std::make_unique(ctx_, this->gpair, dmat->Info()); + quantiser = std::make_unique(ctx_, this->gpair, p_fmat->Info()); if (!row_partitioner_) { row_partitioner_ = std::make_unique(); } - row_partitioner_->Reset(ctx_, sample.sample_rows, page->base_rowid); - CHECK_EQ(page->base_rowid, 0); + row_partitioner_->Reset(ctx_, sample.sample_rows, 0); // Init histogram - hist.Init(ctx_->Device(), page->Cuts().TotalBins()); + hist.Init(ctx_->Device(), this->cuts_->TotalBins()); hist.Reset(ctx_); - this->InitFeatureGroupsOnce(); + this->InitFeatureGroupsOnce(info); this->histogram_.Reset(ctx_, feature_groups->DeviceAccessor(ctx_->Device()), false); + return p_fmat; } - GPUExpandEntry EvaluateRootSplit(GradientPairInt64 root_sum) { + GPUExpandEntry EvaluateRootSplit(DMatrix const * p_fmat, GradientPairInt64 root_sum) { int nidx = RegTree::kRoot; GPUTrainingParam gpu_param(param); auto sampled_features = column_sampler_->GetFeatureSet(0); sampled_features->SetDevice(ctx_->Device()); common::Span feature_set = interaction_constraints.Query(sampled_features->DeviceSpan(), nidx); - auto matrix = page->GetDeviceAccessor(ctx_->Device()); EvaluateSplitInputs inputs{nidx, 0, root_sum, feature_set, hist.GetNodeHistogram(nidx)}; EvaluateSplitSharedInputs shared_inputs{ gpu_param, *quantiser, feature_types, - matrix.feature_segments, - matrix.gidx_fvalue_map, - matrix.min_fvalue, - matrix.is_dense && !collective::IsDistributed() + cuts_->cut_ptrs_.ConstDeviceSpan(), + cuts_->cut_values_.ConstDeviceSpan(), + cuts_->min_vals_.ConstDeviceSpan(), + p_fmat->IsDense() && !collective::IsDistributed() }; auto split = this->evaluator_.EvaluateSingleSplit(ctx_, inputs, shared_inputs); return split; } - void EvaluateSplits(const std::vector& candidates, const RegTree& tree, - common::Span pinned_candidates_out) { + void EvaluateSplits(DMatrix const* p_fmat, const std::vector& candidates, + const RegTree& tree, common::Span pinned_candidates_out) { if (candidates.empty()) { return; } @@ -204,12 +211,11 @@ struct GPUHistMakerDevice { dh::TemporaryArray splits_out(2 * candidates.size()); std::vector nidx(2 * candidates.size()); auto h_node_inputs = pinned2.GetSpan(2 * candidates.size()); - auto matrix = page->GetDeviceAccessor(ctx_->Device()); - EvaluateSplitSharedInputs shared_inputs{GPUTrainingParam{param}, *quantiser, feature_types, - matrix.feature_segments, matrix.gidx_fvalue_map, - matrix.min_fvalue, - // is_dense represents the local data - matrix.is_dense && !collective::IsDistributed()}; + EvaluateSplitSharedInputs shared_inputs{ + GPUTrainingParam{param}, *quantiser, feature_types, cuts_->cut_ptrs_.ConstDeviceSpan(), + cuts_->cut_values_.ConstDeviceSpan(), cuts_->min_vals_.ConstDeviceSpan(), + // is_dense represents the local data + p_fmat->IsDense() && !collective::IsDistributed()}; dh::TemporaryArray entries(2 * candidates.size()); // Store the feature set ptrs so they dont go out of scope before the kernel is called std::vector>> feature_sets; @@ -254,7 +260,7 @@ struct GPUHistMakerDevice { this->monitor.Stop(__func__); } - void BuildHist(int nidx) { + void BuildHist(EllpackPageImpl const* page, int nidx) { auto d_node_hist = hist.GetNodeHistogram(nidx); auto d_ridx = row_partitioner_->GetRows(nidx); this->histogram_.BuildHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()), @@ -272,9 +278,8 @@ struct GPUHistMakerDevice { auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram); auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction); - dh::LaunchN(page->Cuts().TotalBins(), [=] __device__(size_t idx) { - d_node_hist_subtraction[idx] = - d_node_hist_parent[idx] - d_node_hist_histogram[idx]; + dh::LaunchN(cuts_->TotalBins(), [=] __device__(size_t idx) { + d_node_hist_subtraction[idx] = d_node_hist_parent[idx] - d_node_hist_histogram[idx]; }); return true; } @@ -366,7 +371,8 @@ struct GPUHistMakerDevice { } }; - void UpdatePosition(std::vector const& candidates, RegTree* p_tree) { + void UpdatePosition(DMatrix* p_fmat, std::vector const& candidates, + RegTree* p_tree) { if (candidates.empty()) { return; } @@ -390,30 +396,33 @@ struct GPUHistMakerDevice { CHECK_EQ(split_type == FeatureType::kCategorical, e.split.is_cat); } - auto d_matrix = page->GetDeviceAccessor(ctx_->Device()); + for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { + auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_->Device()); - if (info_.IsColumnSplit()) { - UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx); - monitor.Stop(__func__); - return; + if (info_.IsColumnSplit()) { + UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx); + monitor.Stop(__func__); + return; + } + auto go_left = GoLeftOp{d_matrix}; + row_partitioner_->UpdatePositionBatch( + nidx, left_nidx, right_nidx, split_data, + [=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/, + const NodeSplitData& data) { return go_left(ridx, data); }); } - auto go_left = GoLeftOp{d_matrix}; - row_partitioner_->UpdatePositionBatch( - nidx, left_nidx, right_nidx, split_data, - [=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/, - const NodeSplitData& data) { return go_left(ridx, data); }); + monitor.Stop(__func__); } // After tree update is finished, update the position of all training // instances to their final leaf. This information is used later to update the // prediction cache - void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat, ObjInfo task, + void FinalisePosition(DMatrix* p_fmat, RegTree const* p_tree, ObjInfo task, bst_idx_t n_samples, HostDeviceVector* p_out_position) { if (!p_fmat->SingleColBlock() && task.UpdateTreeLeaf()) { LOG(FATAL) << "Current objective function can not be used with external memory."; } - if (p_fmat->Info().num_row_ != row_partitioner_->GetRows().size()) { + if (p_fmat->Info().num_row_ != n_samples) { // Subsampling with external memory. Not supported. p_out_position->Resize(0); positions_.clear(); @@ -438,37 +447,40 @@ struct GPUHistMakerDevice { } dh::caching_device_vector categories; - dh::CopyToD(p_tree->GetSplitCategories(), &categories); + dh::CopyTo(p_tree->GetSplitCategories(), &categories); auto const& cat_segments = p_tree->GetSplitCategoriesPtr(); auto d_categories = dh::ToSpan(categories); - auto d_matrix = page->GetDeviceAccessor(ctx_->Device()); + for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { + auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_->Device()); - std::vector split_data(p_tree->NumNodes()); - auto const& tree = *p_tree; - for (std::size_t i = 0, n = split_data.size(); i < n; ++i) { - RegTree::Node split_node = tree[i]; - auto split_type = p_tree->NodeSplitType(i); - auto node_cats = common::GetNodeCats(d_categories, cat_segments[i]); - split_data[i] = NodeSplitData{std::move(split_node), split_type, node_cats}; + std::vector split_data(p_tree->NumNodes()); + auto const& tree = *p_tree; + for (std::size_t i = 0, n = split_data.size(); i < n; ++i) { + RegTree::Node split_node = tree[i]; + auto split_type = p_tree->NodeSplitType(i); + auto node_cats = common::GetNodeCats(d_categories, cat_segments[i]); + split_data[i] = NodeSplitData{std::move(split_node), split_type, node_cats}; + } + + auto go_left_op = GoLeftOp{d_matrix}; + dh::caching_device_vector d_split_data; + dh::CopyToD(split_data, &d_split_data); + auto s_split_data = dh::ToSpan(d_split_data); + + row_partitioner_->FinalisePosition(d_out_position, + [=] __device__(bst_idx_t row_id, bst_node_t nidx) { + auto split_data = s_split_data[nidx]; + auto node = split_data.split_node; + while (!node.IsLeaf()) { + auto go_left = go_left_op(row_id, split_data); + nidx = go_left ? node.LeftChild() : node.RightChild(); + node = s_split_data[nidx].split_node; + } + return encode_op(row_id, nidx); + }); } - auto go_left_op = GoLeftOp{d_matrix}; - dh::caching_device_vector d_split_data; - dh::CopyToD(split_data, &d_split_data); - auto s_split_data = dh::ToSpan(d_split_data); - - row_partitioner_->FinalisePosition(d_out_position, - [=] __device__(bst_idx_t row_id, bst_node_t nidx) { - auto split_data = s_split_data[nidx]; - auto node = split_data.split_node; - while (!node.IsLeaf()) { - auto go_left = go_left_op(row_id, split_data); - nidx = go_left ? node.LeftChild() : node.RightChild(); - node = s_split_data[nidx].split_node; - } - return encode_op(row_id, nidx); - }); dh::CopyTo(d_out_position, &positions_); } @@ -508,7 +520,7 @@ struct GPUHistMakerDevice { auto rc = collective::GlobalSum( ctx_, info_, linalg::MakeVec(reinterpret_cast(d_node_hist), - page->Cuts().TotalBins() * 2 * num_histograms, ctx_->Device())); + cuts_->TotalBins() * 2 * num_histograms, ctx_->Device())); SafeColl(rc); monitor.Stop("AllReduce"); @@ -517,7 +529,8 @@ struct GPUHistMakerDevice { /** * \brief Build GPU local histograms for the left and right child of some parent node */ - void BuildHistLeftRight(std::vector const& candidates, const RegTree& tree) { + void BuildHistLeftRight(DMatrix* p_fmat, std::vector const& candidates, + const RegTree& tree) { if (candidates.empty()) { return; } @@ -544,8 +557,10 @@ struct GPUHistMakerDevice { // Guaranteed contiguous memory hist.AllocateHistograms(ctx_, all_new); - for (auto nidx : hist_nidx) { - this->BuildHist(nidx); + for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { + for (auto nidx : hist_nidx) { + this->BuildHist(page.Impl(), nidx); + } } // Reduce all in one go @@ -560,7 +575,9 @@ struct GPUHistMakerDevice { if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) { // Calculate other histogram manually - this->BuildHist(subtraction_trick_nidx); + for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { + this->BuildHist(page.Impl(), subtraction_trick_nidx); + } this->AllReduceHist(subtraction_trick_nidx, 1); } } @@ -595,7 +612,7 @@ struct GPUHistMakerDevice { std::vector split_cats; auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid); - auto n_bins_feature = page->Cuts().FeatureBins(candidate.split.findex); + auto n_bins_feature = cuts_->FeatureBins(candidate.split.findex); split_cats.resize(common::CatBitField::ComputeStorageSize(n_bins_feature), 0); CHECK_LE(split_cats.size(), h_cats.size()); std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data()); @@ -618,7 +635,7 @@ struct GPUHistMakerDevice { parent.RightChild()); } - GPUExpandEntry InitRoot(RegTree* p_tree) { + GPUExpandEntry InitRoot(DMatrix* p_fmat, RegTree* p_tree) { constexpr bst_node_t kRootNIdx = 0; dh::XGBCachingDeviceAllocator alloc; auto quantiser = *this->quantiser; @@ -635,7 +652,9 @@ struct GPUHistMakerDevice { collective::SafeColl(rc); hist.AllocateHistograms(ctx_, {kRootNIdx}); - this->BuildHist(kRootNIdx); + for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { + this->BuildHist(page.Impl(), kRootNIdx); + } this->AllReduceHist(kRootNIdx, 1); // Remember root stats @@ -646,24 +665,25 @@ struct GPUHistMakerDevice { (*p_tree)[kRootNIdx].SetLeaf(param.learning_rate * weight); // Generate first split - auto root_entry = this->EvaluateRootSplit(root_sum_quantised); + auto root_entry = this->EvaluateRootSplit(p_fmat, root_sum_quantised); return root_entry; } void UpdateTree(HostDeviceVector* gpair_all, DMatrix* p_fmat, ObjInfo const* task, RegTree* p_tree, HostDeviceVector* p_out_position) { bool const is_single_block = p_fmat->SingleColBlock(); + bst_idx_t const n_samples = p_fmat->Info().num_row_; auto& tree = *p_tree; // Process maximum 32 nodes at a time Driver driver(param, 32); monitor.Start("Reset"); - this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_); + p_fmat = this->Reset(gpair_all, p_fmat); monitor.Stop("Reset"); monitor.Start("InitRoot"); - driver.Push({this->InitRoot(p_tree)}); + driver.Push({this->InitRoot(p_fmat, p_tree)}); monitor.Stop("InitRoot"); // The set of leaves that can be expanded asynchronously @@ -683,11 +703,11 @@ struct GPUHistMakerDevice { // Update all the nodes if working with external memory, this saves us from working // with the finalize position call, which adds an additional iteration and requires // special handling for row index. - this->UpdatePosition(is_single_block ? filtered_expand_set : expand_set, p_tree); + this->UpdatePosition(p_fmat, is_single_block ? filtered_expand_set : expand_set, p_tree); - this->BuildHistLeftRight(filtered_expand_set, tree); + this->BuildHistLeftRight(p_fmat, filtered_expand_set, tree); - this->EvaluateSplits(filtered_expand_set, *p_tree, new_candidates); + this->EvaluateSplits(p_fmat, filtered_expand_set, *p_tree, new_candidates); dh::DefaultStream().Sync(); driver.Push(new_candidates.begin(), new_candidates.end()); @@ -701,7 +721,7 @@ struct GPUHistMakerDevice { if (is_single_block) { CHECK_GE(p_tree->NumNodes(), this->row_partitioner_->GetNumNodes()); } - this->FinalisePosition(p_tree, p_fmat, *task, p_out_position); + this->FinalisePosition(p_fmat, p_tree, *task, n_samples, p_out_position); } }; @@ -750,9 +770,8 @@ class GPUHistMaker : public TreeUpdater { monitor_.Stop(__func__); } - void InitDataOnce(TrainParam const* param, DMatrix* dmat) { + void InitDataOnce(TrainParam const* param, DMatrix* p_fmat) { CHECK_GE(ctx_->Ordinal(), 0) << "Must have at least one device"; - info_ = &dmat->Info(); // Synchronise the column sampling seed uint32_t column_sampling_seed = common::GlobalRandom()(); @@ -761,13 +780,19 @@ class GPUHistMaker : public TreeUpdater { SafeColl(rc); this->column_sampler_ = std::make_shared(column_sampling_seed); - dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); - info_->feature_types.SetDevice(ctx_->Device()); - maker = std::make_unique( - ctx_, !dmat->SingleColBlock(), info_->feature_types.ConstDeviceSpan(), info_->num_row_, - *param, column_sampler_, info_->num_col_, HistBatch(*param), dmat->Info()); + std::shared_ptr cuts; + auto batch = HistBatch(*param); + for (auto const& page : p_fmat->GetBatches(ctx_, HistBatch(*param))) { + cuts = page.Impl()->CutsShared(); + } - p_last_fmat_ = dmat; + dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); + p_fmat->Info().feature_types.SetDevice(ctx_->Device()); + maker = std::make_unique(ctx_, cuts, !p_fmat->SingleColBlock(), + p_fmat->Info().feature_types.ConstDeviceSpan(), + *param, column_sampler_, batch, p_fmat->Info()); + + p_last_fmat_ = p_fmat; initialised_ = true; } @@ -801,8 +826,6 @@ class GPUHistMaker : public TreeUpdater { return result; } - MetaInfo* info_{}; // NOLINT - std::unique_ptr maker; // NOLINT [[nodiscard]] char const* Name() const override { return "grow_gpu_hist"; } @@ -873,9 +896,15 @@ class GPUGlobalApproxMaker : public TreeUpdater { auto const& info = p_fmat->Info(); info.feature_types.SetDevice(ctx_->Device()); - maker_ = std::make_unique( - ctx_, !p_fmat->SingleColBlock(), info.feature_types.ConstDeviceSpan(), info.num_row_, - *param, column_sampler_, info.num_col_, ApproxBatch(*param, hess, *task_), p_fmat->Info()); + std::shared_ptr cuts; + auto batch = ApproxBatch(*param, hess, *task_); + for (auto const& page : p_fmat->GetBatches(ctx_, batch)) { + cuts = page.Impl()->CutsShared(); + } + batch.regen = false; // Regen only at the beginning of the iteration. + maker_ = std::make_unique(ctx_, cuts, !p_fmat->SingleColBlock(), + info.feature_types.ConstDeviceSpan(), *param, + column_sampler_, batch, p_fmat->Info()); std::size_t t_idx{0}; for (xgboost::RegTree* tree : trees) { diff --git a/tests/cpp/common/test_linalg.cu b/tests/cpp/common/test_linalg.cu index bf217842b..6ba398bde 100644 --- a/tests/cpp/common/test_linalg.cu +++ b/tests/cpp/common/test_linalg.cu @@ -2,8 +2,9 @@ * Copyright 2021-2024, XGBoost Contributors */ #include -#include // for equal -#include // for sequence +#include // for equal +#include // for make_constant_iterator +#include // for sequence #include "../../../src/common/cuda_context.cuh" #include "../../../src/common/linalg_op.cuh" @@ -83,6 +84,14 @@ void TestSlice() { } }); } + +void TestWriteAccess(CUDAContext const* cuctx, linalg::TensorView t) { + thrust::for_each(cuctx->CTP(), linalg::tbegin(t), linalg::tend(t), + [=] XGBOOST_DEVICE(double& v) { v = 0; }); + auto eq = thrust::equal(cuctx->CTP(), linalg::tcbegin(t), linalg::tcend(t), + thrust::make_constant_iterator(0.0), thrust::equal_to<>{}); + ASSERT_TRUE(eq); +} } // anonymous namespace TEST(Linalg, GPUElementWise) { TestElementWiseKernel(); } @@ -106,5 +115,7 @@ TEST(Linalg, GPUIter) { bool eq = thrust::equal(cuctx->CTP(), data.cbegin(), data.cend(), linalg::tcbegin(t)); ASSERT_TRUE(eq); + + TestWriteAccess(cuctx, t); } } // namespace xgboost::linalg diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index f3957a002..0dc4f8e8a 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -1,5 +1,5 @@ /** - * Copyright 2019-2023, XGBoost contributors + * Copyright 2019-2024, XGBoost contributors */ #include @@ -15,7 +15,6 @@ #include "gtest/gtest.h" namespace xgboost { - TEST(EllpackPage, EmptyDMatrix) { constexpr int kNRows = 0, kNCols = 0, kMaxBin = 256; constexpr float kSparsity = 0; @@ -242,7 +241,7 @@ TEST(EllpackPage, Compact) { namespace { class EllpackPageTest : public testing::TestWithParam { protected: - void Run(float sparsity) { + void TestFromGHistIndex(float sparsity) const { // Only testing with small sample size as the cuts might be different between host and // device. size_t n_samples{128}, n_features{13}; @@ -273,9 +272,25 @@ class EllpackPageTest : public testing::TestWithParam { } } } + + void TestNumNonMissing(float sparsity) const { + size_t n_samples{1024}, n_features{13}; + auto ctx = MakeCUDACtx(0); + auto p_fmat = RandomDataGenerator{n_samples, n_features, sparsity}.GenerateDMatrix(true); + auto nnz = p_fmat->Info().num_nonzero_; + for (auto const& page : p_fmat->GetBatches( + &ctx, BatchParam{17, tree::TrainParam::DftSparseThreshold()})) { + auto ellpack_nnz = + page.Impl()->NumNonMissing(&ctx, p_fmat->Info().feature_types.ConstDeviceSpan()); + ASSERT_EQ(nnz, ellpack_nnz); + } + } }; } // namespace -TEST_P(EllpackPageTest, FromGHistIndex) { this->Run(GetParam()); } +TEST_P(EllpackPageTest, FromGHistIndex) { this->TestFromGHistIndex(GetParam()); } + +TEST_P(EllpackPageTest, NumNonMissing) { this->TestNumNonMissing(this->GetParam()); } + INSTANTIATE_TEST_SUITE_P(EllpackPage, EllpackPageTest, testing::Values(.0f, .2f, .4f, .8f)); } // namespace xgboost diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index 837ca7768..dffa0bfed 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -355,4 +355,70 @@ TEST(MetaInfo, HostExtend) { } TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(DeviceOrd::CPU()); } + +namespace { +class TestMetaInfo : public ::testing::TestWithParam> { + public: + void Run(Context const *ctx, bst_target_t n_targets) { + MetaInfo info; + info.num_row_ = 128; + info.num_col_ = 3; + info.feature_names.resize(info.num_col_, "a"); + info.labels.Reshape(info.num_row_, n_targets); + + HostDeviceVector ridx(info.num_row_ / 2, 0); + ridx.SetDevice(ctx->Device()); + auto h_ridx = ridx.HostSpan(); + for (std::size_t i = 0, j = 0; i < ridx.Size(); i++, j += 2) { + h_ridx[i] = j; + } + + { + info.weights_.Resize(info.num_row_); + auto h_w = info.weights_.HostSpan(); + std::iota(h_w.begin(), h_w.end(), 0); + } + + auto out = info.Slice(ctx, ctx->IsCPU() ? h_ridx : ridx.ConstDeviceSpan(), /*nnz=*/256); + + ASSERT_EQ(info.labels.Device(), ctx->Device()); + auto h_y = info.labels.HostView(); + auto h_y_out = out.labels.HostView(); + ASSERT_EQ(h_y_out.Shape(0), ridx.Size()); + ASSERT_EQ(h_y_out.Shape(1), n_targets); + + auto h_w = info.weights_.ConstHostSpan(); + auto h_w_out = out.weights_.ConstHostSpan(); + ASSERT_EQ(h_w_out.size(), ridx.Size()); + + for (std::size_t i = 0; i < ridx.Size(); ++i) { + for (bst_target_t t = 0; t < n_targets; ++t) { + ASSERT_EQ(h_y_out(i, t), h_y(h_ridx[i], t)); + } + ASSERT_EQ(h_w_out[i], h_w[h_ridx[i]]); + } + + for (auto v : info.feature_names) { + ASSERT_EQ(v, "a"); + } + } +}; +} // anonymous namespace + +TEST_P(TestMetaInfo, Slice) { + Context ctx; + auto [n_targets, is_cuda] = this->GetParam(); + if (is_cuda) { + ctx = MakeCUDACtx(0); + } + this->Run(&ctx, n_targets); +} + +INSTANTIATE_TEST_SUITE_P(Cpu, TestMetaInfo, + ::testing::Values(std::tuple{1u, false}, std::tuple{3u, false})); + +#if defined(XGBOOST_USE_CUDA) +INSTANTIATE_TEST_SUITE_P(Gpu, TestMetaInfo, + ::testing::Values(std::tuple{1u, true}, std::tuple{3u, true})); +#endif // defined(XGBOOST_USE_CUDA) } // namespace xgboost diff --git a/tests/cpp/data/test_metainfo.h b/tests/cpp/data/test_metainfo.h index 92cd6cb91..53da10dcc 100644 --- a/tests/cpp/data/test_metainfo.h +++ b/tests/cpp/data/test_metainfo.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2021 by XGBoost Contributors +/** + * Copyright 2021-2024, XGBoost Contributors */ #ifndef XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_ #define XGBOOST_TESTS_CPP_DATA_TEST_METAINFO_H_ @@ -11,7 +11,6 @@ #include #include "../../../src/common/linalg_op.h" -#include "../../../src/data/array_interface.h" namespace xgboost { inline void TestMetaInfoStridedData(DeviceOrd device) { diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index b1e86e2eb..c86489102 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -39,11 +39,11 @@ void VerifySampling(size_t page_size, float subsample, int sampling_method, if (fixed_size_sampling) { EXPECT_EQ(sample.sample_rows, kRows); - EXPECT_EQ(sample.page->n_rows, kRows); + EXPECT_EQ(sample.p_fmat->Info().num_row_, kRows); EXPECT_EQ(sample.gpair.size(), kRows); } else { EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.03); - EXPECT_NEAR(sample.page->n_rows, sample_rows, kRows * 0.03f); + EXPECT_NEAR(sample.p_fmat->Info().num_row_, sample_rows, kRows * 0.03f); EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.03f); } @@ -88,25 +88,28 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) { GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true); auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get()); - auto sampled_page = sample.page; + auto p_fmat = sample.p_fmat; EXPECT_EQ(sample.sample_rows, kRows); EXPECT_EQ(sample.gpair.size(), gpair.Size()); EXPECT_EQ(sample.gpair.data(), gpair.DevicePointer()); - EXPECT_EQ(sampled_page->n_rows, kRows); + EXPECT_EQ(p_fmat->Info().num_row_, kRows); - std::vector h_gidx_buffer; - auto h_accessor = sampled_page->GetHostAccessor(&ctx, &h_gidx_buffer); + ASSERT_EQ(p_fmat->NumBatches(), 1); + for (auto const& sampled_page : p_fmat->GetBatches(&ctx, param)) { + std::vector h_gidx_buffer; + auto h_accessor = sampled_page.Impl()->GetHostAccessor(&ctx, &h_gidx_buffer); - std::size_t offset = 0; - for (auto& batch : dmat->GetBatches(&ctx, param)) { - auto page = batch.Impl(); - std::vector h_page_gidx_buffer; - auto page_accessor = page->GetHostAccessor(&ctx, &h_page_gidx_buffer); - size_t num_elements = page->n_rows * page->row_stride; - for (size_t i = 0; i < num_elements; i++) { - EXPECT_EQ(h_accessor.gidx_iter[i + offset], page_accessor.gidx_iter[i]); + std::size_t offset = 0; + for (auto& batch : dmat->GetBatches(&ctx, param)) { + auto page = batch.Impl(); + std::vector h_page_gidx_buffer; + auto page_accessor = page->GetHostAccessor(&ctx, &h_page_gidx_buffer); + size_t num_elements = page->n_rows * page->row_stride; + for (size_t i = 0; i < num_elements; i++) { + EXPECT_EQ(h_accessor.gidx_iter[i + offset], page_accessor.gidx_iter[i]); + } + offset += num_elements; } - offset += num_elements; } }