Add approx partitioner. (#7467)
This commit is contained in:
76
tests/cpp/tree/test_approx.cc
Normal file
76
tests/cpp/tree/test_approx.cc
Normal file
@@ -0,0 +1,76 @@
|
||||
/*!
|
||||
* Copyright 2021 XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "../../../src/tree/updater_approx.h"
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
TEST(Approx, Partitioner) {
|
||||
size_t n_samples = 1024, n_features = 1, base_rowid = 0;
|
||||
ApproxRowPartitioner partitioner{n_samples, base_rowid};
|
||||
ASSERT_EQ(partitioner.base_rowid, base_rowid);
|
||||
ASSERT_EQ(partitioner.Size(), 1);
|
||||
ASSERT_EQ(partitioner.Partitions()[0].Size(), n_samples);
|
||||
|
||||
auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true);
|
||||
GenericParameter ctx;
|
||||
ctx.InitAllowUnknown(Args{});
|
||||
std::vector<CPUExpandEntry> candidates{{0, 0, 0.4}};
|
||||
|
||||
for (auto const &page : Xy->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, 64})) {
|
||||
bst_feature_t split_ind = 0;
|
||||
{
|
||||
auto min_value = page.cut.MinValues()[split_ind];
|
||||
RegTree tree;
|
||||
tree.ExpandNode(
|
||||
/*nid=*/0, /*split_index=*/0, /*split_value=*/min_value,
|
||||
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
/*left_sum=*/0.0f,
|
||||
/*right_sum=*/0.0f);
|
||||
ApproxRowPartitioner partitioner{n_samples, base_rowid};
|
||||
candidates.front().split.split_value = min_value;
|
||||
candidates.front().split.sindex = 0;
|
||||
candidates.front().split.sindex |= (1U << 31);
|
||||
partitioner.UpdatePosition(&ctx, page, candidates, &tree);
|
||||
ASSERT_EQ(partitioner.Size(), 3);
|
||||
ASSERT_EQ(partitioner[1].Size(), 0);
|
||||
ASSERT_EQ(partitioner[2].Size(), n_samples);
|
||||
}
|
||||
{
|
||||
ApproxRowPartitioner partitioner{n_samples, base_rowid};
|
||||
auto ptr = page.cut.Ptrs()[split_ind + 1];
|
||||
float split_value = page.cut.Values().at(ptr / 2);
|
||||
RegTree tree;
|
||||
tree.ExpandNode(
|
||||
/*nid=*/RegTree::kRoot, /*split_index=*/split_ind,
|
||||
/*split_value=*/split_value,
|
||||
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
/*left_sum=*/0.0f,
|
||||
/*right_sum=*/0.0f);
|
||||
auto left_nidx = tree[RegTree::kRoot].LeftChild();
|
||||
candidates.front().split.split_value = split_value;
|
||||
candidates.front().split.sindex = 0;
|
||||
candidates.front().split.sindex |= (1U << 31);
|
||||
partitioner.UpdatePosition(&ctx, page, candidates, &tree);
|
||||
|
||||
auto elem = partitioner[left_nidx];
|
||||
ASSERT_LT(elem.Size(), n_samples);
|
||||
ASSERT_GT(elem.Size(), 1);
|
||||
for (auto it = elem.begin; it != elem.end; ++it) {
|
||||
auto value = page.cut.Values().at(page.index[*it]);
|
||||
ASSERT_LE(value, split_value);
|
||||
}
|
||||
auto right_nidx = tree[RegTree::kRoot].RightChild();
|
||||
elem = partitioner[right_nidx];
|
||||
for (auto it = elem.begin; it != elem.end; ++it) {
|
||||
auto value = page.cut.Values().at(page.index[*it]);
|
||||
ASSERT_GT(value, split_value) << *it;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
Reference in New Issue
Block a user