[EM] Support CPU quantile objective for external memory. (#10751)

This commit is contained in:
Jiaming Yuan 2024-08-27 04:16:57 +08:00 committed by GitHub
parent 12c6b7ceea
commit d6ebcfb032
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 163 additions and 36 deletions

View File

@ -163,6 +163,37 @@ def check_quantile_loss(tree_method: str, weighted: bool) -> None:
np.testing.assert_allclose(predts[:, i], predt_multi[:, i]) np.testing.assert_allclose(predts[:, i], predt_multi[:, i])
def check_quantile_loss_extmem(
n_samples_per_batch: int,
n_features: int,
n_batches: int,
tree_method: str,
device: str,
) -> None:
"""Check external memory with the quantile objective."""
it = tm.IteratorForTest(
*tm.make_batches(n_samples_per_batch, n_features, n_batches, device != "cpu"),
cache="cache",
on_host=False,
)
Xy_it = xgb.DMatrix(it)
params = {
"tree_method": tree_method,
"objective": "reg:quantileerror",
"device": device,
"quantile_alpha": [0.2, 0.8],
}
booster_it = xgb.train(params, Xy_it)
X, y, w = it.as_arrays()
Xy = xgb.DMatrix(X, y, weight=w)
booster = xgb.train(params, Xy)
predt_it = booster_it.predict(Xy_it)
predt = booster.predict(Xy)
np.testing.assert_allclose(predt, predt_it)
def check_cut( def check_cut(
n_entries: int, indptr: np.ndarray, data: np.ndarray, dtypes: Any n_entries: int, indptr: np.ndarray, data: np.ndarray, dtypes: Any
) -> None: ) -> None:

View File

@ -41,6 +41,8 @@ constexpr StringView InconsistentMaxBin() {
"and consistent with the Booster being trained."; "and consistent with the Booster being trained.";
} }
constexpr StringView InvalidMaxBin() { return "`max_bin` must be equal to or greater than 2."; }
constexpr StringView UnknownDevice() { return "Unknown device type."; } constexpr StringView UnknownDevice() { return "Unknown device type."; }
inline void MaxFeatureSize(std::uint64_t n_features) { inline void MaxFeatureSize(std::uint64_t n_features) {

View File

@ -367,23 +367,21 @@ class PartitionBuilder {
// Copy row partitions into global cache for reuse in objective // Copy row partitions into global cache for reuse in objective
template <typename Invalidp> template <typename Invalidp>
void LeafPartition(Context const* ctx, RegTree const& tree, RowSetCollection const& row_set, void LeafPartition(Context const* ctx, RegTree const& tree, RowSetCollection const& row_set,
std::vector<bst_node_t>* p_position, Invalidp invalidp) const { Span<bst_node_t> position, Invalidp invalidp) const {
auto& h_pos = *p_position;
h_pos.resize(row_set.Data()->size(), std::numeric_limits<bst_node_t>::max());
auto p_begin = row_set.Data()->data(); auto p_begin = row_set.Data()->data();
// For each node, walk through all the samples that fall in this node. // For each node, walk through all the samples that fall in this node.
ParallelFor(row_set.Size(), ctx->Threads(), [&](size_t i) { auto p_pos = position.data();
ParallelFor(row_set.Size(), ctx->Threads(), [&](auto i) {
auto const& node = row_set[i]; auto const& node = row_set[i];
if (node.node_id < 0) { if (node.node_id < 0) {
return; return;
} }
CHECK(tree.IsLeaf(node.node_id)); CHECK(tree.IsLeaf(node.node_id));
if (node.begin()) { // guard for empty node. if (node.begin()) { // guard for empty node.
size_t ptr_offset = node.end() - p_begin; std::size_t ptr_offset = node.end() - p_begin;
CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id; CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id;
for (auto idx = node.begin(); idx != node.end(); ++idx) { for (auto idx = node.begin(); idx != node.end(); ++idx) {
h_pos[*idx] = tree::SamplePosition::Encode(node.node_id, !invalidp(*idx)); p_pos[*idx] = tree::SamplePosition::Encode(node.node_id, !invalidp(*idx));
} }
} }
}); });

View File

@ -8,6 +8,7 @@
#include <utility> #include <utility>
#include "../collective/aggregator.h" #include "../collective/aggregator.h"
#include "../common/error_msg.h" // for InvalidMaxBin
#include "../data/adapter.h" #include "../data/adapter.h"
#include "categorical.h" #include "categorical.h"
#include "hist_util.h" #include "hist_util.h"
@ -16,15 +17,16 @@ namespace xgboost::common {
template <typename WQSketch> template <typename WQSketch>
SketchContainerImpl<WQSketch>::SketchContainerImpl(Context const *ctx, SketchContainerImpl<WQSketch>::SketchContainerImpl(Context const *ctx,
std::vector<bst_idx_t> columns_size, std::vector<bst_idx_t> columns_size,
int32_t max_bins, bst_bin_t max_bin,
Span<FeatureType const> feature_types, Span<FeatureType const> feature_types,
bool use_group) bool use_group)
: feature_types_(feature_types.cbegin(), feature_types.cend()), : feature_types_(feature_types.cbegin(), feature_types.cend()),
columns_size_{std::move(columns_size)}, columns_size_{std::move(columns_size)},
max_bins_{max_bins}, max_bins_{max_bin},
use_group_ind_{use_group}, use_group_ind_{use_group},
n_threads_{ctx->Threads()} { n_threads_{ctx->Threads()} {
monitor_.Init(__func__); monitor_.Init(__func__);
CHECK_GE(max_bin, 2) << error::InvalidMaxBin();
CHECK_NE(columns_size_.size(), 0); CHECK_NE(columns_size_.size(), 0);
sketches_.resize(columns_size_.size()); sketches_.resize(columns_size_.size());
CHECK_GE(n_threads_, 1); CHECK_GE(n_threads_, 1);

View File

@ -8,6 +8,7 @@
#include "categorical.h" #include "categorical.h"
#include "device_helpers.cuh" #include "device_helpers.cuh"
#include "error_msg.h" // for InvalidMaxBin
#include "quantile.h" #include "quantile.h"
#include "timer.h" #include "timer.h"
#include "xgboost/data.h" #include "xgboost/data.h"
@ -96,7 +97,7 @@ class SketchContainer {
* \param num_rows Total number of rows in known dataset (typically the rows in current worker). * \param num_rows Total number of rows in known dataset (typically the rows in current worker).
* \param device GPU ID. * \param device GPU ID.
*/ */
SketchContainer(HostDeviceVector<FeatureType> const& feature_types, int32_t max_bin, SketchContainer(HostDeviceVector<FeatureType> const& feature_types, bst_bin_t max_bin,
bst_feature_t num_columns, bst_idx_t num_rows, DeviceOrd device) bst_feature_t num_columns, bst_idx_t num_rows, DeviceOrd device)
: num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin}, device_{device} { : num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin}, device_{device} {
CHECK(device.IsCUDA()); CHECK(device.IsCUDA());
@ -117,6 +118,7 @@ class SketchContainer {
has_categorical_ = has_categorical_ =
!d_feature_types.empty() && !d_feature_types.empty() &&
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types), common::IsCatOp{}); thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types), common::IsCatOp{});
CHECK_GE(max_bin, 2) << error::InvalidMaxBin();
timer_.Init(__func__); timer_.Init(__func__);
} }

View File

@ -802,10 +802,10 @@ class SketchContainerImpl {
/* \brief Initialize necessary info. /* \brief Initialize necessary info.
* *
* \param columns_size Size of each column. * \param columns_size Size of each column.
* \param max_bins maximum number of bins for each feature. * \param max_bin maximum number of bins for each feature.
* \param use_group whether is assigned to group to data instance. * \param use_group whether is assigned to group to data instance.
*/ */
SketchContainerImpl(Context const *ctx, std::vector<bst_idx_t> columns_size, bst_bin_t max_bins, SketchContainerImpl(Context const *ctx, std::vector<bst_idx_t> columns_size, bst_bin_t max_bin,
common::Span<FeatureType const> feature_types, bool use_group); common::Span<FeatureType const> feature_types, bool use_group);
static bool UseGroup(MetaInfo const &info) { static bool UseGroup(MetaInfo const &info) {

View File

@ -218,7 +218,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, linalg::Matrix<GradientPair>* in_gpair,
model_.learner_model_param->OutputLength()); model_.learner_model_param->OutputLength());
CHECK_NE(n_groups, 0); CHECK_NE(n_groups, 0);
if (!p_fmat->SingleColBlock() && obj->Task().UpdateTreeLeaf()) { if (!p_fmat->SingleColBlock() && obj->Task().UpdateTreeLeaf() && this->ctx_->IsCUDA()) {
LOG(FATAL) << "Current objective doesn't support external memory."; LOG(FATAL) << "Current objective doesn't support external memory.";
} }

View File

@ -301,34 +301,37 @@ class CommonRowPartitioner {
auto const& operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; } auto const& operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; }
void LeafPartition(Context const* ctx, RegTree const& tree, common::Span<float const> hess, void LeafPartition(Context const* ctx, RegTree const& tree, common::Span<float const> hess,
std::vector<bst_node_t>* p_out_position) const { common::Span<bst_node_t> out_position) const {
partition_builder_.LeafPartition(ctx, tree, this->Partitions(), p_out_position, partition_builder_.LeafPartition(
[&](size_t idx) -> bool { return hess[idx] - .0f == .0f; }); ctx, tree, this->Partitions(), out_position,
[&](size_t idx) -> bool { return hess[idx - this->base_rowid] - .0f == .0f; });
} }
void LeafPartition(Context const* ctx, RegTree const& tree, void LeafPartition(Context const* ctx, RegTree const& tree,
linalg::TensorView<GradientPair const, 2> gpair, linalg::TensorView<GradientPair const, 2> gpair,
std::vector<bst_node_t>* p_out_position) const { common::Span<bst_node_t> out_position) const {
if (gpair.Shape(1) > 1) { if (gpair.Shape(1) > 1) {
partition_builder_.LeafPartition( partition_builder_.LeafPartition(
ctx, tree, this->Partitions(), p_out_position, [&](std::size_t idx) -> bool { ctx, tree, this->Partitions(), out_position, [&](std::size_t idx) -> bool {
auto sample = gpair.Slice(idx, linalg::All()); auto sample = gpair.Slice(idx - this->base_rowid, linalg::All());
return std::all_of(linalg::cbegin(sample), linalg::cend(sample), return std::all_of(linalg::cbegin(sample), linalg::cend(sample),
[](GradientPair const& g) { return g.GetHess() - .0f == .0f; }); [](GradientPair const& g) { return g.GetHess() - .0f == .0f; });
}); });
} else { } else {
auto s = gpair.Slice(linalg::All(), 0); auto s = gpair.Slice(linalg::All(), 0);
partition_builder_.LeafPartition( partition_builder_.LeafPartition(ctx, tree, this->Partitions(), out_position,
ctx, tree, this->Partitions(), p_out_position, [&](std::size_t idx) -> bool {
[&](std::size_t idx) -> bool { return s(idx).GetHess() - .0f == .0f; }); return s(idx - this->base_rowid).GetHess() - .0f == .0f;
});
} }
} }
void LeafPartition(Context const* ctx, RegTree const& tree, void LeafPartition(Context const* ctx, RegTree const& tree,
common::Span<GradientPair const> gpair, common::Span<GradientPair const> gpair,
std::vector<bst_node_t>* p_out_position) const { common::Span<bst_node_t> out_position) const {
partition_builder_.LeafPartition( partition_builder_.LeafPartition(ctx, tree, this->Partitions(), out_position,
ctx, tree, this->Partitions(), p_out_position, [&](std::size_t idx) -> bool {
[&](std::size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; }); return gpair[idx - this->base_rowid].GetHess() - .0f == .0f;
});
} }
private: private:

View File

@ -154,8 +154,10 @@ class GlobalApproxBuilder {
if (!task_->UpdateTreeLeaf()) { if (!task_->UpdateTreeLeaf()) {
return; return;
} }
p_out_position->resize(hess.size());
for (auto const &part : partitioner_) { for (auto const &part : partitioner_) {
part.LeafPartition(ctx_, tree, hess, p_out_position); part.LeafPartition(ctx_, tree, hess,
common::Span{p_out_position->data(), p_out_position->size()});
} }
monitor_->Stop(__func__); monitor_->Stop(__func__);
} }

View File

@ -126,7 +126,7 @@ class MultiTargetHistBuilder {
std::vector<CommonRowPartitioner> partitioner_; std::vector<CommonRowPartitioner> partitioner_;
// Pointer to last updated tree, used for update prediction cache. // Pointer to last updated tree, used for update prediction cache.
RegTree const *p_last_tree_{nullptr}; RegTree const *p_last_tree_{nullptr};
DMatrix const * p_last_fmat_{nullptr}; DMatrix const *p_last_fmat_{nullptr};
ObjInfo const *task_{nullptr}; ObjInfo const *task_{nullptr};
@ -254,8 +254,10 @@ class MultiTargetHistBuilder {
monitor_->Stop(__func__); monitor_->Stop(__func__);
return; return;
} }
p_out_position->resize(gpair.Shape(0));
for (auto const &part : partitioner_) { for (auto const &part : partitioner_) {
part.LeafPartition(ctx_, tree, gpair, p_out_position); part.LeafPartition(ctx_, tree, gpair,
common::Span{p_out_position->data(), p_out_position->size()});
} }
monitor_->Stop(__func__); monitor_->Stop(__func__);
} }
@ -461,8 +463,10 @@ class HistUpdater {
monitor_->Stop(__func__); monitor_->Stop(__func__);
return; return;
} }
p_out_position->resize(gpair.Shape(0));
for (auto const &part : partitioner_) { for (auto const &part : partitioner_) {
part.LeafPartition(ctx_, tree, gpair, p_out_position); part.LeafPartition(ctx_, tree, gpair,
common::Span{p_out_position->data(), p_out_position->size()});
} }
monitor_->Stop(__func__); monitor_->Stop(__func__);
} }
@ -521,7 +525,9 @@ class QuantileHistMaker : public TreeUpdater {
linalg::Matrix<GradientPair> sample_out; linalg::Matrix<GradientPair> sample_out;
auto h_sample_out = h_gpair; auto h_sample_out = h_gpair;
auto need_copy = [&] { return trees.size() > 1 || n_targets > 1; }; auto need_copy = [&] {
return trees.size() > 1 || n_targets > 1;
};
if (need_copy()) { if (need_copy()) {
// allocate buffer // allocate buffer
sample_out = decltype(sample_out){h_gpair.Shape(), ctx_->Device(), linalg::Order::kF}; sample_out = decltype(sample_out){h_gpair.Shape(), ctx_->Device(), linalg::Order::kF};

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2022-2023 by XGBoost contributors. * Copyright 2022-2024, XGBoost contributors.
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/base.h> // for bst_node_t #include <xgboost/base.h> // for bst_node_t
@ -43,14 +43,15 @@ void TestLeafPartition(size_t n_samples) {
std::vector<size_t> h_nptr; std::vector<size_t> h_nptr;
float split_value{0}; float split_value{0};
bst_feature_t const split_ind = 0;
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>(&ctx, BatchParam{64, 0.2})) { for (auto const& page : Xy->GetBatches<GHistIndexMatrix>(&ctx, BatchParam{64, 0.2})) {
bst_feature_t const split_ind = 0;
auto ptr = page.cut.Ptrs()[split_ind + 1]; auto ptr = page.cut.Ptrs()[split_ind + 1];
split_value = page.cut.Values().at(ptr / 2); split_value = page.cut.Values().at(ptr / 2);
GetSplit(&tree, split_value, &candidates); GetSplit(&tree, split_value, &candidates);
partitioner.UpdatePosition(&ctx, page, candidates, &tree); partitioner.UpdatePosition(&ctx, page, candidates, &tree);
std::vector<bst_node_t> position; std::vector<bst_node_t> position(page.Size());
partitioner.LeafPartition(&ctx, tree, hess, &position); partitioner.LeafPartition(&ctx, tree, hess, position);
std::sort(position.begin(), position.end()); std::sort(position.begin(), position.end());
size_t beg = std::distance( size_t beg = std::distance(
position.begin(), position.begin(),
@ -76,13 +77,59 @@ void TestLeafPartition(size_t n_samples) {
auto batch = page.GetView(); auto batch = page.GetView();
size_t left{0}; size_t left{0};
for (size_t i = 0; i < batch.Size(); ++i) { for (size_t i = 0; i < batch.Size(); ++i) {
if (not_sampled(i) && batch[i].front().fvalue < split_value) { if (not_sampled(i) && batch[i][split_ind].fvalue < split_value) {
left++; left++;
} }
} }
ASSERT_EQ(left, h_nptr[1] - h_nptr[0]); // equal to number of sampled assigned to left ASSERT_EQ(left, h_nptr[1] - h_nptr[0]); // equal to number of sampled assigned to left
} }
} }
void TestExternalMemory() {
Context ctx;
bst_bin_t max_bin = 32;
auto p_fmat =
RandomDataGenerator{256, 16, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true);
std::vector<CommonRowPartitioner> partitioners;
RegTree tree;
std::vector<CPUExpandEntry> candidates{{0, 0}};
auto gpair = GenerateRandomGradients(p_fmat->Info().num_row_);
auto t_gpair = linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), p_fmat->Info().num_row_, 1);
std::vector<bst_node_t> position(p_fmat->Info().num_row_);
auto param = BatchParam{max_bin, TrainParam::DftSparseThreshold()};
float split_value{0.0f};
bst_feature_t const split_ind = 0;
for (auto const& page : p_fmat->GetBatches<GHistIndexMatrix>(&ctx, param)) {
if (partitioners.empty()) {
auto ptr = page.cut.Ptrs()[split_ind + 1];
split_value = page.cut.Values().at(ptr / 2);
GetSplit(&tree, split_value, &candidates);
}
partitioners.emplace_back(&ctx, page.Size(), page.base_rowid, false);
partitioners.back().UpdatePosition(&ctx, page, candidates, &tree);
partitioners.back().LeafPartition(&ctx, tree, t_gpair, position);
}
bst_idx_t n_left{0};
for (auto const& page : p_fmat->GetBatches<SparsePage>()) {
auto batch = page.GetView();
for (size_t i = 0; i < batch.Size(); ++i) {
if (batch[i][split_ind].fvalue < split_value) {
n_left++;
}
}
}
auto n_left_pos = std::count_if(position.cbegin(), position.cend(),
[&](auto v) { return v == tree[RegTree::kRoot].LeftChild(); });
ASSERT_EQ(n_left, n_left_pos);
std::sort(position.begin(), position.end());
auto end_it = std::unique(position.begin(), position.end());
ASSERT_EQ(std::distance(position.begin(), end_it), 2);
}
} // anonymous namespace } // anonymous namespace
TEST(CommonRowPartitioner, LeafPartition) { TEST(CommonRowPartitioner, LeafPartition) {
@ -90,4 +137,6 @@ TEST(CommonRowPartitioner, LeafPartition) {
TestLeafPartition(n_samples); TestLeafPartition(n_samples);
} }
} }
TEST(CommonRowPartitioner, LeafPartitionExternalMemory) { TestExternalMemory(); }
} // namespace xgboost::tree } // namespace xgboost::tree

View File

@ -4,6 +4,7 @@ import pytest
from hypothesis import given, settings, strategies from hypothesis import given, settings, strategies
from xgboost.testing import no_cupy from xgboost.testing import no_cupy
from xgboost.testing.updater import check_quantile_loss_extmem
sys.path.append("tests/python") sys.path.append("tests/python")
from test_data_iterator import run_data_iterator from test_data_iterator import run_data_iterator
@ -56,3 +57,8 @@ def test_cpu_data_iterator() -> None:
use_cupy=True, use_cupy=True,
on_host=False, on_host=False,
) )
def test_quantile_objective() -> None:
with pytest.raises(ValueError, match="external memory"):
check_quantile_loss_extmem(2, 2, 2, "hist", "cuda")

View File

@ -12,6 +12,7 @@ import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
from xgboost.data import SingleBatchInternalIter as SingleBatch from xgboost.data import SingleBatchInternalIter as SingleBatch
from xgboost.testing import IteratorForTest, make_batches, non_increasing from xgboost.testing import IteratorForTest, make_batches, non_increasing
from xgboost.testing.updater import check_quantile_loss_extmem
pytestmark = tm.timeout(30) pytestmark = tm.timeout(30)
@ -276,3 +277,28 @@ def test_cat_check() -> None:
Xy = xgb.DMatrix(it, enable_categorical=True) Xy = xgb.DMatrix(it, enable_categorical=True)
with pytest.raises(ValueError, match="categorical features"): with pytest.raises(ValueError, match="categorical features"):
xgb.train({"booster": "gblinear"}, Xy) xgb.train({"booster": "gblinear"}, Xy)
@given(
strategies.integers(1, 64),
strategies.integers(1, 8),
strategies.integers(1, 4),
)
@settings(deadline=None, max_examples=10, print_blob=True)
def test_quantile_objective(
n_samples_per_batch: int, n_features: int, n_batches: int
) -> None:
check_quantile_loss_extmem(
n_samples_per_batch,
n_features,
n_batches,
"hist",
"cpu",
)
check_quantile_loss_extmem(
n_samples_per_batch,
n_features,
n_batches,
"approx",
"cpu",
)