Merge approx tests. (#10583)
This commit is contained in:
parent
5a92ffe3ca
commit
a6a8a55ffa
@ -4,10 +4,12 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include "../../../src/tree/common_row_partitioner.h"
|
#include "../../../src/tree/common_row_partitioner.h"
|
||||||
|
#include "../../../src/tree/param.h" // for TrainParam
|
||||||
#include "../collective/test_worker.h" // for TestDistributedGlobal
|
#include "../collective/test_worker.h" // for TestDistributedGlobal
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
#include "test_column_split.h" // for TestColumnSplit
|
#include "test_column_split.h" // for TestColumnSplit
|
||||||
#include "test_partitioner.h"
|
#include "test_partitioner.h"
|
||||||
|
#include "xgboost/tree_model.h" // for RegTree
|
||||||
|
|
||||||
namespace xgboost::tree {
|
namespace xgboost::tree {
|
||||||
namespace {
|
namespace {
|
||||||
@ -76,6 +78,53 @@ TEST(Approx, Partitioner) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Approx, InteractionConstraint) {
|
||||||
|
auto constexpr kRows = 32;
|
||||||
|
auto constexpr kCols = 16;
|
||||||
|
auto p_dmat = GenerateCatDMatrix(kRows, kCols, 0.6f, false);
|
||||||
|
Context ctx;
|
||||||
|
|
||||||
|
linalg::Matrix<GradientPair> gpair({kRows}, ctx.Device());
|
||||||
|
gpair.Data()->Copy(GenerateRandomGradients(kRows));
|
||||||
|
|
||||||
|
ObjInfo task{ObjInfo::kRegression};
|
||||||
|
{
|
||||||
|
// With constraints
|
||||||
|
RegTree tree{1, kCols};
|
||||||
|
|
||||||
|
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
||||||
|
TrainParam param;
|
||||||
|
param.UpdateAllowUnknown(
|
||||||
|
Args{{"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}});
|
||||||
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
|
updater->Configure(Args{});
|
||||||
|
updater->Update(¶m, &gpair, p_dmat.get(), position, {&tree});
|
||||||
|
|
||||||
|
ASSERT_EQ(tree.NumExtraNodes(), 4);
|
||||||
|
ASSERT_EQ(tree[0].SplitIndex(), 1);
|
||||||
|
|
||||||
|
ASSERT_EQ(tree[tree[0].LeftChild()].SplitIndex(), 0);
|
||||||
|
ASSERT_EQ(tree[tree[0].RightChild()].SplitIndex(), 0);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Without constraints
|
||||||
|
RegTree tree{1u, kCols};
|
||||||
|
|
||||||
|
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
||||||
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
|
TrainParam param;
|
||||||
|
param.Init(Args{});
|
||||||
|
updater->Configure(Args{});
|
||||||
|
updater->Update(¶m, &gpair, p_dmat.get(), position, {&tree});
|
||||||
|
|
||||||
|
ASSERT_EQ(tree.NumExtraNodes(), 10);
|
||||||
|
ASSERT_EQ(tree[0].SplitIndex(), 1);
|
||||||
|
|
||||||
|
ASSERT_NE(tree[tree[0].LeftChild()].SplitIndex(), 0);
|
||||||
|
ASSERT_NE(tree[tree[0].RightChild()].SplitIndex(), 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
void TestColumnSplitPartitioner(size_t n_samples, size_t base_rowid, std::shared_ptr<DMatrix> Xy,
|
void TestColumnSplitPartitioner(size_t n_samples, size_t base_rowid, std::shared_ptr<DMatrix> Xy,
|
||||||
std::vector<float>* hess, float min_value, float mid_value,
|
std::vector<float>* hess, float min_value, float mid_value,
|
||||||
|
|||||||
@ -23,9 +23,13 @@ inline std::shared_ptr<DMatrix> GenerateCatDMatrix(std::size_t rows, std::size_t
|
|||||||
for (size_t i = 0; i < ft.size(); ++i) {
|
for (size_t i = 0; i < ft.size(); ++i) {
|
||||||
ft[i] = (i % 3 == 0) ? FeatureType::kNumerical : FeatureType::kCategorical;
|
ft[i] = (i % 3 == 0) ? FeatureType::kNumerical : FeatureType::kCategorical;
|
||||||
}
|
}
|
||||||
return RandomDataGenerator(rows, cols, 0.6f).Seed(3).Type(ft).MaxCategory(17).GenerateDMatrix();
|
return RandomDataGenerator(rows, cols, sparsity)
|
||||||
|
.Seed(3)
|
||||||
|
.Type(ft)
|
||||||
|
.MaxCategory(17)
|
||||||
|
.GenerateDMatrix();
|
||||||
} else {
|
} else {
|
||||||
return RandomDataGenerator{rows, cols, 0.6f}.Seed(3).GenerateDMatrix();
|
return RandomDataGenerator{rows, cols, sparsity}.Seed(3).GenerateDMatrix();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,59 +0,0 @@
|
|||||||
/**
|
|
||||||
* Copyright 2019-2024, XGBoost Contributors
|
|
||||||
*/
|
|
||||||
#include <gtest/gtest.h>
|
|
||||||
#include <xgboost/tree_model.h>
|
|
||||||
#include <xgboost/tree_updater.h>
|
|
||||||
|
|
||||||
#include "../../../src/tree/param.h" // for TrainParam
|
|
||||||
#include "../helpers.h"
|
|
||||||
#include "test_column_split.h" // for GenerateCatDMatrix
|
|
||||||
|
|
||||||
namespace xgboost::tree {
|
|
||||||
TEST(GrowHistMaker, InteractionConstraint) {
|
|
||||||
auto constexpr kRows = 32;
|
|
||||||
auto constexpr kCols = 16;
|
|
||||||
auto p_dmat = GenerateCatDMatrix(kRows, kCols, 0.0, false);
|
|
||||||
Context ctx;
|
|
||||||
|
|
||||||
linalg::Matrix<GradientPair> gpair({kRows}, ctx.Device());
|
|
||||||
gpair.Data()->Copy(GenerateRandomGradients(kRows));
|
|
||||||
|
|
||||||
ObjInfo task{ObjInfo::kRegression};
|
|
||||||
{
|
|
||||||
// With constraints
|
|
||||||
RegTree tree{1, kCols};
|
|
||||||
|
|
||||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
|
||||||
TrainParam param;
|
|
||||||
param.UpdateAllowUnknown(
|
|
||||||
Args{{"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}});
|
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
|
||||||
updater->Configure(Args{});
|
|
||||||
updater->Update(¶m, &gpair, p_dmat.get(), position, {&tree});
|
|
||||||
|
|
||||||
ASSERT_EQ(tree.NumExtraNodes(), 4);
|
|
||||||
ASSERT_EQ(tree[0].SplitIndex(), 1);
|
|
||||||
|
|
||||||
ASSERT_EQ(tree[tree[0].LeftChild()].SplitIndex(), 0);
|
|
||||||
ASSERT_EQ(tree[tree[0].RightChild()].SplitIndex(), 0);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
// Without constraints
|
|
||||||
RegTree tree{1u, kCols};
|
|
||||||
|
|
||||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
|
|
||||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
|
||||||
TrainParam param;
|
|
||||||
param.Init(Args{});
|
|
||||||
updater->Configure(Args{});
|
|
||||||
updater->Update(¶m, &gpair, p_dmat.get(), position, {&tree});
|
|
||||||
|
|
||||||
ASSERT_EQ(tree.NumExtraNodes(), 10);
|
|
||||||
ASSERT_EQ(tree[0].SplitIndex(), 1);
|
|
||||||
|
|
||||||
ASSERT_NE(tree[tree[0].LeftChild()].SplitIndex(), 0);
|
|
||||||
ASSERT_NE(tree[tree[0].RightChild()].SplitIndex(), 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace xgboost::tree
|
|
||||||
Loading…
x
Reference in New Issue
Block a user