diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index 0a09718ef..3c5299547 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -389,6 +389,7 @@ class HistEvaluator { tree_evaluator_.AddSplit(candidate.nid, left_child, right_child, tree[candidate.nid].SplitIndex(), left_weight, right_weight); + evaluator = tree_evaluator_.GetEvaluator(); auto max_node = std::max(left_child, tree[candidate.nid].RightChild()); max_node = std::max(candidate.nid, max_node); diff --git a/src/tree/split_evaluator.h b/src/tree/split_evaluator.h index d19755d37..b65c1d055 100644 --- a/src/tree/split_evaluator.h +++ b/src/tree/split_evaluator.h @@ -48,6 +48,8 @@ class TreeEvaluator { monotone_.HostVector().resize(n_features, 0); has_constraint_ = false; } else { + CHECK_LE(p.monotone_constraints.size(), n_features) + << "The size of monotone constraint should be less or equal to the number of features."; monotone_.HostVector() = p.monotone_constraints; monotone_.HostVector().resize(n_features, 0); // Initialised to some small size, can grow if needed diff --git a/tests/cpp/tree/test_constraints.cc b/tests/cpp/tree/test_constraints.cc index fa923a621..e6b36e797 100644 --- a/tests/cpp/tree/test_constraints.cc +++ b/tests/cpp/tree/test_constraints.cc @@ -6,6 +6,9 @@ #include #include "../../../src/tree/constraints.h" +#include "../../../src/tree/hist/evaluate_splits.h" +#include "../../../src/tree/hist/expand_entry.h" +#include "../helpers.h" namespace xgboost { namespace tree { @@ -56,5 +59,38 @@ TEST(CPUFeatureInteractionConstraint, Basic) { ASSERT_FALSE(constraints.Query(1, 5)); } +TEST(CPUMonoConstraint, Basic) { + std::size_t kRows{64}, kCols{16}; + Context ctx; + + TrainParam param; + std::vector mono(kCols, 1); + I32Array arr; + for (std::size_t i = 0; i < kCols; ++i) { + arr.GetArray().push_back(mono[i]); + } + Json jarr{std::move(arr)}; + std::string str_mono; + Json::Dump(jarr, &str_mono); + str_mono.front() = '('; + str_mono.back() = ')'; + + param.UpdateAllowUnknown(Args{{"monotone_constraints", str_mono}}); + + auto Xy = RandomDataGenerator{kRows, kCols, 0.0}.GenerateDMatrix(true); + auto sampler = std::make_shared(); + + HistEvaluator evalutor{param, Xy->Info(), ctx.Threads(), sampler}; + evalutor.InitRoot(GradStats{2.0, 2.0}); + + SplitEntry split; + split.Update(1.0f, 0, 3.0, false, false, GradStats{1.0, 1.0}, GradStats{1.0, 1.0}); + CPUExpandEntry entry{0, 0, split}; + RegTree tree; + tree.param.UpdateAllowUnknown(Args{{"num_feature", std::to_string(kCols)}}); + evalutor.ApplyTreeSplit(entry, &tree); + + ASSERT_TRUE(evalutor.Evaluator().has_constraint); +} } // namespace tree } // namespace xgboost