diff --git a/src/tree/split_evaluator.cc b/src/tree/split_evaluator.cc index c48415ec8..f12af9b19 100644 --- a/src/tree/split_evaluator.cc +++ b/src/tree/split_evaluator.cc @@ -61,6 +61,10 @@ bst_float SplitEvaluator::ComputeSplitScore(bst_uint nodeid, return ComputeSplitScore(nodeid, featureid, left_stats, right_stats, left_weight, right_weight); } +bool SplitEvaluator::CheckFeatureConstraint(bst_uint nodeid, bst_uint featureid) const { + return true; +} + //! \brief Encapsulates the parameters for ElasticNet struct ElasticNetParams : public dmlc::Parameter { bst_float reg_lambda; @@ -153,6 +157,10 @@ class ElasticNet final : public SplitEvaluator { return w; } + bool CheckFeatureConstraint(bst_uint nodeid, bst_uint featureid) const override { + return true; + } + private: ElasticNetParams params_; @@ -297,6 +305,10 @@ class MonotonicConstraint final : public SplitEvaluator { } } + bool CheckFeatureConstraint(bst_uint nodeid, bst_uint featureid) const override { + return true; + } + private: MonotonicConstraintParams params_; std::unique_ptr inner_; @@ -481,6 +493,10 @@ class InteractionConstraint final : public SplitEvaluator { } } + bool CheckFeatureConstraint(bst_uint nodeid, bst_uint featureid) const override { + return CheckInteractionConstraint(featureid, nodeid); + } + private: InteractionConstraintParams params_; std::unique_ptr inner_; diff --git a/src/tree/split_evaluator.h b/src/tree/split_evaluator.h index 142b71097..8d8d61db3 100644 --- a/src/tree/split_evaluator.h +++ b/src/tree/split_evaluator.h @@ -69,6 +69,11 @@ class SplitEvaluator { bst_uint featureid, bst_float leftweight, bst_float rightweight); + + // Check whether a given feature is feasible for a given node. + // Use this function to narrow the search space for split candidates + virtual bool CheckFeatureConstraint(bst_uint nodeid, + bst_uint featureid) const = 0; }; struct SplitEvaluatorReg diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index c6381c95f..c9f569dd2 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -570,14 +570,20 @@ void QuantileHistMaker::Builder::EvaluateSplit(const int nid, best_split_tloc_[tid] = snode_[nid].best; } GHistRow node_hist = hist[nid]; + #pragma omp parallel for schedule(dynamic) num_threads(nthread) - for (bst_omp_uint i = 0; i < nfeature; ++i) { - const bst_uint fid = feature_set[i]; - const unsigned tid = omp_get_thread_num(); - this->EnumerateSplit(-1, gmat, node_hist, snode_[nid], info, - &best_split_tloc_[tid], fid, nid); - this->EnumerateSplit(+1, gmat, node_hist, snode_[nid], info, - &best_split_tloc_[tid], fid, nid); + for (bst_omp_uint i = 0; i < nfeature; ++i) { // NOLINT(*) + const auto feature_id = static_cast(feature_set[i]); + const auto tid = static_cast(omp_get_thread_num()); + const auto node_id = static_cast(nid); + // Narrow search space by dropping features that are not feasible under the + // given set of constraints (e.g. feature interaction constraints) + if (spliteval_->CheckFeatureConstraint(node_id, feature_id)) { + this->EnumerateSplit(-1, gmat, node_hist, snode_[nid], info, + &best_split_tloc_[tid], feature_id, node_id); + this->EnumerateSplit(+1, gmat, node_hist, snode_[nid], info, + &best_split_tloc_[tid], feature_id, node_id); + } } for (unsigned tid = 0; tid < nthread; ++tid) { snode_[nid].best.Update(best_split_tloc_[tid]);