From e1a2c1bbb366f30fedcaa230435de489a3f7a289 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 31 Aug 2024 03:25:37 +0800 Subject: [PATCH] [EM] Merge GPU partitioning with histogram building. (#10766) - Stop concatenating pages if there's no subsampling. - Use a single iteration for histogram build and partitioning. --- python-package/xgboost/testing/updater.py | 10 +- src/tree/gpu_hist/gradient_based_sampler.cu | 22 +- src/tree/gpu_hist/gradient_based_sampler.cuh | 2 - src/tree/gpu_hist/row_partitioner.cu | 4 + src/tree/updater_gpu_hist.cu | 197 +++++++++--------- .../gpu_hist/test_gradient_based_sampler.cu | 26 +-- tests/python-gpu/test_gpu_data_iterator.py | 16 +- 7 files changed, 118 insertions(+), 159 deletions(-) diff --git a/python-package/xgboost/testing/updater.py b/python-package/xgboost/testing/updater.py index 3a8715a4d..cf46bd43f 100644 --- a/python-package/xgboost/testing/updater.py +++ b/python-package/xgboost/testing/updater.py @@ -222,10 +222,12 @@ def check_extmem_qdm( Xy = xgb.QuantileDMatrix(X, y, weight=w) booster = xgb.train({"device": device}, Xy, num_boost_round=8) - cut_it = Xy_it.get_quantile_cut() - cut = Xy.get_quantile_cut() - np.testing.assert_allclose(cut_it[0], cut[0]) - np.testing.assert_allclose(cut_it[1], cut[1]) + if device == "cpu": + # Get cuts from ellpack without CPU-GPU interpolation is not yet supported. + cut_it = Xy_it.get_quantile_cut() + cut = Xy.get_quantile_cut() + np.testing.assert_allclose(cut_it[0], cut[0]) + np.testing.assert_allclose(cut_it[1], cut[1]) predt_it = booster_it.predict(Xy_it) predt = booster.predict(Xy) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 50a00149b..077cc2c72 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -158,28 +158,10 @@ GradientBasedSample NoSampling::Sample(Context const*, common::Span gpair, DMatrix* p_fmat) { - std::shared_ptr new_page; - if (!page_concatenated_) { - // Concatenate all the external memory ELLPACK pages into a single in-memory page. - bst_idx_t offset = 0; - for (auto& batch : p_fmat->GetBatches(ctx, batch_param_)) { - auto page = batch.Impl(); - if (!new_page) { - new_page = std::make_shared(); - *new_page->Impl() = EllpackPageImpl(ctx, page->CutsShared(), page->is_dense, - page->row_stride, p_fmat->Info().num_row_); - } - bst_idx_t num_elements = new_page->Impl()->Copy(ctx, page, offset); - offset += num_elements; - } - page_concatenated_ = true; - this->p_fmat_new_ = - std::make_unique(new_page, p_fmat->Info(), batch_param_); - } - return {this->p_fmat_new_.get(), gpair}; + return {p_fmat, gpair}; } UniformSampling::UniformSampling(BatchParam batch_param, float subsample) diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index d7e24dafc..ea3d10cd0 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -46,8 +46,6 @@ class ExternalMemoryNoSampling : public SamplingStrategy { private: BatchParam batch_param_; - std::unique_ptr p_fmat_new_{nullptr}; - bool page_concatenated_{false}; }; /*! \brief Uniform sampling in in-memory mode. */ diff --git a/src/tree/gpu_hist/row_partitioner.cu b/src/tree/gpu_hist/row_partitioner.cu index bec500078..9bde18ed2 100644 --- a/src/tree/gpu_hist/row_partitioner.cu +++ b/src/tree/gpu_hist/row_partitioner.cu @@ -22,6 +22,10 @@ void RowPartitioner::Reset(Context const* ctx, bst_idx_t n_samples, bst_idx_t ba NodePositionInfo{Segment{0, static_cast(n_samples)}}); thrust::sequence(ctx->CUDACtx()->CTP(), ridx_.data(), ridx_.data() + ridx_.size(), base_rowid); + + // Pre-allocate some host memory + this->pinned_.GetSpan(1 << 11); + this->pinned2_.GetSpan(1 << 13); } RowPartitioner::~RowPartitioner() = default; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index e4e27b72a..fcb38c3e5 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -200,6 +200,7 @@ struct GPUHistMakerDevice { // Reset values for each update iteration [[nodiscard]] DMatrix* Reset(HostDeviceVector* dh_gpair, DMatrix* p_fmat) { + this->monitor.Start(__func__); auto const& info = p_fmat->Info(); this->column_sampler_->Init(ctx_, p_fmat->Info().num_col_, info.feature_weights.HostVector(), param.colsample_bynode, param.colsample_bylevel, @@ -252,7 +253,7 @@ struct GPUHistMakerDevice { this->histogram_.Reset(ctx_, this->hist_param_->MaxCachedHistNodes(ctx_->Device()), feature_groups->DeviceAccessor(ctx_->Device()), cuts_->TotalBins(), false); - + this->monitor.Stop(__func__); return p_fmat; } @@ -346,6 +347,38 @@ struct GPUHistMakerDevice { monitor.Stop(__func__); } + void ReduceHist(DMatrix* p_fmat, std::vector const& candidates, + std::vector const& build_nidx, + std::vector const& subtraction_nidx) { + if (candidates.empty()) { + return; + } + this->monitor.Start(__func__); + + // Reduce all in one go + // This gives much better latency in a distributed setting when processing a large batch + this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), build_nidx.at(0), build_nidx.size()); + // Perform subtraction for sibiling nodes + auto need_build = this->histogram_.SubtractHist(candidates, build_nidx, subtraction_nidx); + if (need_build.empty()) { + this->monitor.Stop(__func__); + return; + } + + // Build the nodes that can not obtain the histogram using subtraction. This is the slow path. + std::int32_t k = 0; + for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { + for (auto nidx : need_build) { + this->BuildHist(page, k, nidx); + } + ++k; + } + for (auto nidx : need_build) { + this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), nidx, 1); + } + this->monitor.Stop(__func__); + } + void UpdatePositionColumnSplit(EllpackDeviceAccessor d_matrix, std::vector const& split_data, std::vector const& nidx, @@ -434,56 +467,74 @@ struct GPUHistMakerDevice { } }; - void UpdatePosition(DMatrix* p_fmat, std::vector const& candidates, - RegTree* p_tree) { - if (candidates.empty()) { + // Update position and build histogram. + void PartitionAndBuildHist(DMatrix* p_fmat, std::vector const& expand_set, + std::vector const& candidates, RegTree const* p_tree) { + if (expand_set.empty()) { return; } - monitor.Start(__func__); + CHECK_LE(candidates.size(), expand_set.size()); - auto [nidx, left_nidx, right_nidx, split_data] = this->CreatePartitionNodes(p_tree, candidates); + // Update all the nodes if working with external memory, this saves us from working + // with the finalize position call, which adds an additional iteration and requires + // special handling for row index. + bool const is_single_block = p_fmat->SingleColBlock(); - for (size_t i = 0; i < candidates.size(); i++) { - auto const& e = candidates[i]; - RegTree::Node const& split_node = (*p_tree)[e.nid]; - auto split_type = p_tree->NodeSplitType(e.nid); - nidx[i] = e.nid; - left_nidx[i] = split_node.LeftChild(); - right_nidx[i] = split_node.RightChild(); - split_data[i] = NodeSplitData{split_node, split_type, evaluator_.GetDeviceNodeCats(e.nid)}; + // Prepare for update partition + auto [nidx, left_nidx, right_nidx, split_data] = + this->CreatePartitionNodes(p_tree, is_single_block ? candidates : expand_set); - CHECK_EQ(split_type == FeatureType::kCategorical, e.split.is_cat); - } + // Prepare for build hist + std::vector build_nidx(candidates.size()); + std::vector subtraction_nidx(candidates.size()); + auto prefetch_copy = + AssignNodes(p_tree, this->quantiser.get(), candidates, build_nidx, subtraction_nidx); - CHECK_EQ(p_fmat->NumBatches(), 1); - for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { + this->histogram_.AllocateHistograms(ctx_, build_nidx, subtraction_nidx); + + monitor.Start("Partition-BuildHist"); + + std::int32_t k{0}; + for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(prefetch_copy))) { auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_->Device()); + auto go_left = GoLeftOp{d_matrix}; + // Partition histogram. + monitor.Start("UpdatePositionBatch"); if (p_fmat->Info().IsColumnSplit()) { UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx); - monitor.Stop(__func__); - return; + } else { + partitioners_.at(k)->UpdatePositionBatch( + nidx, left_nidx, right_nidx, split_data, + [=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/, + const NodeSplitData& data) { return go_left(ridx, data); }); } - auto go_left = GoLeftOp{d_matrix}; - partitioners_.front()->UpdatePositionBatch( - nidx, left_nidx, right_nidx, split_data, - [=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/, - const NodeSplitData& data) { return go_left(ridx, data); }); + monitor.Stop("UpdatePositionBatch"); + + for (auto nidx : build_nidx) { + this->BuildHist(page, k, nidx); + } + + ++k; } + monitor.Stop("Partition-BuildHist"); + + this->ReduceHist(p_fmat, candidates, build_nidx, subtraction_nidx); + monitor.Stop(__func__); } // After tree update is finished, update the position of all training // instances to their final leaf. This information is used later to update the // prediction cache - void FinalisePosition(DMatrix* p_fmat, RegTree const* p_tree, ObjInfo task, bst_idx_t n_samples, + void FinalisePosition(DMatrix* p_fmat, RegTree const* p_tree, ObjInfo task, HostDeviceVector* p_out_position) { if (!p_fmat->SingleColBlock() && task.UpdateTreeLeaf()) { LOG(FATAL) << "Current objective function can not be used with external memory."; } - if (p_fmat->Info().num_row_ != n_samples) { + if (static_cast(p_fmat->NumBatches() + 1) != this->batch_ptr_.size()) { // External memory with concatenation. Not supported. p_out_position->Resize(0); positions_.clear(); @@ -577,60 +628,6 @@ struct GPUHistMakerDevice { return true; } - /** - * \brief Build GPU local histograms for the left and right child of some parent node - */ - void BuildHistLeftRight(DMatrix* p_fmat, std::vector const& candidates, - const RegTree& tree) { - if (candidates.empty()) { - return; - } - this->monitor.Start(__func__); - // Some nodes we will manually compute histograms - // others we will do by subtraction - std::vector hist_nidx(candidates.size()); - std::vector subtraction_nidx(candidates.size()); - auto prefetch_copy = - AssignNodes(&tree, this->quantiser.get(), candidates, hist_nidx, subtraction_nidx); - - std::vector all_new = hist_nidx; - all_new.insert(all_new.end(), subtraction_nidx.begin(), subtraction_nidx.end()); - // Allocate the histograms - // Guaranteed contiguous memory - histogram_.AllocateHistograms(ctx_, all_new); - - std::int32_t k = 0; - for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(prefetch_copy))) { - for (auto nidx : hist_nidx) { - this->BuildHist(page, k, nidx); - } - ++k; - } - - // Reduce all in one go - // This gives much better latency in a distributed setting - // when processing a large batch - this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), hist_nidx.at(0), hist_nidx.size()); - - for (size_t i = 0; i < subtraction_nidx.size(); i++) { - auto build_hist_nidx = hist_nidx.at(i); - auto subtraction_trick_nidx = subtraction_nidx.at(i); - auto parent_nidx = candidates.at(i).nid; - - if (!this->histogram_.SubtractionTrick(parent_nidx, build_hist_nidx, - subtraction_trick_nidx)) { - // Calculate other histogram manually - std::int32_t k = 0; - for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { - this->BuildHist(page, k, subtraction_trick_nidx); - ++k; - } - this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), subtraction_trick_nidx, 1); - } - } - this->monitor.Stop(__func__); - } - void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) { RegTree& tree = *p_tree; @@ -681,8 +678,9 @@ struct GPUHistMakerDevice { } GPUExpandEntry InitRoot(DMatrix* p_fmat, RegTree* p_tree) { - constexpr bst_node_t kRootNIdx = 0; - dh::XGBCachingDeviceAllocator alloc; + this->monitor.Start(__func__); + + constexpr bst_node_t kRootNIdx = RegTree::kRoot; auto quantiser = *this->quantiser; auto gpair_it = dh::MakeTransformIterator( dh::tbegin(gpair), @@ -697,6 +695,7 @@ struct GPUHistMakerDevice { histogram_.AllocateHistograms(ctx_, {kRootNIdx}); std::int32_t k = 0; + CHECK_EQ(p_fmat->NumBatches(), this->partitioners_.size()); for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { this->BuildHist(page, k, kRootNIdx); ++k; @@ -712,25 +711,18 @@ struct GPUHistMakerDevice { // Generate first split auto root_entry = this->EvaluateRootSplit(p_fmat, root_sum_quantised); + + this->monitor.Stop(__func__); return root_entry; } void UpdateTree(HostDeviceVector* gpair_all, DMatrix* p_fmat, ObjInfo const* task, RegTree* p_tree, HostDeviceVector* p_out_position) { - bool const is_single_block = p_fmat->SingleColBlock(); - bst_idx_t const n_samples = p_fmat->Info().num_row_; - - auto& tree = *p_tree; // Process maximum 32 nodes at a time Driver driver(param, 32); - monitor.Start("Reset"); p_fmat = this->Reset(gpair_all, p_fmat); - monitor.Stop("Reset"); - - monitor.Start("InitRoot"); driver.Push({this->InitRoot(p_fmat, p_tree)}); - monitor.Stop("InitRoot"); // The set of leaves that can be expanded asynchronously auto expand_set = driver.Pop(); @@ -740,20 +732,17 @@ struct GPUHistMakerDevice { } // Get the candidates we are allowed to expand further // e.g. We do not bother further processing nodes whose children are beyond max depth - std::vector filtered_expand_set; - std::copy_if(expand_set.begin(), expand_set.end(), std::back_inserter(filtered_expand_set), - [&](const auto& e) { return driver.IsChildValid(e); }); + std::vector valid_candidates; + std::copy_if(expand_set.begin(), expand_set.end(), std::back_inserter(valid_candidates), + [&](auto const& e) { return driver.IsChildValid(e); }); + // Allocaate children nodes. auto new_candidates = - pinned.GetSpan(filtered_expand_set.size() * 2, GPUExpandEntry{}); - // Update all the nodes if working with external memory, this saves us from working - // with the finalize position call, which adds an additional iteration and requires - // special handling for row index. - this->UpdatePosition(p_fmat, is_single_block ? filtered_expand_set : expand_set, p_tree); + pinned.GetSpan(valid_candidates.size() * 2, GPUExpandEntry()); - this->BuildHistLeftRight(p_fmat, filtered_expand_set, tree); + this->PartitionAndBuildHist(p_fmat, expand_set, valid_candidates, p_tree); - this->EvaluateSplits(p_fmat, filtered_expand_set, *p_tree, new_candidates); + this->EvaluateSplits(p_fmat, valid_candidates, *p_tree, new_candidates); dh::DefaultStream().Sync(); driver.Push(new_candidates.begin(), new_candidates.end()); @@ -764,10 +753,10 @@ struct GPUHistMakerDevice { // be spliable before evaluation but invalid after evaluation as we have more // restrictions like min loss change after evalaution. Therefore, the check condition // is greater than or equal to. - if (is_single_block) { + if (p_fmat->SingleColBlock()) { CHECK_GE(p_tree->NumNodes(), this->partitioners_.front()->GetNumNodes()); } - this->FinalisePosition(p_fmat, p_tree, *task, n_samples, p_out_position); + this->FinalisePosition(p_fmat, p_tree, *task, p_out_position); } }; diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index bdb36c447..2c3bcdd88 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -67,7 +67,6 @@ TEST(GradientBasedSampler, NoSampling) { VerifySampling(kPageSize, kSubsample, kSamplingMethod); } -// In external mode, when not sampling, we concatenate the pages together. TEST(GradientBasedSampler, NoSamplingExternalMemory) { constexpr size_t kRows = 2048; constexpr size_t kCols = 1; @@ -81,34 +80,11 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) { gpair.SetDevice(ctx.Device()); auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()}; - auto page = (*dmat->GetBatches(&ctx, param).begin()).Impl(); - EXPECT_NE(page->n_rows, kRows); GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true); auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get()); auto p_fmat = sample.p_fmat; - EXPECT_EQ(sample.p_fmat->Info().num_row_, kRows); - EXPECT_EQ(sample.gpair.size(), gpair.Size()); - EXPECT_EQ(sample.gpair.data(), gpair.DevicePointer()); - EXPECT_EQ(p_fmat->Info().num_row_, kRows); - - ASSERT_EQ(p_fmat->NumBatches(), 1); - for (auto const& sampled_page : p_fmat->GetBatches(&ctx, param)) { - std::vector h_gidx_buffer; - auto h_accessor = sampled_page.Impl()->GetHostAccessor(&ctx, &h_gidx_buffer); - - std::size_t offset = 0; - for (auto& batch : dmat->GetBatches(&ctx, param)) { - auto page = batch.Impl(); - std::vector h_page_gidx_buffer; - auto page_accessor = page->GetHostAccessor(&ctx, &h_page_gidx_buffer); - size_t num_elements = page->n_rows * page->row_stride; - for (size_t i = 0; i < num_elements; i++) { - EXPECT_EQ(h_accessor.gidx_iter[i + offset], page_accessor.gidx_iter[i]); - } - offset += num_elements; - } - } + ASSERT_EQ(p_fmat, dmat.get()); } TEST(GradientBasedSampler, UniformSampling) { diff --git a/tests/python-gpu/test_gpu_data_iterator.py b/tests/python-gpu/test_gpu_data_iterator.py index 9aa8cc242..e039e0348 100644 --- a/tests/python-gpu/test_gpu_data_iterator.py +++ b/tests/python-gpu/test_gpu_data_iterator.py @@ -4,7 +4,7 @@ import pytest from hypothesis import given, settings, strategies from xgboost.testing import no_cupy -from xgboost.testing.updater import check_quantile_loss_extmem +from xgboost.testing.updater import check_extmem_qdm, check_quantile_loss_extmem sys.path.append("tests/python") from test_data_iterator import run_data_iterator @@ -59,6 +59,14 @@ def test_cpu_data_iterator() -> None: ) -def test_quantile_objective() -> None: - with pytest.raises(ValueError, match="external memory"): - check_quantile_loss_extmem(2, 2, 2, "hist", "cuda") +@given( + strategies.integers(1, 2048), + strategies.integers(1, 8), + strategies.integers(1, 4), + strategies.booleans(), +) +@settings(deadline=None, max_examples=10, print_blob=True) +def test_extmem_qdm( + n_samples_per_batch: int, n_features: int, n_batches: int, on_host: bool +) -> None: + check_extmem_qdm(n_samples_per_batch, n_features, n_batches, "cuda", on_host)