[EM] Support CPU quantile objective for external memory. (#10751)
This commit is contained in:
parent
12c6b7ceea
commit
d6ebcfb032
@ -163,6 +163,37 @@ def check_quantile_loss(tree_method: str, weighted: bool) -> None:
|
||||
np.testing.assert_allclose(predts[:, i], predt_multi[:, i])
|
||||
|
||||
|
||||
def check_quantile_loss_extmem(
|
||||
n_samples_per_batch: int,
|
||||
n_features: int,
|
||||
n_batches: int,
|
||||
tree_method: str,
|
||||
device: str,
|
||||
) -> None:
|
||||
"""Check external memory with the quantile objective."""
|
||||
it = tm.IteratorForTest(
|
||||
*tm.make_batches(n_samples_per_batch, n_features, n_batches, device != "cpu"),
|
||||
cache="cache",
|
||||
on_host=False,
|
||||
)
|
||||
Xy_it = xgb.DMatrix(it)
|
||||
params = {
|
||||
"tree_method": tree_method,
|
||||
"objective": "reg:quantileerror",
|
||||
"device": device,
|
||||
"quantile_alpha": [0.2, 0.8],
|
||||
}
|
||||
booster_it = xgb.train(params, Xy_it)
|
||||
X, y, w = it.as_arrays()
|
||||
Xy = xgb.DMatrix(X, y, weight=w)
|
||||
booster = xgb.train(params, Xy)
|
||||
|
||||
predt_it = booster_it.predict(Xy_it)
|
||||
predt = booster.predict(Xy)
|
||||
|
||||
np.testing.assert_allclose(predt, predt_it)
|
||||
|
||||
|
||||
def check_cut(
|
||||
n_entries: int, indptr: np.ndarray, data: np.ndarray, dtypes: Any
|
||||
) -> None:
|
||||
|
||||
@ -41,6 +41,8 @@ constexpr StringView InconsistentMaxBin() {
|
||||
"and consistent with the Booster being trained.";
|
||||
}
|
||||
|
||||
constexpr StringView InvalidMaxBin() { return "`max_bin` must be equal to or greater than 2."; }
|
||||
|
||||
constexpr StringView UnknownDevice() { return "Unknown device type."; }
|
||||
|
||||
inline void MaxFeatureSize(std::uint64_t n_features) {
|
||||
|
||||
@ -367,23 +367,21 @@ class PartitionBuilder {
|
||||
// Copy row partitions into global cache for reuse in objective
|
||||
template <typename Invalidp>
|
||||
void LeafPartition(Context const* ctx, RegTree const& tree, RowSetCollection const& row_set,
|
||||
std::vector<bst_node_t>* p_position, Invalidp invalidp) const {
|
||||
auto& h_pos = *p_position;
|
||||
h_pos.resize(row_set.Data()->size(), std::numeric_limits<bst_node_t>::max());
|
||||
|
||||
Span<bst_node_t> position, Invalidp invalidp) const {
|
||||
auto p_begin = row_set.Data()->data();
|
||||
// For each node, walk through all the samples that fall in this node.
|
||||
ParallelFor(row_set.Size(), ctx->Threads(), [&](size_t i) {
|
||||
auto p_pos = position.data();
|
||||
ParallelFor(row_set.Size(), ctx->Threads(), [&](auto i) {
|
||||
auto const& node = row_set[i];
|
||||
if (node.node_id < 0) {
|
||||
return;
|
||||
}
|
||||
CHECK(tree.IsLeaf(node.node_id));
|
||||
if (node.begin()) { // guard for empty node.
|
||||
size_t ptr_offset = node.end() - p_begin;
|
||||
std::size_t ptr_offset = node.end() - p_begin;
|
||||
CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id;
|
||||
for (auto idx = node.begin(); idx != node.end(); ++idx) {
|
||||
h_pos[*idx] = tree::SamplePosition::Encode(node.node_id, !invalidp(*idx));
|
||||
p_pos[*idx] = tree::SamplePosition::Encode(node.node_id, !invalidp(*idx));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include <utility>
|
||||
|
||||
#include "../collective/aggregator.h"
|
||||
#include "../common/error_msg.h" // for InvalidMaxBin
|
||||
#include "../data/adapter.h"
|
||||
#include "categorical.h"
|
||||
#include "hist_util.h"
|
||||
@ -16,15 +17,16 @@ namespace xgboost::common {
|
||||
template <typename WQSketch>
|
||||
SketchContainerImpl<WQSketch>::SketchContainerImpl(Context const *ctx,
|
||||
std::vector<bst_idx_t> columns_size,
|
||||
int32_t max_bins,
|
||||
bst_bin_t max_bin,
|
||||
Span<FeatureType const> feature_types,
|
||||
bool use_group)
|
||||
: feature_types_(feature_types.cbegin(), feature_types.cend()),
|
||||
columns_size_{std::move(columns_size)},
|
||||
max_bins_{max_bins},
|
||||
max_bins_{max_bin},
|
||||
use_group_ind_{use_group},
|
||||
n_threads_{ctx->Threads()} {
|
||||
monitor_.Init(__func__);
|
||||
CHECK_GE(max_bin, 2) << error::InvalidMaxBin();
|
||||
CHECK_NE(columns_size_.size(), 0);
|
||||
sketches_.resize(columns_size_.size());
|
||||
CHECK_GE(n_threads_, 1);
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
|
||||
#include "categorical.h"
|
||||
#include "device_helpers.cuh"
|
||||
#include "error_msg.h" // for InvalidMaxBin
|
||||
#include "quantile.h"
|
||||
#include "timer.h"
|
||||
#include "xgboost/data.h"
|
||||
@ -96,7 +97,7 @@ class SketchContainer {
|
||||
* \param num_rows Total number of rows in known dataset (typically the rows in current worker).
|
||||
* \param device GPU ID.
|
||||
*/
|
||||
SketchContainer(HostDeviceVector<FeatureType> const& feature_types, int32_t max_bin,
|
||||
SketchContainer(HostDeviceVector<FeatureType> const& feature_types, bst_bin_t max_bin,
|
||||
bst_feature_t num_columns, bst_idx_t num_rows, DeviceOrd device)
|
||||
: num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin}, device_{device} {
|
||||
CHECK(device.IsCUDA());
|
||||
@ -117,6 +118,7 @@ class SketchContainer {
|
||||
has_categorical_ =
|
||||
!d_feature_types.empty() &&
|
||||
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types), common::IsCatOp{});
|
||||
CHECK_GE(max_bin, 2) << error::InvalidMaxBin();
|
||||
|
||||
timer_.Init(__func__);
|
||||
}
|
||||
|
||||
@ -802,10 +802,10 @@ class SketchContainerImpl {
|
||||
/* \brief Initialize necessary info.
|
||||
*
|
||||
* \param columns_size Size of each column.
|
||||
* \param max_bins maximum number of bins for each feature.
|
||||
* \param max_bin maximum number of bins for each feature.
|
||||
* \param use_group whether is assigned to group to data instance.
|
||||
*/
|
||||
SketchContainerImpl(Context const *ctx, std::vector<bst_idx_t> columns_size, bst_bin_t max_bins,
|
||||
SketchContainerImpl(Context const *ctx, std::vector<bst_idx_t> columns_size, bst_bin_t max_bin,
|
||||
common::Span<FeatureType const> feature_types, bool use_group);
|
||||
|
||||
static bool UseGroup(MetaInfo const &info) {
|
||||
|
||||
@ -218,7 +218,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, linalg::Matrix<GradientPair>* in_gpair,
|
||||
model_.learner_model_param->OutputLength());
|
||||
CHECK_NE(n_groups, 0);
|
||||
|
||||
if (!p_fmat->SingleColBlock() && obj->Task().UpdateTreeLeaf()) {
|
||||
if (!p_fmat->SingleColBlock() && obj->Task().UpdateTreeLeaf() && this->ctx_->IsCUDA()) {
|
||||
LOG(FATAL) << "Current objective doesn't support external memory.";
|
||||
}
|
||||
|
||||
|
||||
@ -301,34 +301,37 @@ class CommonRowPartitioner {
|
||||
auto const& operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; }
|
||||
|
||||
void LeafPartition(Context const* ctx, RegTree const& tree, common::Span<float const> hess,
|
||||
std::vector<bst_node_t>* p_out_position) const {
|
||||
partition_builder_.LeafPartition(ctx, tree, this->Partitions(), p_out_position,
|
||||
[&](size_t idx) -> bool { return hess[idx] - .0f == .0f; });
|
||||
common::Span<bst_node_t> out_position) const {
|
||||
partition_builder_.LeafPartition(
|
||||
ctx, tree, this->Partitions(), out_position,
|
||||
[&](size_t idx) -> bool { return hess[idx - this->base_rowid] - .0f == .0f; });
|
||||
}
|
||||
|
||||
void LeafPartition(Context const* ctx, RegTree const& tree,
|
||||
linalg::TensorView<GradientPair const, 2> gpair,
|
||||
std::vector<bst_node_t>* p_out_position) const {
|
||||
common::Span<bst_node_t> out_position) const {
|
||||
if (gpair.Shape(1) > 1) {
|
||||
partition_builder_.LeafPartition(
|
||||
ctx, tree, this->Partitions(), p_out_position, [&](std::size_t idx) -> bool {
|
||||
auto sample = gpair.Slice(idx, linalg::All());
|
||||
ctx, tree, this->Partitions(), out_position, [&](std::size_t idx) -> bool {
|
||||
auto sample = gpair.Slice(idx - this->base_rowid, linalg::All());
|
||||
return std::all_of(linalg::cbegin(sample), linalg::cend(sample),
|
||||
[](GradientPair const& g) { return g.GetHess() - .0f == .0f; });
|
||||
});
|
||||
} else {
|
||||
auto s = gpair.Slice(linalg::All(), 0);
|
||||
partition_builder_.LeafPartition(
|
||||
ctx, tree, this->Partitions(), p_out_position,
|
||||
[&](std::size_t idx) -> bool { return s(idx).GetHess() - .0f == .0f; });
|
||||
partition_builder_.LeafPartition(ctx, tree, this->Partitions(), out_position,
|
||||
[&](std::size_t idx) -> bool {
|
||||
return s(idx - this->base_rowid).GetHess() - .0f == .0f;
|
||||
});
|
||||
}
|
||||
}
|
||||
void LeafPartition(Context const* ctx, RegTree const& tree,
|
||||
common::Span<GradientPair const> gpair,
|
||||
std::vector<bst_node_t>* p_out_position) const {
|
||||
partition_builder_.LeafPartition(
|
||||
ctx, tree, this->Partitions(), p_out_position,
|
||||
[&](std::size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; });
|
||||
common::Span<bst_node_t> out_position) const {
|
||||
partition_builder_.LeafPartition(ctx, tree, this->Partitions(), out_position,
|
||||
[&](std::size_t idx) -> bool {
|
||||
return gpair[idx - this->base_rowid].GetHess() - .0f == .0f;
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@ -154,8 +154,10 @@ class GlobalApproxBuilder {
|
||||
if (!task_->UpdateTreeLeaf()) {
|
||||
return;
|
||||
}
|
||||
p_out_position->resize(hess.size());
|
||||
for (auto const &part : partitioner_) {
|
||||
part.LeafPartition(ctx_, tree, hess, p_out_position);
|
||||
part.LeafPartition(ctx_, tree, hess,
|
||||
common::Span{p_out_position->data(), p_out_position->size()});
|
||||
}
|
||||
monitor_->Stop(__func__);
|
||||
}
|
||||
|
||||
@ -126,7 +126,7 @@ class MultiTargetHistBuilder {
|
||||
std::vector<CommonRowPartitioner> partitioner_;
|
||||
// Pointer to last updated tree, used for update prediction cache.
|
||||
RegTree const *p_last_tree_{nullptr};
|
||||
DMatrix const * p_last_fmat_{nullptr};
|
||||
DMatrix const *p_last_fmat_{nullptr};
|
||||
|
||||
ObjInfo const *task_{nullptr};
|
||||
|
||||
@ -254,8 +254,10 @@ class MultiTargetHistBuilder {
|
||||
monitor_->Stop(__func__);
|
||||
return;
|
||||
}
|
||||
p_out_position->resize(gpair.Shape(0));
|
||||
for (auto const &part : partitioner_) {
|
||||
part.LeafPartition(ctx_, tree, gpair, p_out_position);
|
||||
part.LeafPartition(ctx_, tree, gpair,
|
||||
common::Span{p_out_position->data(), p_out_position->size()});
|
||||
}
|
||||
monitor_->Stop(__func__);
|
||||
}
|
||||
@ -461,8 +463,10 @@ class HistUpdater {
|
||||
monitor_->Stop(__func__);
|
||||
return;
|
||||
}
|
||||
p_out_position->resize(gpair.Shape(0));
|
||||
for (auto const &part : partitioner_) {
|
||||
part.LeafPartition(ctx_, tree, gpair, p_out_position);
|
||||
part.LeafPartition(ctx_, tree, gpair,
|
||||
common::Span{p_out_position->data(), p_out_position->size()});
|
||||
}
|
||||
monitor_->Stop(__func__);
|
||||
}
|
||||
@ -521,7 +525,9 @@ class QuantileHistMaker : public TreeUpdater {
|
||||
|
||||
linalg::Matrix<GradientPair> sample_out;
|
||||
auto h_sample_out = h_gpair;
|
||||
auto need_copy = [&] { return trees.size() > 1 || n_targets > 1; };
|
||||
auto need_copy = [&] {
|
||||
return trees.size() > 1 || n_targets > 1;
|
||||
};
|
||||
if (need_copy()) {
|
||||
// allocate buffer
|
||||
sample_out = decltype(sample_out){h_gpair.Shape(), ctx_->Device(), linalg::Order::kF};
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2022-2023 by XGBoost contributors.
|
||||
* Copyright 2022-2024, XGBoost contributors.
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/base.h> // for bst_node_t
|
||||
@ -43,14 +43,15 @@ void TestLeafPartition(size_t n_samples) {
|
||||
|
||||
std::vector<size_t> h_nptr;
|
||||
float split_value{0};
|
||||
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>(&ctx, BatchParam{64, 0.2})) {
|
||||
bst_feature_t const split_ind = 0;
|
||||
|
||||
for (auto const& page : Xy->GetBatches<GHistIndexMatrix>(&ctx, BatchParam{64, 0.2})) {
|
||||
auto ptr = page.cut.Ptrs()[split_ind + 1];
|
||||
split_value = page.cut.Values().at(ptr / 2);
|
||||
GetSplit(&tree, split_value, &candidates);
|
||||
partitioner.UpdatePosition(&ctx, page, candidates, &tree);
|
||||
std::vector<bst_node_t> position;
|
||||
partitioner.LeafPartition(&ctx, tree, hess, &position);
|
||||
std::vector<bst_node_t> position(page.Size());
|
||||
partitioner.LeafPartition(&ctx, tree, hess, position);
|
||||
std::sort(position.begin(), position.end());
|
||||
size_t beg = std::distance(
|
||||
position.begin(),
|
||||
@ -76,13 +77,59 @@ void TestLeafPartition(size_t n_samples) {
|
||||
auto batch = page.GetView();
|
||||
size_t left{0};
|
||||
for (size_t i = 0; i < batch.Size(); ++i) {
|
||||
if (not_sampled(i) && batch[i].front().fvalue < split_value) {
|
||||
if (not_sampled(i) && batch[i][split_ind].fvalue < split_value) {
|
||||
left++;
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(left, h_nptr[1] - h_nptr[0]); // equal to number of sampled assigned to left
|
||||
}
|
||||
}
|
||||
|
||||
void TestExternalMemory() {
|
||||
Context ctx;
|
||||
bst_bin_t max_bin = 32;
|
||||
auto p_fmat =
|
||||
RandomDataGenerator{256, 16, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true);
|
||||
std::vector<CommonRowPartitioner> partitioners;
|
||||
|
||||
RegTree tree;
|
||||
std::vector<CPUExpandEntry> candidates{{0, 0}};
|
||||
|
||||
auto gpair = GenerateRandomGradients(p_fmat->Info().num_row_);
|
||||
auto t_gpair = linalg::MakeTensorView(&ctx, gpair.ConstHostSpan(), p_fmat->Info().num_row_, 1);
|
||||
std::vector<bst_node_t> position(p_fmat->Info().num_row_);
|
||||
|
||||
auto param = BatchParam{max_bin, TrainParam::DftSparseThreshold()};
|
||||
float split_value{0.0f};
|
||||
bst_feature_t const split_ind = 0;
|
||||
for (auto const& page : p_fmat->GetBatches<GHistIndexMatrix>(&ctx, param)) {
|
||||
if (partitioners.empty()) {
|
||||
auto ptr = page.cut.Ptrs()[split_ind + 1];
|
||||
split_value = page.cut.Values().at(ptr / 2);
|
||||
GetSplit(&tree, split_value, &candidates);
|
||||
}
|
||||
|
||||
partitioners.emplace_back(&ctx, page.Size(), page.base_rowid, false);
|
||||
partitioners.back().UpdatePosition(&ctx, page, candidates, &tree);
|
||||
partitioners.back().LeafPartition(&ctx, tree, t_gpair, position);
|
||||
}
|
||||
|
||||
bst_idx_t n_left{0};
|
||||
for (auto const& page : p_fmat->GetBatches<SparsePage>()) {
|
||||
auto batch = page.GetView();
|
||||
for (size_t i = 0; i < batch.Size(); ++i) {
|
||||
if (batch[i][split_ind].fvalue < split_value) {
|
||||
n_left++;
|
||||
}
|
||||
}
|
||||
}
|
||||
auto n_left_pos = std::count_if(position.cbegin(), position.cend(),
|
||||
[&](auto v) { return v == tree[RegTree::kRoot].LeftChild(); });
|
||||
ASSERT_EQ(n_left, n_left_pos);
|
||||
std::sort(position.begin(), position.end());
|
||||
auto end_it = std::unique(position.begin(), position.end());
|
||||
ASSERT_EQ(std::distance(position.begin(), end_it), 2);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(CommonRowPartitioner, LeafPartition) {
|
||||
@ -90,4 +137,6 @@ TEST(CommonRowPartitioner, LeafPartition) {
|
||||
TestLeafPartition(n_samples);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CommonRowPartitioner, LeafPartitionExternalMemory) { TestExternalMemory(); }
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@ -4,6 +4,7 @@ import pytest
|
||||
from hypothesis import given, settings, strategies
|
||||
|
||||
from xgboost.testing import no_cupy
|
||||
from xgboost.testing.updater import check_quantile_loss_extmem
|
||||
|
||||
sys.path.append("tests/python")
|
||||
from test_data_iterator import run_data_iterator
|
||||
@ -56,3 +57,8 @@ def test_cpu_data_iterator() -> None:
|
||||
use_cupy=True,
|
||||
on_host=False,
|
||||
)
|
||||
|
||||
|
||||
def test_quantile_objective() -> None:
|
||||
with pytest.raises(ValueError, match="external memory"):
|
||||
check_quantile_loss_extmem(2, 2, 2, "hist", "cuda")
|
||||
|
||||
@ -12,6 +12,7 @@ import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
from xgboost.data import SingleBatchInternalIter as SingleBatch
|
||||
from xgboost.testing import IteratorForTest, make_batches, non_increasing
|
||||
from xgboost.testing.updater import check_quantile_loss_extmem
|
||||
|
||||
pytestmark = tm.timeout(30)
|
||||
|
||||
@ -276,3 +277,28 @@ def test_cat_check() -> None:
|
||||
Xy = xgb.DMatrix(it, enable_categorical=True)
|
||||
with pytest.raises(ValueError, match="categorical features"):
|
||||
xgb.train({"booster": "gblinear"}, Xy)
|
||||
|
||||
|
||||
@given(
|
||||
strategies.integers(1, 64),
|
||||
strategies.integers(1, 8),
|
||||
strategies.integers(1, 4),
|
||||
)
|
||||
@settings(deadline=None, max_examples=10, print_blob=True)
|
||||
def test_quantile_objective(
|
||||
n_samples_per_batch: int, n_features: int, n_batches: int
|
||||
) -> None:
|
||||
check_quantile_loss_extmem(
|
||||
n_samples_per_batch,
|
||||
n_features,
|
||||
n_batches,
|
||||
"hist",
|
||||
"cpu",
|
||||
)
|
||||
check_quantile_loss_extmem(
|
||||
n_samples_per_batch,
|
||||
n_features,
|
||||
n_batches,
|
||||
"approx",
|
||||
"cpu",
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user