From 4fe67f10b403712339dc6b275a492077fa8eced8 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 29 Aug 2024 01:35:17 +0800 Subject: [PATCH] [EM] Have one partitioner for each batch. (#10760) - Initialize one partitioner for each batch. - Collect partition size during initialization. - Support base ridx in the finalization. --- src/common/device_helpers.cuh | 5 - src/common/threading_utils.cc | 4 +- src/tree/gpu_hist/gradient_based_sampler.cu | 14 +- src/tree/gpu_hist/gradient_based_sampler.cuh | 6 +- src/tree/gpu_hist/row_partitioner.cu | 2 +- src/tree/gpu_hist/row_partitioner.cuh | 18 +- src/tree/updater_gpu_hist.cu | 323 ++++++++++-------- .../gpu_hist/test_gradient_based_sampler.cu | 4 +- tests/python-gpu/test_gpu_data_iterator.py | 2 +- tests/python/test_data_iterator.py | 14 +- 10 files changed, 211 insertions(+), 181 deletions(-) diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index c8b13ffe1..7d35beb72 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -387,11 +387,6 @@ void CopyTo(Src const &src, Dst *dst) { src.size() * sizeof(SVT), cudaMemcpyDefault)); } -template -void CopyToD(HContainer const &h, DContainer *d) { - CopyTo(h, d); -} - // Keep track of pinned memory allocation struct PinnedMemory { void *temp_storage{nullptr}; diff --git a/src/common/threading_utils.cc b/src/common/threading_utils.cc index 625c98d1c..f7296b7f9 100644 --- a/src/common/threading_utils.cc +++ b/src/common/threading_utils.cc @@ -124,7 +124,7 @@ void NameThread(std::thread* t, StringView name) { char old[16]; auto ret = pthread_getname_np(handle, old, 16); if (ret != 0) { - LOG(WARNING) << "Failed to get the name from thread"; + LOG(DEBUG) << "Failed to get the name from thread"; } auto new_name = std::string{old} + ">" + name.c_str(); // NOLINT if (new_name.size() > 15) { @@ -132,7 +132,7 @@ void NameThread(std::thread* t, StringView name) { } ret = pthread_setname_np(handle, new_name.c_str()); if (ret != 0) { - LOG(WARNING) << "Failed to name thread:" << ret << " :" << new_name; + LOG(DEBUG) << "Failed to name thread:" << ret << " :" << new_name; } #else (void)name; diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 44980ac06..50a00149b 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -152,7 +152,7 @@ NoSampling::NoSampling(BatchParam batch_param) : batch_param_(std::move(batch_pa GradientBasedSample NoSampling::Sample(Context const*, common::Span gpair, DMatrix* dmat) { - return {dmat->Info().num_row_, dmat, gpair}; + return {dmat, gpair}; } ExternalMemoryNoSampling::ExternalMemoryNoSampling(BatchParam batch_param) @@ -179,7 +179,7 @@ GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx, this->p_fmat_new_ = std::make_unique(new_page, p_fmat->Info(), batch_param_); } - return {p_fmat->Info().num_row_, this->p_fmat_new_.get(), gpair}; + return {this->p_fmat_new_.get(), gpair}; } UniformSampling::UniformSampling(BatchParam batch_param, float subsample) @@ -192,7 +192,7 @@ GradientBasedSample UniformSampling::Sample(Context const* ctx, common::SpanCTP(), dh::tbegin(gpair), dh::tend(gpair), thrust::counting_iterator(0), BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair()); - return {p_fmat->Info().num_row_, p_fmat, gpair}; + return {p_fmat, gpair}; } ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows, @@ -252,7 +252,8 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx, // Create the new DMatrix this->p_fmat_new_ = std::make_unique( new_page, dmat->Info().Slice(ctx, dh::ToSpan(compact_row_index_), nnz), batch_param_); - return {sample_rows, this->p_fmat_new_.get(), dh::ToSpan(gpair_)}; + CHECK_EQ(sample_rows, this->p_fmat_new_->Info().num_row_); + return {this->p_fmat_new_.get(), dh::ToSpan(gpair_)}; } GradientBasedSampling::GradientBasedSampling(std::size_t n_rows, BatchParam batch_param, @@ -274,7 +275,7 @@ GradientBasedSample GradientBasedSampling::Sample(Context const* ctx, thrust::counting_iterator(0), dh::tbegin(gpair), PoissonSampling(dh::ToSpan(threshold_), threshold_index, RandomWeight(common::GlobalRandom()()))); - return {n_rows, dmat, gpair}; + return {dmat, gpair}; } ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(size_t n_rows, @@ -334,7 +335,8 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c // Create the new DMatrix this->p_fmat_new_ = std::make_unique( new_page, dmat->Info().Slice(ctx, dh::ToSpan(compact_row_index_), nnz), batch_param_); - return {sample_rows, this->p_fmat_new_.get(), dh::ToSpan(gpair_)}; + CHECK_EQ(sample_rows, this->p_fmat_new_->Info().num_row_); + return {this->p_fmat_new_.get(), dh::ToSpan(gpair_)}; } GradientBasedSampler::GradientBasedSampler(Context const* /*ctx*/, size_t n_rows, diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index 22de2c1fb..d7e24dafc 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -12,11 +12,9 @@ namespace xgboost::tree { struct GradientBasedSample { - /*!\brief Number of sampled rows. */ - bst_idx_t sample_rows; - /*!\brief Sampled rows in ELLPACK format. */ + /** @brief Sampled rows in ELLPACK format. */ DMatrix* p_fmat; - /*!\brief Gradient pairs for the sampled rows. */ + /** @brief Gradient pairs for the sampled rows. */ common::Span gpair; }; diff --git a/src/tree/gpu_hist/row_partitioner.cu b/src/tree/gpu_hist/row_partitioner.cu index c768c89df..bec500078 100644 --- a/src/tree/gpu_hist/row_partitioner.cu +++ b/src/tree/gpu_hist/row_partitioner.cu @@ -31,7 +31,7 @@ common::Span RowPartitioner::GetRows(bst_node_t return dh::ToSpan(ridx_).subspan(segment.begin, segment.Size()); } -common::Span RowPartitioner::GetRows() { +common::Span RowPartitioner::GetRows() const { return dh::ToSpan(ridx_); } diff --git a/src/tree/gpu_hist/row_partitioner.cuh b/src/tree/gpu_hist/row_partitioner.cuh index c754f84c0..3c8dec58e 100644 --- a/src/tree/gpu_hist/row_partitioner.cuh +++ b/src/tree/gpu_hist/row_partitioner.cuh @@ -200,11 +200,11 @@ XGBOOST_DEV_INLINE int GetPositionFromSegments(std::size_t idx, template __global__ __launch_bounds__(kBlockSize) void FinalisePositionKernel( - const common::Span d_node_info, + const common::Span d_node_info, bst_idx_t base_ridx, const common::Span d_ridx, common::Span d_out_position, OpT op) { for (auto idx : dh::GridStrideRange(0, d_ridx.size())) { auto position = GetPositionFromSegments(idx, d_node_info.data()); - RowIndexT ridx = d_ridx[idx]; + RowIndexT ridx = d_ridx[idx] - base_ridx; bst_node_t new_position = op(ridx, position); d_out_position[ridx] = new_position; } @@ -264,7 +264,12 @@ class RowPartitioner { /** * \brief Gets all training rows in the set. */ - common::Span GetRows(); + common::Span GetRows() const; + /** + * @brief Get the number of rows in this partitioner. + */ + std::size_t Size() const { return this->GetRows().size(); } + [[nodiscard]] bst_node_t GetNumNodes() const { return n_nodes_; } /** @@ -351,7 +356,8 @@ class RowPartitioner { * argument and return the new position for this training instance. */ template - void FinalisePosition(common::Span d_out_position, FinalisePositionOpT op) const { + void FinalisePosition(common::Span d_out_position, bst_idx_t base_ridx, + FinalisePositionOpT op) const { dh::TemporaryArray d_node_info_storage(ridx_segments_.size()); dh::safe_cuda(cudaMemcpyAsync(d_node_info_storage.data().get(), ridx_segments_.data(), sizeof(NodePositionInfo) * ridx_segments_.size(), @@ -361,8 +367,8 @@ class RowPartitioner { const int kItemsThread = 8; const int grid_size = xgboost::common::DivRoundUp(ridx_.size(), kBlockSize * kItemsThread); common::Span d_ridx{ridx_.data(), ridx_.size()}; - FinalisePositionKernel - <<>>(dh::ToSpan(d_node_info_storage), d_ridx, d_out_position, op); + FinalisePositionKernel<<>>( + dh::ToSpan(d_node_info_storage), base_ridx, d_ridx, d_out_position, op); } }; }; // namespace xgboost::tree diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 8ca6ef71c..03b0e5a42 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -1,26 +1,25 @@ /** * Copyright 2017-2024, XGBoost contributors */ -#include -#include -#include +#include // for plus +#include // for transform -#include -#include -#include // for size_t -#include // for unique_ptr, make_unique -#include // for move -#include +#include // for max +#include // for isnan +#include // for size_t +#include // for unique_ptr, make_unique +#include // for move +#include // for vector #include "../collective/aggregator.h" -#include "../collective/broadcast.h" -#include "../common/bitfield.h" -#include "../common/categorical.h" +#include "../collective/broadcast.h" // for Broadcast +#include "../common/categorical.h" // for KCatBitField #include "../common/cuda_context.cuh" // for CUDAContext #include "../common/cuda_rt_utils.h" // for CheckComputeCapability #include "../common/device_helpers.cuh" -#include "../common/hist_util.h" -#include "../common/random.h" // for ColumnSampler, GlobalRandom +#include "../common/device_vector.cuh" // for device_vector +#include "../common/hist_util.h" // for HistogramCuts +#include "../common/random.h" // for ColumnSampler, GlobalRandom #include "../common/timer.h" #include "../data/ellpack_page.cuh" #include "../data/ellpack_page.h" @@ -31,20 +30,20 @@ #include "gpu_hist/feature_groups.cuh" #include "gpu_hist/gradient_based_sampler.cuh" #include "gpu_hist/histogram.cuh" -#include "gpu_hist/row_partitioner.cuh" -#include "hist/param.h" -#include "param.h" -#include "sample_position.h" // for SamplePosition -#include "updater_gpu_common.cuh" // for HistBatch -#include "xgboost/base.h" -#include "xgboost/context.h" -#include "xgboost/data.h" -#include "xgboost/host_device_vector.h" -#include "xgboost/json.h" -#include "xgboost/span.h" -#include "xgboost/task.h" // for ObjInfo -#include "xgboost/tree_model.h" -#include "xgboost/tree_updater.h" +#include "gpu_hist/row_partitioner.cuh" // for RowPartitioner +#include "hist/param.h" // for HistMakerTrainParam +#include "param.h" // for TrainParam +#include "sample_position.h" // for SamplePosition +#include "updater_gpu_common.cuh" // for HistBatch +#include "xgboost/base.h" // for bst_idx_t +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for DMatrix +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/json.h" // for Json +#include "xgboost/span.h" // for Span +#include "xgboost/task.h" // for ObjInfo +#include "xgboost/tree_model.h" // for RegTree +#include "xgboost/tree_updater.h" // for TreeUpdater namespace xgboost::tree { DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); @@ -57,32 +56,31 @@ using cuda_impl::HistBatch; // parameter to avoid any regen. using cuda_impl::StaticBatch; +// Extra data for each node that is passed to the update position function +struct NodeSplitData { + RegTree::Node split_node; + FeatureType split_type; + common::KCatBitField node_cats; +}; +static_assert(std::is_trivially_copyable_v); + // GPU tree updater implementation. struct GPUHistMakerDevice { private: GPUHistEvaluator evaluator_; Context const* ctx_; std::shared_ptr column_sampler_; - MetaInfo const& info_; + // Set of row partitioners, one for each batch (external memory). When the training is + // in-core, there's only one partitioner. + std::vector> partitioners_; DeviceHistogramBuilder histogram_; + std::vector batch_ptr_; // node idx for each sample dh::device_vector positions_; - std::unique_ptr row_partitioner_; std::shared_ptr cuts_{nullptr}; public: - // Extra data for each node that is passed to the update position function - struct NodeSplitData { - RegTree::Node split_node; - FeatureType split_type; - common::KCatBitField node_cats; - }; - static_assert(std::is_trivially_copyable_v); - - public: - common::Span feature_types; - DeviceHistogramStorage<> hist{}; dh::device_vector d_gpair; // storage for gpair; @@ -104,21 +102,20 @@ struct GPUHistMakerDevice { std::unique_ptr feature_groups; - GPUHistMakerDevice(Context const* ctx, std::shared_ptr cuts, - bool is_external_memory, common::Span _feature_types, - TrainParam _param, std::shared_ptr column_sampler, - BatchParam batch_param, MetaInfo const& info) + GPUHistMakerDevice(Context const* ctx, TrainParam _param, + std::shared_ptr column_sampler, BatchParam batch_param, + MetaInfo const& info, std::vector batch_ptr, + std::shared_ptr cuts) : evaluator_{_param, static_cast(info.num_col_), ctx->Device()}, ctx_(ctx), - feature_types{_feature_types}, param(std::move(_param)), column_sampler_(std::move(column_sampler)), - interaction_constraints(param, info.num_col_), - info_{info}, + interaction_constraints(param, static_cast(info.num_col_)), + batch_ptr_{std::move(batch_ptr)}, cuts_{std::move(cuts)} { sampler = std::make_unique(ctx, info.num_row_, batch_param, param.subsample, - param.sampling_method, is_external_memory); + param.sampling_method, batch_ptr_.size() > 2); if (!param.monotone_constraints.empty()) { // Copy assigning an empty vector causes an exception in MSVC debug builds monotone_constraints = param.monotone_constraints; @@ -149,27 +146,45 @@ struct GPUHistMakerDevice { this->interaction_constraints.Reset(); - if (d_gpair.size() != dh_gpair->Size()) { - d_gpair.resize(dh_gpair->Size()); - } - dh::safe_cuda(cudaMemcpyAsync(d_gpair.data().get(), dh_gpair->ConstDevicePointer(), - dh_gpair->Size() * sizeof(GradientPair), - cudaMemcpyDeviceToDevice)); - auto sample = sampler->Sample(ctx_, dh::ToSpan(d_gpair), p_fmat); + // Sampling + dh::CopyTo(dh_gpair->ConstDeviceSpan(), &this->d_gpair); // backup the gradient + auto sample = this->sampler->Sample(ctx_, dh::ToSpan(d_gpair), p_fmat); this->gpair = sample.gpair; - p_fmat = sample.p_fmat; - CHECK(p_fmat->SingleColBlock()); + p_fmat = sample.p_fmat; // Update p_fmat before allocating partitioners + p_fmat->Info().feature_types.SetDevice(ctx_->Device()); + std::size_t n_batches = p_fmat->NumBatches(); + bool is_concat = (n_batches + 1) != this->batch_ptr_.size(); + std::vector batch_ptr{batch_ptr_}; + if (is_concat) { + // Concatenate the batch ptrs as well. + batch_ptr = {static_cast(0), p_fmat->Info().num_row_}; + } + // Initialize partitions + if (!partitioners_.empty()) { + CHECK_EQ(partitioners_.size(), n_batches); + } + for (std::size_t k = 0; k < n_batches; ++k) { + if (partitioners_.size() != n_batches) { + // First run. + partitioners_.emplace_back(std::make_unique()); + } + auto base_ridx = batch_ptr[k]; + auto n_samples = batch_ptr.at(k + 1) - base_ridx; + partitioners_[k]->Reset(ctx_, n_samples, base_ridx); + } + CHECK_EQ(partitioners_.size(), n_batches); + if (is_concat) { + CHECK_EQ(partitioners_.size(), 1); + CHECK_EQ(partitioners_.front()->Size(), p_fmat->Info().num_row_); + } - this->evaluator_.Reset(*cuts_, feature_types, p_fmat->Info().num_col_, param, - p_fmat->Info().IsColumnSplit(), ctx_->Device()); + // Other initializations + this->evaluator_.Reset(*cuts_, p_fmat->Info().feature_types.ConstDeviceSpan(), + p_fmat->Info().num_col_, this->param, p_fmat->Info().IsColumnSplit(), + this->ctx_->Device()); quantiser = std::make_unique(ctx_, this->gpair, p_fmat->Info()); - if (!row_partitioner_) { - row_partitioner_ = std::make_unique(); - } - row_partitioner_->Reset(ctx_, sample.sample_rows, 0); - // Init histogram hist.Init(ctx_->Device(), this->cuts_->TotalBins()); hist.Reset(ctx_); @@ -180,23 +195,21 @@ struct GPUHistMakerDevice { return p_fmat; } - GPUExpandEntry EvaluateRootSplit(DMatrix const * p_fmat, GradientPairInt64 root_sum) { - int nidx = RegTree::kRoot; + GPUExpandEntry EvaluateRootSplit(DMatrix const* p_fmat, GradientPairInt64 root_sum) { + bst_node_t nidx = RegTree::kRoot; GPUTrainingParam gpu_param(param); auto sampled_features = column_sampler_->GetFeatureSet(0); sampled_features->SetDevice(ctx_->Device()); common::Span feature_set = interaction_constraints.Query(sampled_features->DeviceSpan(), nidx); EvaluateSplitInputs inputs{nidx, 0, root_sum, feature_set, hist.GetNodeHistogram(nidx)}; - EvaluateSplitSharedInputs shared_inputs{ - gpu_param, - *quantiser, - feature_types, - cuts_->cut_ptrs_.ConstDeviceSpan(), - cuts_->cut_values_.ConstDeviceSpan(), - cuts_->min_vals_.ConstDeviceSpan(), - p_fmat->IsDense() && !collective::IsDistributed() - }; + EvaluateSplitSharedInputs shared_inputs{gpu_param, + *quantiser, + p_fmat->Info().feature_types.ConstDeviceSpan(), + cuts_->cut_ptrs_.ConstDeviceSpan(), + cuts_->cut_values_.ConstDeviceSpan(), + cuts_->min_vals_.ConstDeviceSpan(), + p_fmat->IsDense() && !collective::IsDistributed()}; auto split = this->evaluator_.EvaluateSingleSplit(ctx_, inputs, shared_inputs); return split; } @@ -212,8 +225,9 @@ struct GPUHistMakerDevice { std::vector nidx(2 * candidates.size()); auto h_node_inputs = pinned2.GetSpan(2 * candidates.size()); EvaluateSplitSharedInputs shared_inputs{ - GPUTrainingParam{param}, *quantiser, feature_types, cuts_->cut_ptrs_.ConstDeviceSpan(), - cuts_->cut_values_.ConstDeviceSpan(), cuts_->min_vals_.ConstDeviceSpan(), + GPUTrainingParam{param}, *quantiser, p_fmat->Info().feature_types.ConstDeviceSpan(), + cuts_->cut_ptrs_.ConstDeviceSpan(), cuts_->cut_values_.ConstDeviceSpan(), + cuts_->min_vals_.ConstDeviceSpan(), // is_dense represents the local data p_fmat->IsDense() && !collective::IsDistributed()}; dh::TemporaryArray entries(2 * candidates.size()); @@ -262,7 +276,7 @@ struct GPUHistMakerDevice { void BuildHist(EllpackPageImpl const* page, int nidx) { auto d_node_hist = hist.GetNodeHistogram(nidx); - auto d_ridx = row_partitioner_->GetRows(nidx); + auto d_ridx = partitioners_.front()->GetRows(nidx); this->histogram_.BuildHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()), feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx, d_node_hist, *quantiser); @@ -335,7 +349,7 @@ struct GPUHistMakerDevice { }; collective::SafeColl(rc); - row_partitioner_->UpdatePositionBatch( + partitioners_.front()->UpdatePositionBatch( nidx, left_nidx, right_nidx, split_data, [=] __device__(bst_uint ridx, int nidx_in_batch, NodeSplitData const& data) { auto const index = ridx * num_candidates + nidx_in_batch; @@ -396,16 +410,17 @@ struct GPUHistMakerDevice { CHECK_EQ(split_type == FeatureType::kCategorical, e.split.is_cat); } + CHECK_EQ(p_fmat->NumBatches(), 1); for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_->Device()); - if (info_.IsColumnSplit()) { + if (p_fmat->Info().IsColumnSplit()) { UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx); monitor.Stop(__func__); return; } auto go_left = GoLeftOp{d_matrix}; - row_partitioner_->UpdatePositionBatch( + partitioners_.front()->UpdatePositionBatch( nidx, left_nidx, right_nidx, split_data, [=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/, const NodeSplitData& data) { return go_left(ridx, data); }); @@ -423,25 +438,30 @@ struct GPUHistMakerDevice { LOG(FATAL) << "Current objective function can not be used with external memory."; } if (p_fmat->Info().num_row_ != n_samples) { - // Subsampling with external memory. Not supported. + // External memory with concatenation. Not supported. p_out_position->Resize(0); positions_.clear(); return; } p_out_position->SetDevice(ctx_->Device()); - p_out_position->Resize(row_partitioner_->GetRows().size()); + p_out_position->Resize(p_fmat->Info().num_row_); auto d_out_position = p_out_position->DeviceSpan(); auto d_gpair = this->gpair; - auto encode_op = [=] __device__(bst_idx_t row_id, bst_node_t nidx) { - bool is_invalid = d_gpair[row_id].GetHess() - .0f == 0.f; + auto encode_op = [=] __device__(bst_idx_t ridx, bst_node_t nidx) { + bool is_invalid = d_gpair[ridx].GetHess() - .0f == 0.f; return SamplePosition::Encode(nidx, !is_invalid); }; // NOLINT if (!p_fmat->SingleColBlock()) { - CHECK_EQ(row_partitioner_->GetNumNodes(), p_tree->NumNodes()); - row_partitioner_->FinalisePosition(d_out_position, encode_op); + for (std::size_t k = 0; k < partitioners_.size(); ++k) { + auto& part = partitioners_.at(k); + CHECK_EQ(part->GetNumNodes(), p_tree->NumNodes()); + auto base_ridx = batch_ptr_[k]; + auto n_samples = batch_ptr_.at(k + 1) - base_ridx; + part->FinalisePosition(d_out_position.subspan(base_ridx, n_samples), base_ridx, encode_op); + } dh::CopyTo(d_out_position, &positions_); return; } @@ -465,23 +485,22 @@ struct GPUHistMakerDevice { auto go_left_op = GoLeftOp{d_matrix}; dh::caching_device_vector d_split_data; - dh::CopyToD(split_data, &d_split_data); + dh::CopyTo(split_data, &d_split_data); auto s_split_data = dh::ToSpan(d_split_data); - row_partitioner_->FinalisePosition(d_out_position, - [=] __device__(bst_idx_t row_id, bst_node_t nidx) { - auto split_data = s_split_data[nidx]; - auto node = split_data.split_node; - while (!node.IsLeaf()) { - auto go_left = go_left_op(row_id, split_data); - nidx = go_left ? node.LeftChild() : node.RightChild(); - node = s_split_data[nidx].split_node; - } - return encode_op(row_id, nidx); - }); + partitioners_.front()->FinalisePosition( + d_out_position, page.BaseRowId(), [=] __device__(bst_idx_t row_id, bst_node_t nidx) { + auto split_data = s_split_data[nidx]; + auto node = split_data.split_node; + while (!node.IsLeaf()) { + auto go_left = go_left_op(row_id, split_data); + nidx = go_left ? node.LeftChild() : node.RightChild(); + node = s_split_data[nidx].split_node; + } + return encode_op(row_id, nidx); + }); + dh::CopyTo(d_out_position, &positions_); } - - dh::CopyTo(d_out_position, &positions_); } bool UpdatePredictionCache(linalg::MatrixView out_preds_d, RegTree const* p_tree) { @@ -513,17 +532,16 @@ struct GPUHistMakerDevice { } // num histograms is the number of contiguous histograms in memory to reduce over - void AllReduceHist(bst_node_t nidx, int num_histograms) { - monitor.Start("AllReduce"); - auto d_node_hist = hist.GetNodeHistogram(nidx).data(); - using ReduceT = typename std::remove_pointer::type::ValueT; + void AllReduceHist(MetaInfo const& info, bst_node_t nidx, int num_histograms) { + monitor.Start(__func__); + auto d_node_hist = hist.GetNodeHistogram(nidx); + using ReduceT = typename std::remove_pointer::type::ValueT; auto rc = collective::GlobalSum( - ctx_, info_, - linalg::MakeVec(reinterpret_cast(d_node_hist), - cuts_->TotalBins() * 2 * num_histograms, ctx_->Device())); + ctx_, info, + linalg::MakeVec(reinterpret_cast(d_node_hist.data()), + d_node_hist.size() * 2 * num_histograms, ctx_->Device())); SafeColl(rc); - - monitor.Stop("AllReduce"); + monitor.Stop(__func__); } /** @@ -566,7 +584,7 @@ struct GPUHistMakerDevice { // Reduce all in one go // This gives much better latency in a distributed setting // when processing a large batch - this->AllReduceHist(hist_nidx.at(0), hist_nidx.size()); + this->AllReduceHist(p_fmat->Info(), hist_nidx.at(0), hist_nidx.size()); for (size_t i = 0; i < subtraction_nidx.size(); i++) { auto build_hist_nidx = hist_nidx.at(i); @@ -578,7 +596,7 @@ struct GPUHistMakerDevice { for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { this->BuildHist(page.Impl(), subtraction_trick_nidx); } - this->AllReduceHist(subtraction_trick_nidx, 1); + this->AllReduceHist(p_fmat->Info(), subtraction_trick_nidx, 1); } } this->monitor.Stop(__func__); @@ -588,18 +606,16 @@ struct GPUHistMakerDevice { RegTree& tree = *p_tree; // Sanity check - have we created a leaf with no training instances? - if (!collective::IsDistributed() && row_partitioner_) { - CHECK(row_partitioner_->GetRows(candidate.nid).size() > 0) + if (!collective::IsDistributed() && partitioners_.size() == 1) { + CHECK(partitioners_.front()->GetRows(candidate.nid).size() > 0) << "No training instances in this leaf!"; } auto base_weight = candidate.base_weight; auto left_weight = candidate.left_weight * param.learning_rate; auto right_weight = candidate.right_weight * param.learning_rate; - auto parent_hess = quantiser - ->ToFloatingPoint(candidate.split.left_sum + - candidate.split.right_sum) - .GetHess(); + auto parent_hess = + quantiser->ToFloatingPoint(candidate.split.left_sum + candidate.split.right_sum).GetHess(); auto left_hess = quantiser->ToFloatingPoint(candidate.split.left_sum).GetHess(); auto right_hess = @@ -640,22 +656,21 @@ struct GPUHistMakerDevice { dh::XGBCachingDeviceAllocator alloc; auto quantiser = *this->quantiser; auto gpair_it = dh::MakeTransformIterator( - dh::tbegin(gpair), [=] __device__(auto const &gpair) { - return quantiser.ToFixedPoint(gpair); - }); + dh::tbegin(gpair), + [=] __device__(auto const& gpair) { return quantiser.ToFixedPoint(gpair); }); GradientPairInt64 root_sum_quantised = - dh::Reduce(ctx_->CUDACtx()->CTP(), gpair_it, gpair_it + gpair.size(), - GradientPairInt64{}, thrust::plus{}); + dh::Reduce(ctx_->CUDACtx()->CTP(), gpair_it, gpair_it + gpair.size(), GradientPairInt64{}, + thrust::plus{}); using ReduceT = typename decltype(root_sum_quantised)::ValueT; auto rc = collective::GlobalSum( - ctx_, info_, linalg::MakeVec(reinterpret_cast(&root_sum_quantised), 2)); + ctx_, p_fmat->Info(), linalg::MakeVec(reinterpret_cast(&root_sum_quantised), 2)); collective::SafeColl(rc); hist.AllocateHistograms(ctx_, {kRootNIdx}); for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { this->BuildHist(page.Impl(), kRootNIdx); } - this->AllReduceHist(kRootNIdx, 1); + this->AllReduceHist(p_fmat->Info(), kRootNIdx, 1); // Remember root stats auto root_sum = quantiser.ToFloatingPoint(root_sum_quantised); @@ -719,12 +734,30 @@ struct GPUHistMakerDevice { // restrictions like min loss change after evalaution. Therefore, the check condition // is greater than or equal to. if (is_single_block) { - CHECK_GE(p_tree->NumNodes(), this->row_partitioner_->GetNumNodes()); + CHECK_GE(p_tree->NumNodes(), this->partitioners_.front()->GetNumNodes()); } this->FinalisePosition(p_fmat, p_tree, *task, n_samples, p_out_position); } }; +std::shared_ptr InitBatchCuts(Context const* ctx, DMatrix* p_fmat, + BatchParam batch, + std::vector* p_batch_ptr) { + std::vector& batch_ptr = *p_batch_ptr; + batch_ptr = {0}; + std::shared_ptr cuts; + + for (auto const& page : p_fmat->GetBatches(ctx, batch)) { + batch_ptr.push_back(page.Size()); + cuts = page.Impl()->CutsShared(); + CHECK(cuts->cut_values_.DeviceCanRead()); + } + CHECK(cuts); + CHECK_EQ(p_fmat->NumBatches(), batch_ptr.size() - 1); + std::partial_sum(batch_ptr.cbegin(), batch_ptr.cend(), batch_ptr.begin()); + return cuts; +} + class GPUHistMaker : public TreeUpdater { using GradientSumT = GradientPairPrecise; @@ -774,23 +807,20 @@ class GPUHistMaker : public TreeUpdater { CHECK_GE(ctx_->Ordinal(), 0) << "Must have at least one device"; // Synchronise the column sampling seed - uint32_t column_sampling_seed = common::GlobalRandom()(); - auto rc = collective::Broadcast( - ctx_, linalg::MakeVec(&column_sampling_seed, sizeof(column_sampling_seed)), 0); - SafeColl(rc); + std::uint32_t column_sampling_seed = common::GlobalRandom()(); + SafeColl(collective::Broadcast( + ctx_, linalg::MakeVec(&column_sampling_seed, sizeof(column_sampling_seed)), 0)); this->column_sampler_ = std::make_shared(column_sampling_seed); - std::shared_ptr cuts; - auto batch = HistBatch(*param); - for (auto const& page : p_fmat->GetBatches(ctx_, HistBatch(*param))) { - cuts = page.Impl()->CutsShared(); - } - dh::safe_cuda(cudaSetDevice(ctx_->Ordinal())); p_fmat->Info().feature_types.SetDevice(ctx_->Device()); - maker = std::make_unique(ctx_, cuts, !p_fmat->SingleColBlock(), - p_fmat->Info().feature_types.ConstDeviceSpan(), - *param, column_sampler_, batch, p_fmat->Info()); + + std::vector batch_ptr; + auto batch = HistBatch(*param); + auto cuts = InitBatchCuts(ctx_, p_fmat, batch, &batch_ptr); + + this->maker = std::make_unique(ctx_, *param, column_sampler_, batch, + p_fmat->Info(), batch_ptr, cuts); p_last_fmat_ = p_fmat; initialised_ = true; @@ -896,15 +926,14 @@ class GPUGlobalApproxMaker : public TreeUpdater { auto const& info = p_fmat->Info(); info.feature_types.SetDevice(ctx_->Device()); - std::shared_ptr cuts; + + std::vector batch_ptr; auto batch = ApproxBatch(*param, hess, *task_); - for (auto const& page : p_fmat->GetBatches(ctx_, batch)) { - cuts = page.Impl()->CutsShared(); - } + auto cuts = InitBatchCuts(ctx_, p_fmat, batch, &batch_ptr); batch.regen = false; // Regen only at the beginning of the iteration. - maker_ = std::make_unique(ctx_, cuts, !p_fmat->SingleColBlock(), - info.feature_types.ConstDeviceSpan(), *param, - column_sampler_, batch, p_fmat->Info()); + + this->maker_ = std::make_unique(ctx_, *param, column_sampler_, batch, + p_fmat->Info(), batch_ptr, cuts); std::size_t t_idx{0}; for (xgboost::RegTree* tree : trees) { diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index c86489102..bdb36c447 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -38,11 +38,9 @@ void VerifySampling(size_t page_size, float subsample, int sampling_method, auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get()); if (fixed_size_sampling) { - EXPECT_EQ(sample.sample_rows, kRows); EXPECT_EQ(sample.p_fmat->Info().num_row_, kRows); EXPECT_EQ(sample.gpair.size(), kRows); } else { - EXPECT_NEAR(sample.sample_rows, sample_rows, kRows * 0.03); EXPECT_NEAR(sample.p_fmat->Info().num_row_, sample_rows, kRows * 0.03f); EXPECT_NEAR(sample.gpair.size(), sample_rows, kRows * 0.03f); } @@ -89,7 +87,7 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) { GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true); auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get()); auto p_fmat = sample.p_fmat; - EXPECT_EQ(sample.sample_rows, kRows); + EXPECT_EQ(sample.p_fmat->Info().num_row_, kRows); EXPECT_EQ(sample.gpair.size(), gpair.Size()); EXPECT_EQ(sample.gpair.data(), gpair.DevicePointer()); EXPECT_EQ(p_fmat->Info().num_row_, kRows); diff --git a/tests/python-gpu/test_gpu_data_iterator.py b/tests/python-gpu/test_gpu_data_iterator.py index b42a152fe..9aa8cc242 100644 --- a/tests/python-gpu/test_gpu_data_iterator.py +++ b/tests/python-gpu/test_gpu_data_iterator.py @@ -12,7 +12,7 @@ from test_data_iterator import test_single_batch as cpu_single_batch def test_gpu_single_batch() -> None: - cpu_single_batch("gpu_hist") + cpu_single_batch("hist", "cuda") @pytest.mark.skipif(**no_cupy()) diff --git a/tests/python/test_data_iterator.py b/tests/python/test_data_iterator.py index fbf05a236..a42ad0f75 100644 --- a/tests/python/test_data_iterator.py +++ b/tests/python/test_data_iterator.py @@ -17,7 +17,7 @@ from xgboost.testing.updater import check_quantile_loss_extmem pytestmark = tm.timeout(30) -def test_single_batch(tree_method: str = "approx") -> None: +def test_single_batch(tree_method: str = "approx", device: str = "cpu") -> None: from sklearn.datasets import load_breast_cancer n_rounds = 10 @@ -25,17 +25,19 @@ def test_single_batch(tree_method: str = "approx") -> None: X = X.astype(np.float32) y = y.astype(np.float32) + params = {"tree_method": tree_method, "device": device} + Xy = xgb.DMatrix(SingleBatch(data=X, label=y)) - from_it = xgb.train({"tree_method": tree_method}, Xy, num_boost_round=n_rounds) + from_it = xgb.train(params, Xy, num_boost_round=n_rounds) Xy = xgb.DMatrix(X, y) - from_dmat = xgb.train({"tree_method": tree_method}, Xy, num_boost_round=n_rounds) + from_dmat = xgb.train(params, Xy, num_boost_round=n_rounds) assert from_it.get_dump() == from_dmat.get_dump() X, y = load_breast_cancer(return_X_y=True, as_frame=True) X = X.astype(np.float32) Xy = xgb.DMatrix(SingleBatch(data=X, label=y)) - from_pd = xgb.train({"tree_method": tree_method}, Xy, num_boost_round=n_rounds) + from_pd = xgb.train(params, Xy, num_boost_round=n_rounds) # remove feature info to generate exact same text representation. from_pd.feature_names = None from_pd.feature_types = None @@ -45,11 +47,11 @@ def test_single_batch(tree_method: str = "approx") -> None: X, y = load_breast_cancer(return_X_y=True) X = csr_matrix(X) Xy = xgb.DMatrix(SingleBatch(data=X, label=y)) - from_it = xgb.train({"tree_method": tree_method}, Xy, num_boost_round=n_rounds) + from_it = xgb.train(params, Xy, num_boost_round=n_rounds) X, y = load_breast_cancer(return_X_y=True) Xy = xgb.DMatrix(SingleBatch(data=X, label=y), missing=0.0) - from_np = xgb.train({"tree_method": tree_method}, Xy, num_boost_round=n_rounds) + from_np = xgb.train(params, Xy, num_boost_round=n_rounds) assert from_np.get_dump() == from_it.get_dump()