[TREE] add interaction constraints (#3466)

* add interaction constraints

* enable both interaction and monotonic constraints at the same time

* fix lint

* add R test, fix lint, update demo

* Use dmlc::JSONReader to express interaction constraints as nested lists; Use sparse arrays for bookkeeping

* Add Python test for interaction constraints

* make R interaction constraints parameter based on feature index instead of column names, fix R coding style

* Fix lint

* Add BlueTea88 to CONTRIBUTORS.md

* Short circuit when no constraint is specified; address review comments

* Add tutorial for feature interaction constraints

* allow interaction constraints to be passed as string, remove redundant column_names argument

* Fix typo

* Address review comments

* Add comments to Python test
This commit is contained in:
Andrew Thia
2018-09-05 02:35:39 +10:00
committed by Philip Hyunsu Cho
parent dee0b69674
commit 9254c58e4d
12 changed files with 581 additions and 3 deletions

View File

@@ -194,7 +194,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
.describe("Number of rows in a GPU batch, used for finding quantiles on GPU; "
"-1 to use all rows assignted to a GPU, and 0 to auto-deduce");
DMLC_DECLARE_FIELD(split_evaluator)
.set_default("elastic_net,monotonic")
.set_default("elastic_net,monotonic,interaction")
.describe("The criteria to use for ranking splits");
// add alias of parameters
DMLC_DECLARE_ALIAS(reg_lambda, lambda);

View File

@@ -4,8 +4,11 @@
* \brief Contains implementations of different split evaluators.
*/
#include "split_evaluator.h"
#include <dmlc/json.h>
#include <dmlc/registry.h>
#include <algorithm>
#include <unordered_set>
#include <vector>
#include <limits>
#include <string>
#include <sstream>
@@ -303,5 +306,196 @@ XGBOOST_REGISTER_SPLIT_EVALUATOR(MonotonicConstraint, "monotonic")
return new MonotonicConstraint(std::move(inner));
});
/*! \brief Encapsulates the parameters required by the InteractionConstraint
split evaluator
*/
struct InteractionConstraintParams
: public dmlc::Parameter<InteractionConstraintParams> {
std::string interaction_constraints;
bst_uint num_feature;
DMLC_DECLARE_PARAMETER(InteractionConstraintParams) {
DMLC_DECLARE_FIELD(interaction_constraints)
.set_default("")
.describe("Constraints for interaction representing permitted interactions."
"The constraints must be specified in the form of a nest list,"
"e.g. [[0, 1], [2, 3, 4]], where each inner list is a group of"
"indices of features that are allowed to interact with each other."
"See tutorial for more information");
DMLC_DECLARE_FIELD(num_feature)
.describe("Number of total features used");
}
};
DMLC_REGISTER_PARAMETER(InteractionConstraintParams);
/*! \brief Enforces that the tree is monotonically increasing/decreasing with respect to a user specified set of
features.
*/
class InteractionConstraint final : public SplitEvaluator {
public:
explicit InteractionConstraint(std::unique_ptr<SplitEvaluator> inner) {
if (!inner) {
LOG(FATAL) << "InteractionConstraint must be given an inner evaluator";
}
inner_ = std::move(inner);
}
void Init(const std::vector<std::pair<std::string, std::string> >& args)
override {
inner_->Init(args);
params_.InitAllowUnknown(args);
Reset();
}
void Reset() override {
if (params_.interaction_constraints.empty()) {
return; // short-circuit if no constraint is specified
}
// Parse interaction constraints
std::istringstream iss(params_.interaction_constraints);
dmlc::JSONReader reader(&iss);
// Read std::vector<std::vector<bst_uint>> first and then
// convert to std::vector<std::unordered_set<bst_uint>>
std::vector<std::vector<bst_uint>> tmp;
reader.Read(&tmp);
for (const auto& e : tmp) {
interaction_constraints_.emplace_back(e.begin(), e.end());
}
// Initialise interaction constraints record with all variables permitted for the first node
int_cont_.clear();
int_cont_.resize(1, std::unordered_set<bst_uint>());
int_cont_[0].reserve(params_.num_feature);
for (bst_uint i = 0; i < params_.num_feature; ++i) {
int_cont_[0].insert(i);
}
// Initialise splits record
splits_.clear();
splits_.resize(1, std::unordered_set<bst_uint>());
}
SplitEvaluator* GetHostClone() const override {
if (params_.interaction_constraints.empty()) {
// No interaction constraints specified, just return a clone of inner
return inner_->GetHostClone();
} else {
auto c = new InteractionConstraint(
std::unique_ptr<SplitEvaluator>(inner_->GetHostClone()));
c->params_ = this->params_;
c->Reset();
return c;
}
}
bst_float ComputeSplitScore(bst_uint nodeid,
bst_uint featureid,
const GradStats& left_stats,
const GradStats& right_stats,
bst_float left_weight,
bst_float right_weight) const override {
// Return negative infinity score if feature is not permitted by interaction constraints
if (!CheckInteractionConstraint(featureid, nodeid)) {
return -std::numeric_limits<bst_float>::infinity();
}
// Otherwise, get score from inner evaluator
bst_float score = inner_->ComputeSplitScore(
nodeid, featureid, left_stats, right_stats, left_weight, right_weight);
return score;
}
bst_float ComputeScore(bst_uint parentID, const GradStats& stats, bst_float weight)
const override {
return inner_->ComputeScore(parentID, stats, weight);
}
bst_float ComputeWeight(bst_uint parentID, const GradStats& stats)
const override {
return inner_->ComputeWeight(parentID, stats);
}
void AddSplit(bst_uint nodeid,
bst_uint leftid,
bst_uint rightid,
bst_uint featureid,
bst_float leftweight,
bst_float rightweight) override {
inner_->AddSplit(nodeid, leftid, rightid, featureid, leftweight, rightweight);
if (params_.interaction_constraints.empty()) {
return; // short-circuit if no constraint is specified
}
bst_uint newsize = std::max(leftid, rightid) + 1;
// Record previous splits for child nodes
std::unordered_set<bst_uint> feature_splits = splits_[nodeid]; // fid history of current node
feature_splits.insert(featureid); // add feature of current node
splits_.resize(newsize);
splits_[leftid] = feature_splits;
splits_[rightid] = feature_splits;
// Resize constraints record, initialise all features to be not permitted for new nodes
int_cont_.resize(newsize, std::unordered_set<bst_uint>());
// Permit features used in previous splits
for (bst_uint fid : feature_splits) {
int_cont_[leftid].insert(fid);
int_cont_[rightid].insert(fid);
}
// Loop across specified interactions in constraints
for (const auto& constraint : interaction_constraints_) {
bst_uint flag = 1; // flags whether the specified interaction is still relevant
// Test relevance of specified interaction by checking all previous features are included
for (bst_uint checkvar : feature_splits) {
if (constraint.count(checkvar) == 0) {
flag = 0;
break; // interaction is not relevant due to unmet constraint
}
}
// If interaction is still relevant, permit all other features in the interaction
if (flag == 1) {
for (bst_uint k : constraint) {
int_cont_[leftid].insert(k);
int_cont_[rightid].insert(k);
}
}
}
}
private:
InteractionConstraintParams params_;
std::unique_ptr<SplitEvaluator> inner_;
// interaction_constraints_[constraint_id] contains a single interaction
// constraint, which specifies a group of feature IDs that can interact
// with each other
std::vector< std::unordered_set<bst_uint> > interaction_constraints_;
// int_cont_[nid] contains the set of all feature IDs that are allowed to
// be used for a split at node nid
std::vector< std::unordered_set<bst_uint> > int_cont_;
// splits_[nid] contains the set of all feature IDs that have been used for
// splits in node nid and its parents
std::vector< std::unordered_set<bst_uint> > splits_;
// Check interaction constraints. Returns true if a given feature ID is
// permissible in a given node; returns false otherwise
inline bool CheckInteractionConstraint(bst_uint featureid, bst_uint nodeid) const {
// short-circuit if no constraint is specified
return (params_.interaction_constraints.empty()
|| int_cont_[nodeid].count(featureid) > 0);
}
};
XGBOOST_REGISTER_SPLIT_EVALUATOR(InteractionConstraint, "interaction")
.describe("Enforces interaction constraints on tree features")
.set_body([](std::unique_ptr<SplitEvaluator> inner) {
return new InteractionConstraint(std::move(inner));
});
} // namespace tree
} // namespace xgboost