[EM] Support CPU quantile objective for external memory. (#10751)

This commit is contained in:
Jiaming Yuan 2024-08-27 04:16:57 +08:00 committed by GitHub
parent 12c6b7ceea
commit d6ebcfb032
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 163 additions and 36 deletions

View File

@ -163,6 +163,37 @@ def check_quantile_loss(tree_method: str, weighted: bool) -> None:
np.testing.assert_allclose(predts[:, i], predt_multi[:, i])
def check_quantile_loss_extmem(
n_samples_per_batch: int,
n_features: int,
n_batches: int,
tree_method: str,
device: str,
) -> None:
"""Check external memory with the quantile objective."""
it = tm.IteratorForTest(
*tm.make_batches(n_samples_per_batch, n_features, n_batches, device != "cpu"),
cache="cache",
on_host=False,
)
Xy_it = xgb.DMatrix(it)
params = {
"tree_method": tree_method,
"objective": "reg:quantileerror",
"device": device,
"quantile_alpha": [0.2, 0.8],
}
booster_it = xgb.train(params, Xy_it)
X, y, w = it.as_arrays()
Xy = xgb.DMatrix(X, y, weight=w)
booster = xgb.train(params, Xy)
predt_it = booster_it.predict(Xy_it)
predt = booster.predict(Xy)
np.testing.assert_allclose(predt, predt_it)
def check_cut(
n_entries: int, indptr: np.ndarray, data: np.ndarray, dtypes: Any
) -> None:

View File

@ -41,6 +41,8 @@ constexpr StringView InconsistentMaxBin() {
"and consistent with the Booster being trained.";
}
constexpr StringView InvalidMaxBin() { return "`max_bin` must be equal to or greater than 2."; }
constexpr StringView UnknownDevice() { return "Unknown device type."; }
inline void MaxFeatureSize(std::uint64_t n_features) {

View File

@ -367,23 +367,21 @@ class PartitionBuilder {
// Copy row partitions into global cache for reuse in objective
template <typename Invalidp>
void LeafPartition(Context const* ctx, RegTree const& tree, RowSetCollection const& row_set,
std::vector<bst_node_t>* p_position, Invalidp invalidp) const {
auto& h_pos = *p_position;
h_pos.resize(row_set.Data()->size(), std::numeric_limits<bst_node_t>::max());
Span<bst_node_t> position, Invalidp invalidp) const {
auto p_begin = row_set.Data()->data();
// For each node, walk through all the samples that fall in this node.
ParallelFor(row_set.Size(), ctx->Threads(), [&](size_t i) {
auto p_pos = position.data();
ParallelFor(row_set.Size(), ctx->Threads(), [&](auto i) {
auto const& node = row_set[i];
if (node.node_id < 0) {
return;
}
CHECK(tree.IsLeaf(node.node_id));
if (node.begin()) { // guard for empty node.
size_t ptr_offset = node.end() - p_begin;
std::size_t ptr_offset = node.end() - p_begin;
CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id;
for (auto idx = node.begin(); idx != node.end(); ++idx) {
h_pos[*idx] = tree::SamplePosition::Encode(node.node_id, !invalidp(*idx));
p_pos[*idx] = tree::SamplePosition::Encode(node.node_id, !invalidp(*idx));
}
}
});

View File

@ -8,6 +8,7 @@
#include <utility>
#include "../collective/aggregator.h"
#include "../common/error_msg.h" // for InvalidMaxBin
#include "../data/adapter.h"
#include "categorical.h"
#include "hist_util.h"
@ -16,15 +17,16 @@ namespace xgboost::common {
template <typename WQSketch>
SketchContainerImpl<WQSketch>::SketchContainerImpl(Context const *ctx,
std::vector<bst_idx_t> columns_size,
int32_t max_bins,
bst_bin_t max_bin,
Span<FeatureType const> feature_types,
bool use_group)
: feature_types_(feature_types.cbegin(), feature_types.cend()),
columns_size_{std::move(columns_size)},
max_bins_{max_bins},
max_bins_{max_bin},
use_group_ind_{use_group},
n_threads_{ctx->Threads()} {
monitor_.Init(__func__);
CHECK_GE(max_bin, 2) << error::InvalidMaxBin();
CHECK_NE(columns_size_.size(), 0);
sketches_.resize(columns_size_.size());
CHECK_GE(n_threads_, 1);

View File

@ -8,6 +8,7 @@
#include "categorical.h"
#include "device_helpers.cuh"
#include "error_msg.h" // for InvalidMaxBin
#include "quantile.h"
#include "timer.h"
#include "xgboost/data.h"
@ -96,7 +97,7 @@ class SketchContainer {
* \param num_rows Total number of rows in known dataset (typically the rows in current worker).
* \param device GPU ID.
*/
SketchContainer(HostDeviceVector<FeatureType> const& feature_types, int32_t max_bin,
SketchContainer(HostDeviceVector<FeatureType> const& feature_types, bst_bin_t max_bin,
bst_feature_t num_columns, bst_idx_t num_rows, DeviceOrd device)
: num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin}, device_{device} {
CHECK(device.IsCUDA());
@ -117,6 +118,7 @@ class SketchContainer {
has_categorical_ =
!d_feature_types.empty() &&
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types), common::IsCatOp{});
CHECK_GE(max_bin, 2) << error::InvalidMaxBin();
timer_.Init(__func__);
}

View File

@ -802,10 +802,10 @@ class SketchContainerImpl {
/* \brief Initialize necessary info.
*
* \param columns_size Size of each column.
* \param max_bins maximum number of bins for each feature.
* \param max_bin maximum number of bins for each feature.
* \param use_group whether is assigned to group to data instance.
*/
SketchContainerImpl(Context const *ctx, std::vector<bst_idx_t> columns_size, bst_bin_t max_bins,
SketchContainerImpl(Context const *ctx, std::vector<bst_idx_t> columns_size, bst_bin_t max_bin,
common::Span<FeatureType const> feature_types, bool use_group);
static bool UseGroup(MetaInfo const &info) {

View File

@ -218,7 +218,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, linalg::Matrix<GradientPair>* in_gpair,
model_.learner_model_param->OutputLength());
CHECK_NE(n_groups, 0);
if (!p_fmat->SingleColBlock() && obj->Task().UpdateTreeLeaf()) {
if (!p_fmat->SingleColBlock() && obj->Task().UpdateTreeLeaf() && this->ctx_->IsCUDA()) {
LOG(FATAL) << "Current objective doesn't support external memory.";
}

View File

@ -301,34 +301,37 @@ class CommonRowPartitioner {
auto const& operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; }
void LeafPartition(Context const* ctx, RegTree const& tree, common::Span<float const> hess,
std::vector<bst_node_t>* p_out_position) const {
partition_builder_.LeafPartition(ctx, tree, this->Partitions(), p_out_position,
[&](size_t idx) -> bool { return hess[idx] - .0f == .0f; });
common::Span<bst_node_t> out_position) const {
partition_builder_.LeafPartition(
ctx, tree, this->Partitions(), out_position,
[&](size_t idx) -> bool { return hess[idx - this->base_rowid] - .0f == .0f; });
}
void LeafPartition(Context const* ctx, RegTree const& tree,
linalg::TensorView<GradientPair const, 2> gpair,
std::vector<bst_node_t>* p_out_position) const {
common::Span<bst_node_t> out_position) const {
if (gpair.Shape(1) > 1) {
partition_builder_.LeafPartition(
ctx, tree, this->Partitions(), p_out_position, [&](std::size_t idx) -> bool {
auto sample = gpair.Slice(idx, linalg::All());
ctx, tree, this->Partitions(), out_position, [&](std::size_t idx) -> bool {
auto sample = gpair.Slice(idx - this->base_rowid, linalg::All());
return std::all_of(linalg::cbegin(sample), linalg::cend(sample),
[](GradientPair const& g) { return g.GetHess() - .0f == .0f; });
});
} else {
auto s = gpair.Slice(linalg::All(), 0);
partition_builder_.LeafPartition(
ctx, tree, this->Partitions(), p_out_position,
[&](std::size_t idx) -> bool { return s(idx).GetHess() - .0f == .0f; });
partition_builder_.LeafPartition(ctx, tree, this->Partitions(), out_position,
[&](std::size_t idx) -> bool {
return s(idx - this->base_rowid).GetHess() - .0f == .0f;
});
}
}
void LeafPartition(Context const* ctx, RegTree const& tree,
common::Span<GradientPair const> gpair,
std::vector<bst_node_t>* p_out_position) const {
partition_builder_.LeafPartition(
ctx, tree, this->Partitions(), p_out_position,
[&](std::size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; });
common::Span<bst_node_t> out_position) const {
partition_builder_.LeafPartition(ctx, tree, this->Partitions(), out_position,
[&](std::size_t idx) -> bool {
return gpair[idx - this->base_rowid].GetHess() - .0f == .0f;
});
}
private:

View File

@ -154,8 +154,10 @@ class GlobalApproxBuilder {
if (!task_->UpdateTreeLeaf()) {
return;
}
p_out_position->resize(hess.size());
for (auto const &part : partitioner_) {
part.LeafPartition(ctx_, tree, hess, p_out_position);
part.LeafPartition(ctx_, tree, hess,
common::Span{p_out_position->data(), p_out_position->size()});
}
monitor_->Stop(__func__);
}

View File

@ -254,8 +254,10 @@ class MultiTargetHistBuilder {
monitor_->Stop(__func__);
return;
}
p_out_position->resize(gpair.Shape(0));
for (auto const &part : partitioner_) {
part.LeafPartition(ctx_, tree, gpair, p_out_position);
part.LeafPartition(ctx_, tree, gpair,
common::Span{p_out_position->data(), p_out_position->size()});
}
monitor_->Stop(__func__);
}
@ -461,8 +463,10 @@ class HistUpdater {
monitor_->Stop(__func__);
return;
}
p_out_position->resize(gpair.Shape(0));
for (auto const &part : partitioner_) {
part.LeafPartition(ctx_, tree, gpair, p_out_position);
part.LeafPartition(ctx_, tree, gpair,
common::Span{p_out_position->data(), p_out_position->size()});
}
monitor_->Stop(__func__);
}
@ -521,7 +525,9 @@ class QuantileHistMaker : public TreeUpdater {
linalg::Matrix<GradientPair> sample_out;
auto h_sample_out = h_gpair;
auto need_copy = [&] { return trees.size() > 1 || n_targets > 1; };
auto need_copy = [&] {
return trees.size() > 1 || n_targets > 1;
};
if (need_copy()) {
// allocate buffer
sample_out = decltype(sample_out){h_gpair.Shape(), ctx_->Device(), linalg::Order::kF};

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022-2023 by XGBoost contributors.
* Copyright 2022-2024, XGBoost contributors.
*/
#include <gtest/gtest.h>
#include <xgboost/base.h> // for bst_node_t
@ -43,14 +43,15 @@ void TestLeafPartition(size_t n_samples) {
std::vector<size_t> h_nptr;
float split_value{0};
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>(&ctx, BatchParam{64, 0.2})) {
bst_feature_t const split_ind = 0;
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>(&ctx, BatchParam{64, 0.2})) {
auto ptr = page.cut.Ptrs()[split_ind + 1];
split_value = page.cut.Values().at(ptr / 2);
GetSplit(&tree, split_value, &candidates);
partitioner.UpdatePosition(&ctx, page, candidates, &tree);
std::vector<bst_node_t> position;
partitioner.LeafPartition(&ctx, tree, hess, &position);
std::vector<bst_node_t> position(page.Size());
partitioner.LeafPartition(&ctx, tree, hess, position);
std::sort(position.begin(), position.end());
size_t beg = std::distance(
position.begin(),
@ -76,13 +77,59 @@ void TestLeafPartition(size_t n_samples) {
auto batch = page.GetView();
size_t left{0};
for (size_t i = 0; i < batch.Size(); ++i) {
if (not_sampled(i) && batch[i].front().fvalue < split_value) {
if (not_sampled(i) && batch[i][split_ind].fvalue < split_value) {
left++;
}
}
ASSERT_EQ(left, h_nptr[1] - h_nptr[0]); // equal to number of sampled assigned to left
}
}
void TestExternalMemory() {
Context ctx;
bst_bin_t max_bin = 32;
auto p_fmat =
RandomDataGenerator{256, 16, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true);
std::vector<CommonRowPartitioner> partitioners;
RegTree tree;
std::vector<CPUExpandEntry> candidates{{0, 0}};
auto gpair = GenerateRandomGradients(p_fmat->Info().num_row_);
auto t_gpair = linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), p_fmat->Info().num_row_, 1);
std::vector<bst_node_t> position(p_fmat->Info().num_row_);
auto param = BatchParam{max_bin, TrainParam::DftSparseThreshold()};
float split_value{0.0f};
bst_feature_t const split_ind = 0;
for (auto const& page : p_fmat->GetBatches<GHistIndexMatrix>(&ctx, param)) {
if (partitioners.empty()) {
auto ptr = page.cut.Ptrs()[split_ind + 1];
split_value = page.cut.Values().at(ptr / 2);
GetSplit(&tree, split_value, &candidates);
}
partitioners.emplace_back(&ctx, page.Size(), page.base_rowid, false);
partitioners.back().UpdatePosition(&ctx, page, candidates, &tree);
partitioners.back().LeafPartition(&ctx, tree, t_gpair, position);
}
bst_idx_t n_left{0};
for (auto const& page : p_fmat->GetBatches<SparsePage>()) {
auto batch = page.GetView();
for (size_t i = 0; i < batch.Size(); ++i) {
if (batch[i][split_ind].fvalue < split_value) {
n_left++;
}
}
}
auto n_left_pos = std::count_if(position.cbegin(), position.cend(),
[&](auto v) { return v == tree[RegTree::kRoot].LeftChild(); });
ASSERT_EQ(n_left, n_left_pos);
std::sort(position.begin(), position.end());
auto end_it = std::unique(position.begin(), position.end());
ASSERT_EQ(std::distance(position.begin(), end_it), 2);
}
} // anonymous namespace
TEST(CommonRowPartitioner, LeafPartition) {
@ -90,4 +137,6 @@ TEST(CommonRowPartitioner, LeafPartition) {
TestLeafPartition(n_samples);
}
}
TEST(CommonRowPartitioner, LeafPartitionExternalMemory) { TestExternalMemory(); }
} // namespace xgboost::tree

View File

@ -4,6 +4,7 @@ import pytest
from hypothesis import given, settings, strategies
from xgboost.testing import no_cupy
from xgboost.testing.updater import check_quantile_loss_extmem
sys.path.append("tests/python")
from test_data_iterator import run_data_iterator
@ -56,3 +57,8 @@ def test_cpu_data_iterator() -> None:
use_cupy=True,
on_host=False,
)
def test_quantile_objective() -> None:
with pytest.raises(ValueError, match="external memory"):
check_quantile_loss_extmem(2, 2, 2, "hist", "cuda")

View File

@ -12,6 +12,7 @@ import xgboost as xgb
from xgboost import testing as tm
from xgboost.data import SingleBatchInternalIter as SingleBatch
from xgboost.testing import IteratorForTest, make_batches, non_increasing
from xgboost.testing.updater import check_quantile_loss_extmem
pytestmark = tm.timeout(30)
@ -276,3 +277,28 @@ def test_cat_check() -> None:
Xy = xgb.DMatrix(it, enable_categorical=True)
with pytest.raises(ValueError, match="categorical features"):
xgb.train({"booster": "gblinear"}, Xy)
@given(
strategies.integers(1, 64),
strategies.integers(1, 8),
strategies.integers(1, 4),
)
@settings(deadline=None, max_examples=10, print_blob=True)
def test_quantile_objective(
n_samples_per_batch: int, n_features: int, n_batches: int
) -> None:
check_quantile_loss_extmem(
n_samples_per_batch,
n_features,
n_batches,
"hist",
"cpu",
)
check_quantile_loss_extmem(
n_samples_per_batch,
n_features,
n_batches,
"approx",
"cpu",
)