[TREE] add interaction constraints (#3466)
* add interaction constraints * enable both interaction and monotonic constraints at the same time * fix lint * add R test, fix lint, update demo * Use dmlc::JSONReader to express interaction constraints as nested lists; Use sparse arrays for bookkeeping * Add Python test for interaction constraints * make R interaction constraints parameter based on feature index instead of column names, fix R coding style * Fix lint * Add BlueTea88 to CONTRIBUTORS.md * Short circuit when no constraint is specified; address review comments * Add tutorial for feature interaction constraints * allow interaction constraints to be passed as string, remove redundant column_names argument * Fix typo * Address review comments * Add comments to Python test
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
dee0b69674
commit
9254c58e4d
@@ -194,7 +194,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
.describe("Number of rows in a GPU batch, used for finding quantiles on GPU; "
|
||||
"-1 to use all rows assignted to a GPU, and 0 to auto-deduce");
|
||||
DMLC_DECLARE_FIELD(split_evaluator)
|
||||
.set_default("elastic_net,monotonic")
|
||||
.set_default("elastic_net,monotonic,interaction")
|
||||
.describe("The criteria to use for ranking splits");
|
||||
// add alias of parameters
|
||||
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
|
||||
|
||||
@@ -4,8 +4,11 @@
|
||||
* \brief Contains implementations of different split evaluators.
|
||||
*/
|
||||
#include "split_evaluator.h"
|
||||
#include <dmlc/json.h>
|
||||
#include <dmlc/registry.h>
|
||||
#include <algorithm>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
@@ -303,5 +306,196 @@ XGBOOST_REGISTER_SPLIT_EVALUATOR(MonotonicConstraint, "monotonic")
|
||||
return new MonotonicConstraint(std::move(inner));
|
||||
});
|
||||
|
||||
/*! \brief Encapsulates the parameters required by the InteractionConstraint
|
||||
split evaluator
|
||||
*/
|
||||
struct InteractionConstraintParams
|
||||
: public dmlc::Parameter<InteractionConstraintParams> {
|
||||
std::string interaction_constraints;
|
||||
bst_uint num_feature;
|
||||
|
||||
DMLC_DECLARE_PARAMETER(InteractionConstraintParams) {
|
||||
DMLC_DECLARE_FIELD(interaction_constraints)
|
||||
.set_default("")
|
||||
.describe("Constraints for interaction representing permitted interactions."
|
||||
"The constraints must be specified in the form of a nest list,"
|
||||
"e.g. [[0, 1], [2, 3, 4]], where each inner list is a group of"
|
||||
"indices of features that are allowed to interact with each other."
|
||||
"See tutorial for more information");
|
||||
DMLC_DECLARE_FIELD(num_feature)
|
||||
.describe("Number of total features used");
|
||||
}
|
||||
};
|
||||
|
||||
DMLC_REGISTER_PARAMETER(InteractionConstraintParams);
|
||||
|
||||
/*! \brief Enforces that the tree is monotonically increasing/decreasing with respect to a user specified set of
|
||||
features.
|
||||
*/
|
||||
class InteractionConstraint final : public SplitEvaluator {
|
||||
public:
|
||||
explicit InteractionConstraint(std::unique_ptr<SplitEvaluator> inner) {
|
||||
if (!inner) {
|
||||
LOG(FATAL) << "InteractionConstraint must be given an inner evaluator";
|
||||
}
|
||||
inner_ = std::move(inner);
|
||||
}
|
||||
|
||||
void Init(const std::vector<std::pair<std::string, std::string> >& args)
|
||||
override {
|
||||
inner_->Init(args);
|
||||
params_.InitAllowUnknown(args);
|
||||
Reset();
|
||||
}
|
||||
|
||||
void Reset() override {
|
||||
if (params_.interaction_constraints.empty()) {
|
||||
return; // short-circuit if no constraint is specified
|
||||
}
|
||||
|
||||
// Parse interaction constraints
|
||||
std::istringstream iss(params_.interaction_constraints);
|
||||
dmlc::JSONReader reader(&iss);
|
||||
// Read std::vector<std::vector<bst_uint>> first and then
|
||||
// convert to std::vector<std::unordered_set<bst_uint>>
|
||||
std::vector<std::vector<bst_uint>> tmp;
|
||||
reader.Read(&tmp);
|
||||
for (const auto& e : tmp) {
|
||||
interaction_constraints_.emplace_back(e.begin(), e.end());
|
||||
}
|
||||
|
||||
// Initialise interaction constraints record with all variables permitted for the first node
|
||||
int_cont_.clear();
|
||||
int_cont_.resize(1, std::unordered_set<bst_uint>());
|
||||
int_cont_[0].reserve(params_.num_feature);
|
||||
for (bst_uint i = 0; i < params_.num_feature; ++i) {
|
||||
int_cont_[0].insert(i);
|
||||
}
|
||||
|
||||
// Initialise splits record
|
||||
splits_.clear();
|
||||
splits_.resize(1, std::unordered_set<bst_uint>());
|
||||
}
|
||||
|
||||
SplitEvaluator* GetHostClone() const override {
|
||||
if (params_.interaction_constraints.empty()) {
|
||||
// No interaction constraints specified, just return a clone of inner
|
||||
return inner_->GetHostClone();
|
||||
} else {
|
||||
auto c = new InteractionConstraint(
|
||||
std::unique_ptr<SplitEvaluator>(inner_->GetHostClone()));
|
||||
c->params_ = this->params_;
|
||||
c->Reset();
|
||||
return c;
|
||||
}
|
||||
}
|
||||
|
||||
bst_float ComputeSplitScore(bst_uint nodeid,
|
||||
bst_uint featureid,
|
||||
const GradStats& left_stats,
|
||||
const GradStats& right_stats,
|
||||
bst_float left_weight,
|
||||
bst_float right_weight) const override {
|
||||
// Return negative infinity score if feature is not permitted by interaction constraints
|
||||
if (!CheckInteractionConstraint(featureid, nodeid)) {
|
||||
return -std::numeric_limits<bst_float>::infinity();
|
||||
}
|
||||
|
||||
// Otherwise, get score from inner evaluator
|
||||
bst_float score = inner_->ComputeSplitScore(
|
||||
nodeid, featureid, left_stats, right_stats, left_weight, right_weight);
|
||||
return score;
|
||||
}
|
||||
|
||||
bst_float ComputeScore(bst_uint parentID, const GradStats& stats, bst_float weight)
|
||||
const override {
|
||||
return inner_->ComputeScore(parentID, stats, weight);
|
||||
}
|
||||
|
||||
bst_float ComputeWeight(bst_uint parentID, const GradStats& stats)
|
||||
const override {
|
||||
return inner_->ComputeWeight(parentID, stats);
|
||||
}
|
||||
|
||||
void AddSplit(bst_uint nodeid,
|
||||
bst_uint leftid,
|
||||
bst_uint rightid,
|
||||
bst_uint featureid,
|
||||
bst_float leftweight,
|
||||
bst_float rightweight) override {
|
||||
inner_->AddSplit(nodeid, leftid, rightid, featureid, leftweight, rightweight);
|
||||
|
||||
if (params_.interaction_constraints.empty()) {
|
||||
return; // short-circuit if no constraint is specified
|
||||
}
|
||||
bst_uint newsize = std::max(leftid, rightid) + 1;
|
||||
|
||||
// Record previous splits for child nodes
|
||||
std::unordered_set<bst_uint> feature_splits = splits_[nodeid]; // fid history of current node
|
||||
feature_splits.insert(featureid); // add feature of current node
|
||||
splits_.resize(newsize);
|
||||
splits_[leftid] = feature_splits;
|
||||
splits_[rightid] = feature_splits;
|
||||
|
||||
// Resize constraints record, initialise all features to be not permitted for new nodes
|
||||
int_cont_.resize(newsize, std::unordered_set<bst_uint>());
|
||||
|
||||
// Permit features used in previous splits
|
||||
for (bst_uint fid : feature_splits) {
|
||||
int_cont_[leftid].insert(fid);
|
||||
int_cont_[rightid].insert(fid);
|
||||
}
|
||||
|
||||
// Loop across specified interactions in constraints
|
||||
for (const auto& constraint : interaction_constraints_) {
|
||||
bst_uint flag = 1; // flags whether the specified interaction is still relevant
|
||||
|
||||
// Test relevance of specified interaction by checking all previous features are included
|
||||
for (bst_uint checkvar : feature_splits) {
|
||||
if (constraint.count(checkvar) == 0) {
|
||||
flag = 0;
|
||||
break; // interaction is not relevant due to unmet constraint
|
||||
}
|
||||
}
|
||||
|
||||
// If interaction is still relevant, permit all other features in the interaction
|
||||
if (flag == 1) {
|
||||
for (bst_uint k : constraint) {
|
||||
int_cont_[leftid].insert(k);
|
||||
int_cont_[rightid].insert(k);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
InteractionConstraintParams params_;
|
||||
std::unique_ptr<SplitEvaluator> inner_;
|
||||
// interaction_constraints_[constraint_id] contains a single interaction
|
||||
// constraint, which specifies a group of feature IDs that can interact
|
||||
// with each other
|
||||
std::vector< std::unordered_set<bst_uint> > interaction_constraints_;
|
||||
// int_cont_[nid] contains the set of all feature IDs that are allowed to
|
||||
// be used for a split at node nid
|
||||
std::vector< std::unordered_set<bst_uint> > int_cont_;
|
||||
// splits_[nid] contains the set of all feature IDs that have been used for
|
||||
// splits in node nid and its parents
|
||||
std::vector< std::unordered_set<bst_uint> > splits_;
|
||||
|
||||
// Check interaction constraints. Returns true if a given feature ID is
|
||||
// permissible in a given node; returns false otherwise
|
||||
inline bool CheckInteractionConstraint(bst_uint featureid, bst_uint nodeid) const {
|
||||
// short-circuit if no constraint is specified
|
||||
return (params_.interaction_constraints.empty()
|
||||
|| int_cont_[nodeid].count(featureid) > 0);
|
||||
}
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_SPLIT_EVALUATOR(InteractionConstraint, "interaction")
|
||||
.describe("Enforces interaction constraints on tree features")
|
||||
.set_body([](std::unique_ptr<SplitEvaluator> inner) {
|
||||
return new InteractionConstraint(std::move(inner));
|
||||
});
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user