From 55968ed3fa64ae7a047bc24aee87adf296098738 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 6 May 2023 01:07:54 +0800 Subject: [PATCH] Fix monotone constraints on CPU. (#9122) --- src/tree/hist/evaluate_splits.h | 1 + src/tree/split_evaluator.h | 2 ++ tests/cpp/tree/test_constraints.cc | 34 ++++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index 8d13a48af..ec1ce769f 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -412,6 +412,7 @@ class HistEvaluator { tree_evaluator_.AddSplit(candidate.nid, left_child, right_child, tree[candidate.nid].SplitIndex(), left_weight, right_weight); + evaluator = tree_evaluator_.GetEvaluator(); snode_.resize(tree.GetNodes().size()); snode_.at(left_child).stats = candidate.split.left_sum; diff --git a/src/tree/split_evaluator.h b/src/tree/split_evaluator.h index c036cc3ed..a3b33e757 100644 --- a/src/tree/split_evaluator.h +++ b/src/tree/split_evaluator.h @@ -49,6 +49,8 @@ class TreeEvaluator { monotone_.HostVector().resize(n_features, 0); has_constraint_ = false; } else { + CHECK_LE(p.monotone_constraints.size(), n_features) + << "The size of monotone constraint should be less or equal to the number of features."; monotone_.HostVector() = p.monotone_constraints; monotone_.HostVector().resize(n_features, 0); // Initialised to some small size, can grow if needed diff --git a/tests/cpp/tree/test_constraints.cc b/tests/cpp/tree/test_constraints.cc index fa923a621..913dd8712 100644 --- a/tests/cpp/tree/test_constraints.cc +++ b/tests/cpp/tree/test_constraints.cc @@ -6,6 +6,8 @@ #include #include "../../../src/tree/constraints.h" +#include "../../../src/tree/hist/evaluate_splits.h" +#include "../helpers.h" namespace xgboost { namespace tree { @@ -56,5 +58,37 @@ TEST(CPUFeatureInteractionConstraint, Basic) { ASSERT_FALSE(constraints.Query(1, 5)); } +TEST(CPUMonoConstraint, Basic) { + std::size_t kRows{64}, kCols{16}; + Context ctx; + + TrainParam param; + std::vector mono(kCols, 1); + I32Array arr; + for (std::size_t i = 0; i < kCols; ++i) { + arr.GetArray().push_back(mono[i]); + } + Json jarr{std::move(arr)}; + std::string str_mono; + Json::Dump(jarr, &str_mono); + str_mono.front() = '('; + str_mono.back() = ')'; + + param.UpdateAllowUnknown(Args{{"monotone_constraints", str_mono}}); + + auto Xy = RandomDataGenerator{kRows, kCols, 0.0}.GenerateDMatrix(true); + auto sampler = std::make_shared(); + + HistEvaluator evalutor{&ctx, ¶m, Xy->Info(), sampler}; + evalutor.InitRoot(GradStats{2.0, 2.0}); + + SplitEntry split; + split.Update(1.0f, 0, 3.0, false, false, GradStats{1.0, 1.0}, GradStats{1.0, 1.0}); + CPUExpandEntry entry{0, 0, split}; + RegTree tree{1, static_cast(kCols)}; + evalutor.ApplyTreeSplit(entry, &tree); + + ASSERT_TRUE(evalutor.Evaluator().has_constraint); +} } // namespace tree } // namespace xgboost