From d6ebcfb032aecb438de1cc8f4960de057286b27a Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 27 Aug 2024 04:16:57 +0800 Subject: [PATCH] [EM] Support CPU quantile objective for external memory. (#10751) --- python-package/xgboost/testing/updater.py | 31 ++++++++++++ src/common/error_msg.h | 2 + src/common/partition_builder.h | 12 ++--- src/common/quantile.cc | 6 ++- src/common/quantile.cuh | 4 +- src/common/quantile.h | 4 +- src/gbm/gbtree.cc | 2 +- src/tree/common_row_partitioner.h | 29 ++++++----- src/tree/updater_approx.cc | 4 +- src/tree/updater_quantile_hist.cc | 14 +++-- tests/cpp/tree/test_common_partitioner.cc | 59 ++++++++++++++++++++-- tests/python-gpu/test_gpu_data_iterator.py | 6 +++ tests/python/test_data_iterator.py | 26 ++++++++++ 13 files changed, 163 insertions(+), 36 deletions(-) diff --git a/python-package/xgboost/testing/updater.py b/python-package/xgboost/testing/updater.py index 7e360d42b..7063a7b01 100644 --- a/python-package/xgboost/testing/updater.py +++ b/python-package/xgboost/testing/updater.py @@ -163,6 +163,37 @@ def check_quantile_loss(tree_method: str, weighted: bool) -> None: np.testing.assert_allclose(predts[:, i], predt_multi[:, i]) +def check_quantile_loss_extmem( + n_samples_per_batch: int, + n_features: int, + n_batches: int, + tree_method: str, + device: str, +) -> None: + """Check external memory with the quantile objective.""" + it = tm.IteratorForTest( + *tm.make_batches(n_samples_per_batch, n_features, n_batches, device != "cpu"), + cache="cache", + on_host=False, + ) + Xy_it = xgb.DMatrix(it) + params = { + "tree_method": tree_method, + "objective": "reg:quantileerror", + "device": device, + "quantile_alpha": [0.2, 0.8], + } + booster_it = xgb.train(params, Xy_it) + X, y, w = it.as_arrays() + Xy = xgb.DMatrix(X, y, weight=w) + booster = xgb.train(params, Xy) + + predt_it = booster_it.predict(Xy_it) + predt = booster.predict(Xy) + + np.testing.assert_allclose(predt, predt_it) + + def check_cut( n_entries: int, indptr: np.ndarray, data: np.ndarray, dtypes: Any ) -> None: diff --git a/src/common/error_msg.h b/src/common/error_msg.h index 601e63526..02fc6f55c 100644 --- a/src/common/error_msg.h +++ b/src/common/error_msg.h @@ -41,6 +41,8 @@ constexpr StringView InconsistentMaxBin() { "and consistent with the Booster being trained."; } +constexpr StringView InvalidMaxBin() { return "`max_bin` must be equal to or greater than 2."; } + constexpr StringView UnknownDevice() { return "Unknown device type."; } inline void MaxFeatureSize(std::uint64_t n_features) { diff --git a/src/common/partition_builder.h b/src/common/partition_builder.h index 54febd750..0bc58f499 100644 --- a/src/common/partition_builder.h +++ b/src/common/partition_builder.h @@ -367,23 +367,21 @@ class PartitionBuilder { // Copy row partitions into global cache for reuse in objective template void LeafPartition(Context const* ctx, RegTree const& tree, RowSetCollection const& row_set, - std::vector* p_position, Invalidp invalidp) const { - auto& h_pos = *p_position; - h_pos.resize(row_set.Data()->size(), std::numeric_limits::max()); - + Span position, Invalidp invalidp) const { auto p_begin = row_set.Data()->data(); // For each node, walk through all the samples that fall in this node. - ParallelFor(row_set.Size(), ctx->Threads(), [&](size_t i) { + auto p_pos = position.data(); + ParallelFor(row_set.Size(), ctx->Threads(), [&](auto i) { auto const& node = row_set[i]; if (node.node_id < 0) { return; } CHECK(tree.IsLeaf(node.node_id)); if (node.begin()) { // guard for empty node. - size_t ptr_offset = node.end() - p_begin; + std::size_t ptr_offset = node.end() - p_begin; CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id; for (auto idx = node.begin(); idx != node.end(); ++idx) { - h_pos[*idx] = tree::SamplePosition::Encode(node.node_id, !invalidp(*idx)); + p_pos[*idx] = tree::SamplePosition::Encode(node.node_id, !invalidp(*idx)); } } }); diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 05e2f762c..eb02924aa 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -8,6 +8,7 @@ #include #include "../collective/aggregator.h" +#include "../common/error_msg.h" // for InvalidMaxBin #include "../data/adapter.h" #include "categorical.h" #include "hist_util.h" @@ -16,15 +17,16 @@ namespace xgboost::common { template SketchContainerImpl::SketchContainerImpl(Context const *ctx, std::vector columns_size, - int32_t max_bins, + bst_bin_t max_bin, Span feature_types, bool use_group) : feature_types_(feature_types.cbegin(), feature_types.cend()), columns_size_{std::move(columns_size)}, - max_bins_{max_bins}, + max_bins_{max_bin}, use_group_ind_{use_group}, n_threads_{ctx->Threads()} { monitor_.Init(__func__); + CHECK_GE(max_bin, 2) << error::InvalidMaxBin(); CHECK_NE(columns_size_.size(), 0); sketches_.resize(columns_size_.size()); CHECK_GE(n_threads_, 1); diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh index 3dd393755..ae286c3b3 100644 --- a/src/common/quantile.cuh +++ b/src/common/quantile.cuh @@ -8,6 +8,7 @@ #include "categorical.h" #include "device_helpers.cuh" +#include "error_msg.h" // for InvalidMaxBin #include "quantile.h" #include "timer.h" #include "xgboost/data.h" @@ -96,7 +97,7 @@ class SketchContainer { * \param num_rows Total number of rows in known dataset (typically the rows in current worker). * \param device GPU ID. */ - SketchContainer(HostDeviceVector const& feature_types, int32_t max_bin, + SketchContainer(HostDeviceVector const& feature_types, bst_bin_t max_bin, bst_feature_t num_columns, bst_idx_t num_rows, DeviceOrd device) : num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin}, device_{device} { CHECK(device.IsCUDA()); @@ -117,6 +118,7 @@ class SketchContainer { has_categorical_ = !d_feature_types.empty() && thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types), common::IsCatOp{}); + CHECK_GE(max_bin, 2) << error::InvalidMaxBin(); timer_.Init(__func__); } diff --git a/src/common/quantile.h b/src/common/quantile.h index 59bc3a4f7..e189b259b 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -802,10 +802,10 @@ class SketchContainerImpl { /* \brief Initialize necessary info. * * \param columns_size Size of each column. - * \param max_bins maximum number of bins for each feature. + * \param max_bin maximum number of bins for each feature. * \param use_group whether is assigned to group to data instance. */ - SketchContainerImpl(Context const *ctx, std::vector columns_size, bst_bin_t max_bins, + SketchContainerImpl(Context const *ctx, std::vector columns_size, bst_bin_t max_bin, common::Span feature_types, bool use_group); static bool UseGroup(MetaInfo const &info) { diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index fe640ee00..9ada1ff01 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -218,7 +218,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, linalg::Matrix* in_gpair, model_.learner_model_param->OutputLength()); CHECK_NE(n_groups, 0); - if (!p_fmat->SingleColBlock() && obj->Task().UpdateTreeLeaf()) { + if (!p_fmat->SingleColBlock() && obj->Task().UpdateTreeLeaf() && this->ctx_->IsCUDA()) { LOG(FATAL) << "Current objective doesn't support external memory."; } diff --git a/src/tree/common_row_partitioner.h b/src/tree/common_row_partitioner.h index 1bf185341..3e7c1123f 100644 --- a/src/tree/common_row_partitioner.h +++ b/src/tree/common_row_partitioner.h @@ -301,34 +301,37 @@ class CommonRowPartitioner { auto const& operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; } void LeafPartition(Context const* ctx, RegTree const& tree, common::Span hess, - std::vector* p_out_position) const { - partition_builder_.LeafPartition(ctx, tree, this->Partitions(), p_out_position, - [&](size_t idx) -> bool { return hess[idx] - .0f == .0f; }); + common::Span out_position) const { + partition_builder_.LeafPartition( + ctx, tree, this->Partitions(), out_position, + [&](size_t idx) -> bool { return hess[idx - this->base_rowid] - .0f == .0f; }); } void LeafPartition(Context const* ctx, RegTree const& tree, linalg::TensorView gpair, - std::vector* p_out_position) const { + common::Span out_position) const { if (gpair.Shape(1) > 1) { partition_builder_.LeafPartition( - ctx, tree, this->Partitions(), p_out_position, [&](std::size_t idx) -> bool { - auto sample = gpair.Slice(idx, linalg::All()); + ctx, tree, this->Partitions(), out_position, [&](std::size_t idx) -> bool { + auto sample = gpair.Slice(idx - this->base_rowid, linalg::All()); return std::all_of(linalg::cbegin(sample), linalg::cend(sample), [](GradientPair const& g) { return g.GetHess() - .0f == .0f; }); }); } else { auto s = gpair.Slice(linalg::All(), 0); - partition_builder_.LeafPartition( - ctx, tree, this->Partitions(), p_out_position, - [&](std::size_t idx) -> bool { return s(idx).GetHess() - .0f == .0f; }); + partition_builder_.LeafPartition(ctx, tree, this->Partitions(), out_position, + [&](std::size_t idx) -> bool { + return s(idx - this->base_rowid).GetHess() - .0f == .0f; + }); } } void LeafPartition(Context const* ctx, RegTree const& tree, common::Span gpair, - std::vector* p_out_position) const { - partition_builder_.LeafPartition( - ctx, tree, this->Partitions(), p_out_position, - [&](std::size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; }); + common::Span out_position) const { + partition_builder_.LeafPartition(ctx, tree, this->Partitions(), out_position, + [&](std::size_t idx) -> bool { + return gpair[idx - this->base_rowid].GetHess() - .0f == .0f; + }); } private: diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 781c2dcf4..fe5637f4a 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -154,8 +154,10 @@ class GlobalApproxBuilder { if (!task_->UpdateTreeLeaf()) { return; } + p_out_position->resize(hess.size()); for (auto const &part : partitioner_) { - part.LeafPartition(ctx_, tree, hess, p_out_position); + part.LeafPartition(ctx_, tree, hess, + common::Span{p_out_position->data(), p_out_position->size()}); } monitor_->Stop(__func__); } diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 0cd2671fb..724ecf87b 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -126,7 +126,7 @@ class MultiTargetHistBuilder { std::vector partitioner_; // Pointer to last updated tree, used for update prediction cache. RegTree const *p_last_tree_{nullptr}; - DMatrix const * p_last_fmat_{nullptr}; + DMatrix const *p_last_fmat_{nullptr}; ObjInfo const *task_{nullptr}; @@ -254,8 +254,10 @@ class MultiTargetHistBuilder { monitor_->Stop(__func__); return; } + p_out_position->resize(gpair.Shape(0)); for (auto const &part : partitioner_) { - part.LeafPartition(ctx_, tree, gpair, p_out_position); + part.LeafPartition(ctx_, tree, gpair, + common::Span{p_out_position->data(), p_out_position->size()}); } monitor_->Stop(__func__); } @@ -461,8 +463,10 @@ class HistUpdater { monitor_->Stop(__func__); return; } + p_out_position->resize(gpair.Shape(0)); for (auto const &part : partitioner_) { - part.LeafPartition(ctx_, tree, gpair, p_out_position); + part.LeafPartition(ctx_, tree, gpair, + common::Span{p_out_position->data(), p_out_position->size()}); } monitor_->Stop(__func__); } @@ -521,7 +525,9 @@ class QuantileHistMaker : public TreeUpdater { linalg::Matrix sample_out; auto h_sample_out = h_gpair; - auto need_copy = [&] { return trees.size() > 1 || n_targets > 1; }; + auto need_copy = [&] { + return trees.size() > 1 || n_targets > 1; + }; if (need_copy()) { // allocate buffer sample_out = decltype(sample_out){h_gpair.Shape(), ctx_->Device(), linalg::Order::kF}; diff --git a/tests/cpp/tree/test_common_partitioner.cc b/tests/cpp/tree/test_common_partitioner.cc index 116802c6a..2dcac6417 100644 --- a/tests/cpp/tree/test_common_partitioner.cc +++ b/tests/cpp/tree/test_common_partitioner.cc @@ -1,5 +1,5 @@ /** - * Copyright 2022-2023 by XGBoost contributors. + * Copyright 2022-2024, XGBoost contributors. */ #include #include // for bst_node_t @@ -43,14 +43,15 @@ void TestLeafPartition(size_t n_samples) { std::vector h_nptr; float split_value{0}; + bst_feature_t const split_ind = 0; + for (auto const& page : Xy->GetBatches(&ctx, BatchParam{64, 0.2})) { - bst_feature_t const split_ind = 0; auto ptr = page.cut.Ptrs()[split_ind + 1]; split_value = page.cut.Values().at(ptr / 2); GetSplit(&tree, split_value, &candidates); partitioner.UpdatePosition(&ctx, page, candidates, &tree); - std::vector position; - partitioner.LeafPartition(&ctx, tree, hess, &position); + std::vector position(page.Size()); + partitioner.LeafPartition(&ctx, tree, hess, position); std::sort(position.begin(), position.end()); size_t beg = std::distance( position.begin(), @@ -76,13 +77,59 @@ void TestLeafPartition(size_t n_samples) { auto batch = page.GetView(); size_t left{0}; for (size_t i = 0; i < batch.Size(); ++i) { - if (not_sampled(i) && batch[i].front().fvalue < split_value) { + if (not_sampled(i) && batch[i][split_ind].fvalue < split_value) { left++; } } ASSERT_EQ(left, h_nptr[1] - h_nptr[0]); // equal to number of sampled assigned to left } } + +void TestExternalMemory() { + Context ctx; + bst_bin_t max_bin = 32; + auto p_fmat = + RandomDataGenerator{256, 16, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true); + std::vector partitioners; + + RegTree tree; + std::vector candidates{{0, 0}}; + + auto gpair = GenerateRandomGradients(p_fmat->Info().num_row_); + auto t_gpair = linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), p_fmat->Info().num_row_, 1); + std::vector position(p_fmat->Info().num_row_); + + auto param = BatchParam{max_bin, TrainParam::DftSparseThreshold()}; + float split_value{0.0f}; + bst_feature_t const split_ind = 0; + for (auto const& page : p_fmat->GetBatches(&ctx, param)) { + if (partitioners.empty()) { + auto ptr = page.cut.Ptrs()[split_ind + 1]; + split_value = page.cut.Values().at(ptr / 2); + GetSplit(&tree, split_value, &candidates); + } + + partitioners.emplace_back(&ctx, page.Size(), page.base_rowid, false); + partitioners.back().UpdatePosition(&ctx, page, candidates, &tree); + partitioners.back().LeafPartition(&ctx, tree, t_gpair, position); + } + + bst_idx_t n_left{0}; + for (auto const& page : p_fmat->GetBatches()) { + auto batch = page.GetView(); + for (size_t i = 0; i < batch.Size(); ++i) { + if (batch[i][split_ind].fvalue < split_value) { + n_left++; + } + } + } + auto n_left_pos = std::count_if(position.cbegin(), position.cend(), + [&](auto v) { return v == tree[RegTree::kRoot].LeftChild(); }); + ASSERT_EQ(n_left, n_left_pos); + std::sort(position.begin(), position.end()); + auto end_it = std::unique(position.begin(), position.end()); + ASSERT_EQ(std::distance(position.begin(), end_it), 2); +} } // anonymous namespace TEST(CommonRowPartitioner, LeafPartition) { @@ -90,4 +137,6 @@ TEST(CommonRowPartitioner, LeafPartition) { TestLeafPartition(n_samples); } } + +TEST(CommonRowPartitioner, LeafPartitionExternalMemory) { TestExternalMemory(); } } // namespace xgboost::tree diff --git a/tests/python-gpu/test_gpu_data_iterator.py b/tests/python-gpu/test_gpu_data_iterator.py index 3a432fe67..b42a152fe 100644 --- a/tests/python-gpu/test_gpu_data_iterator.py +++ b/tests/python-gpu/test_gpu_data_iterator.py @@ -4,6 +4,7 @@ import pytest from hypothesis import given, settings, strategies from xgboost.testing import no_cupy +from xgboost.testing.updater import check_quantile_loss_extmem sys.path.append("tests/python") from test_data_iterator import run_data_iterator @@ -56,3 +57,8 @@ def test_cpu_data_iterator() -> None: use_cupy=True, on_host=False, ) + + +def test_quantile_objective() -> None: + with pytest.raises(ValueError, match="external memory"): + check_quantile_loss_extmem(2, 2, 2, "hist", "cuda") diff --git a/tests/python/test_data_iterator.py b/tests/python/test_data_iterator.py index 1cc34f346..fbf05a236 100644 --- a/tests/python/test_data_iterator.py +++ b/tests/python/test_data_iterator.py @@ -12,6 +12,7 @@ import xgboost as xgb from xgboost import testing as tm from xgboost.data import SingleBatchInternalIter as SingleBatch from xgboost.testing import IteratorForTest, make_batches, non_increasing +from xgboost.testing.updater import check_quantile_loss_extmem pytestmark = tm.timeout(30) @@ -276,3 +277,28 @@ def test_cat_check() -> None: Xy = xgb.DMatrix(it, enable_categorical=True) with pytest.raises(ValueError, match="categorical features"): xgb.train({"booster": "gblinear"}, Xy) + + +@given( + strategies.integers(1, 64), + strategies.integers(1, 8), + strategies.integers(1, 4), +) +@settings(deadline=None, max_examples=10, print_blob=True) +def test_quantile_objective( + n_samples_per_batch: int, n_features: int, n_batches: int +) -> None: + check_quantile_loss_extmem( + n_samples_per_batch, + n_features, + n_batches, + "hist", + "cpu", + ) + check_quantile_loss_extmem( + n_samples_per_batch, + n_features, + n_batches, + "approx", + "cpu", + )