From 4c74336384e2e4e32434850d9f2e76c8ed923238 Mon Sep 17 00:00:00 2001 From: Xu Xiao Date: Wed, 1 May 2019 11:59:58 +0800 Subject: [PATCH] Use feature interaction constraints to narrow search space for split candidates (#4341) * Use feature interaction constraints to narrow search space for split candidates. * fix clang-tidy broken at updater_quantile_hist.cc:535:3 * make const * fix * try to fix exception thrown in java_test * fix suspected mistake which cause EvaluateSplit error * try fix * Fix bug: feature ID and node ID swapped in argument * Rename CheckValidation() to CheckFeatureConstraint() for clarity * Do not create temporary vector validFeatures, to enable parallelism --- src/tree/split_evaluator.cc | 16 ++++++++++++++++ src/tree/split_evaluator.h | 5 +++++ src/tree/updater_quantile_hist.cc | 20 +++++++++++++------- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/src/tree/split_evaluator.cc b/src/tree/split_evaluator.cc index c48415ec8..f12af9b19 100644 --- a/src/tree/split_evaluator.cc +++ b/src/tree/split_evaluator.cc @@ -61,6 +61,10 @@ bst_float SplitEvaluator::ComputeSplitScore(bst_uint nodeid, return ComputeSplitScore(nodeid, featureid, left_stats, right_stats, left_weight, right_weight); } +bool SplitEvaluator::CheckFeatureConstraint(bst_uint nodeid, bst_uint featureid) const { + return true; +} + //! \brief Encapsulates the parameters for ElasticNet struct ElasticNetParams : public dmlc::Parameter { bst_float reg_lambda; @@ -153,6 +157,10 @@ class ElasticNet final : public SplitEvaluator { return w; } + bool CheckFeatureConstraint(bst_uint nodeid, bst_uint featureid) const override { + return true; + } + private: ElasticNetParams params_; @@ -297,6 +305,10 @@ class MonotonicConstraint final : public SplitEvaluator { } } + bool CheckFeatureConstraint(bst_uint nodeid, bst_uint featureid) const override { + return true; + } + private: MonotonicConstraintParams params_; std::unique_ptr inner_; @@ -481,6 +493,10 @@ class InteractionConstraint final : public SplitEvaluator { } } + bool CheckFeatureConstraint(bst_uint nodeid, bst_uint featureid) const override { + return CheckInteractionConstraint(featureid, nodeid); + } + private: InteractionConstraintParams params_; std::unique_ptr inner_; diff --git a/src/tree/split_evaluator.h b/src/tree/split_evaluator.h index 142b71097..8d8d61db3 100644 --- a/src/tree/split_evaluator.h +++ b/src/tree/split_evaluator.h @@ -69,6 +69,11 @@ class SplitEvaluator { bst_uint featureid, bst_float leftweight, bst_float rightweight); + + // Check whether a given feature is feasible for a given node. + // Use this function to narrow the search space for split candidates + virtual bool CheckFeatureConstraint(bst_uint nodeid, + bst_uint featureid) const = 0; }; struct SplitEvaluatorReg diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index c6381c95f..c9f569dd2 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -570,14 +570,20 @@ void QuantileHistMaker::Builder::EvaluateSplit(const int nid, best_split_tloc_[tid] = snode_[nid].best; } GHistRow node_hist = hist[nid]; + #pragma omp parallel for schedule(dynamic) num_threads(nthread) - for (bst_omp_uint i = 0; i < nfeature; ++i) { - const bst_uint fid = feature_set[i]; - const unsigned tid = omp_get_thread_num(); - this->EnumerateSplit(-1, gmat, node_hist, snode_[nid], info, - &best_split_tloc_[tid], fid, nid); - this->EnumerateSplit(+1, gmat, node_hist, snode_[nid], info, - &best_split_tloc_[tid], fid, nid); + for (bst_omp_uint i = 0; i < nfeature; ++i) { // NOLINT(*) + const auto feature_id = static_cast(feature_set[i]); + const auto tid = static_cast(omp_get_thread_num()); + const auto node_id = static_cast(nid); + // Narrow search space by dropping features that are not feasible under the + // given set of constraints (e.g. feature interaction constraints) + if (spliteval_->CheckFeatureConstraint(node_id, feature_id)) { + this->EnumerateSplit(-1, gmat, node_hist, snode_[nid], info, + &best_split_tloc_[tid], feature_id, node_id); + this->EnumerateSplit(+1, gmat, node_hist, snode_[nid], info, + &best_split_tloc_[tid], feature_id, node_id); + } } for (unsigned tid = 0; tid < nthread; ++tid) { snode_[nid].best.Update(best_split_tloc_[tid]);